diff --git a/CHANGELOG.md b/CHANGELOG.md index 11125930d1f..e214b5ede1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added a way to resize the vocabulary in the T5 module +- Added an argument `reinit_modules` to `cached_transformers.get()` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings. ### Fixed diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index bc7cc4a6dfd..3177faa2ab8 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -1,10 +1,11 @@ import logging +import re import warnings -from typing import NamedTuple, Optional, Dict, Tuple +from typing import Dict, NamedTuple, Optional, Tuple, Union, cast import transformers -from transformers import AutoModel, AutoConfig - +from allennlp.common.checks import ConfigurationError +from transformers import AutoConfig, AutoModel logger = logging.getLogger(__name__) @@ -13,6 +14,7 @@ class TransformerSpec(NamedTuple): model_name: str override_weights_file: Optional[str] = None override_weights_strip_prefix: Optional[str] = None + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None _model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {} @@ -23,6 +25,7 @@ def get( make_copy: bool, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None, load_weights: bool = True, **kwargs, ) -> transformers.PreTrainedModel: @@ -43,13 +46,28 @@ def get( with `torch.save()`. override_weights_strip_prefix : `str`, optional (default = `None`) If set, strip the given prefix from the state dict when loading it. + reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will + be re-initialized. Note, because the module structure of the transformer `model_name` can + differ, we cannot guarantee that providing an integer or tuple of integers will work. If + this fails, you can instead provide a tuple of strings, which will be treated as regexes and + any module with a name matching the regex will be re-initialized. Re-initializing the last + few layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False` or `override_weights_file` is not `None`. load_weights : `bool`, optional (default = `True`) If set to `False`, no weights will be loaded. This is helpful when you only want to initialize the architecture, like when you've already fine-tuned a model and are going to load the weights from a state dict elsewhere. """ global _model_cache - spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix) + spec = TransformerSpec( + model_name, + override_weights_file, + override_weights_strip_prefix, + reinit_modules, + ) transformer = _model_cache.get(spec, None) if transformer is None: if not load_weights: @@ -59,6 +77,12 @@ def get( "but 'load_weights' is set to False, so 'override_weights_file' will be ignored.", UserWarning, ) + if reinit_modules is not None: + warnings.warn( + "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), " + "but 'load_weights' is set to False, so 'reinit_modules' will be ignored.", + UserWarning, + ) transformer = AutoModel.from_config( AutoConfig.from_pretrained( model_name, @@ -66,8 +90,14 @@ def get( ) ) elif override_weights_file is not None: - from allennlp.common.file_utils import cached_path + if reinit_modules is not None: + warnings.warn( + "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), " + "but 'override_weights_file' is not None, so 'reinit_modules' will be ignored.", + UserWarning, + ) import torch + from allennlp.common.file_utils import cached_path override_weights_file = cached_path(override_weights_file) override_weights = torch.load(override_weights_file) @@ -110,6 +140,42 @@ def strip_prefix(s): transformer.module.load_state_dict(override_weights) else: transformer.load_state_dict(override_weights) + elif reinit_modules is not None: + transformer = AutoModel.from_pretrained( + model_name, + **kwargs, + ) + num_layers = transformer.config.num_hidden_layers + if isinstance(reinit_modules, int): + reinit_modules = tuple(range(num_layers - reinit_modules, num_layers)) + if all(isinstance(x, int) for x in reinit_modules): + # This type cast is neccessary to avoid a mypy error. + reinit_modules = cast(Tuple[int], reinit_modules) + if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in reinit_modules): + raise ValueError( + f"A layer index in reinit_modules ({reinit_modules}) is invalid." + f" Must be between 0 and the maximum layer index ({num_layers - 1}.)" + ) + # Some transformer models organize their modules differently, so if this fails, + # raise an error with a helpful message. + try: + for layer_idx in reinit_modules: + transformer.encoder.layer[layer_idx].apply(transformer._init_weights) + except AttributeError: + raise ConfigurationError( + f"Unable to re-initialize the layers of transformer model" + f" {model_name} using layer indices. Please provide a tuple of" + " strings corresponding to the names of the layers to re-initialize." + ) + elif all(isinstance(x, str) for x in reinit_modules): + for regex in reinit_modules: + for name, module in transformer.named_modules(): + if re.search(regex, name): + module.apply(transformer._init_weights) + else: + raise ValueError( + "reinit_modules must be either an integer, a tuple of strings, or a tuple of integers." + ) else: transformer = AutoModel.from_pretrained( model_name, diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 6275f86de70..6e819093c65 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -1,16 +1,14 @@ import logging import math -from typing import Optional, Tuple, Dict, Any - +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F -from transformers import XLNetConfig - from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.modules.scalar_mix import ScalarMix from allennlp.modules.token_embedders.token_embedder import TokenEmbedder from allennlp.nn.util import batched_index_select +from transformers import XLNetConfig logger = logging.getLogger(__name__) @@ -54,6 +52,16 @@ class PretrainedTransformerEmbedder(TokenEmbedder): with `torch.save()`. override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) If set, strip the given prefix from the state dict when loading it. + reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will + be re-initialized. Note, because the module structure of the transformer `model_name` can + differ, we cannot guarantee that providing an integer or tuple of integers will work. If + this fails, you can instead provide a tuple of strings, which will be treated as regexes and + any module with a name matching the regex will be re-initialized. Re-initializing the last + few layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False` or `override_weights_file` is not `None`. load_weights: `bool`, optional (default = `True`) Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive it usually makes sense to set this to `False` (via the `overrides` parameter) @@ -84,6 +92,7 @@ def __init__( last_layer_only: bool = True, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None, load_weights: bool = True, gradient_checkpointing: Optional[bool] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, @@ -97,6 +106,7 @@ def __init__( True, override_weights_file=override_weights_file, override_weights_strip_prefix=override_weights_strip_prefix, + reinit_modules=reinit_modules, load_weights=load_weights, **(transformer_kwargs or {}), ) diff --git a/tests/common/cached_transformers_test.py b/tests/common/cached_transformers_test.py index e0a8bee054b..409da5bb8be 100644 --- a/tests/common/cached_transformers_test.py +++ b/tests/common/cached_transformers_test.py @@ -1,12 +1,12 @@ -import pytest -import torch -import os import json +import os +import pytest +import torch from allennlp.common import cached_transformers +from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase - -from transformers import AutoModel, AutoConfig +from transformers import AutoConfig, AutoModel class TestCachedTransformers(AllenNlpTestCase): @@ -72,6 +72,99 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self): for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()): assert p1.data.ne(p2.data).sum() == 0 + def test_reinit_modules_no_op(self): + # Test the case where reinit_modules is None (default) + preinit_weights = torch.cat( + [ + # Comparing all weights of the model is rather complicated, so arbitrarily + # compare the weights of attention module. + layer.attention.output.dense.weight + for layer in cached_transformers.get("bert-base-cased", True).encoder.layer + ] + ) + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get("bert-base-cased", True).encoder.layer + ] + ) + assert torch.equal(postinit_weights, preinit_weights) + + def test_reinit_modules_with_layer_indices(self): + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of attention module. + preinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get("bert-base-cased", True).encoder.layer + ] + ) + + # Test the case when reinit_modules is a valid int. + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get( + "bert-base-cased", True, reinit_modules=2 + ).encoder.layer + ] + ) + assert torch.equal(postinit_weights[:10], preinit_weights[:10]) + assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + + # Test the case when reinit_modules is a valid list of integers. + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get( + "bert-base-cased", True, reinit_modules=(10, 11) + ).encoder.layer + ] + ) + assert torch.equal(postinit_weights[:10], preinit_weights[:10]) + assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + + # Should raise a ValueError because reinit_modules contains at least one index that is + # greater than the models maximum number of layers + with pytest.raises(ValueError): + _ = cached_transformers.get("bert-base-cased", True, reinit_modules=1000) + with pytest.raises(ValueError): + _ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, 1000)) + # The argument cannot mix layer indices and regex strings. + with pytest.raises(ValueError): + _ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, "attentions")) + # This model has a non-standard structure, so if a layer index or list of layer indexes + # is provided, we raise a ConfigurationError. + with pytest.raises(ConfigurationError): + _ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=1) + with pytest.raises(ConfigurationError): + _ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=(1, 2)) + + def test_reinit_modules_with_regex_strings(self): + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of wpe module. + reinit_module = "wpe" + # This MUST be a deep copy, otherwise the parameters will be re-initialized and the + # test will break. + preinit_weights = list( + cached_transformers.get("sshleifer/tiny-gpt2", True) + .get_submodule(reinit_module) + .parameters() + ) + + postinit_weights = list( + cached_transformers.get( + "sshleifer/tiny-gpt2", + True, + reinit_modules=(reinit_module,), + ) + .get_submodule(reinit_module) + .parameters() + ) + assert all( + (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights)) + ) + def test_from_pretrained_no_load_weights(self): _ = cached_transformers.get( "epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR