From fe91837522d33da45ea134e2cdb0aad7c0840df1 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Fri, 10 Dec 2021 13:40:16 -0500 Subject: [PATCH 01/17] Arg to re-init some layers of pretrained transformer --- .../pretrained_transformer_embedder.py | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 6ef220fbcca..40589c61502 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -1,6 +1,6 @@ import logging import math -from typing import Optional, Tuple, Dict, Any +from typing import Optional, Tuple, Dict, Any, Union, List from overrides import overrides @@ -49,6 +49,12 @@ class PretrainedTransformerEmbedder(TokenEmbedder): When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers is used. + reinit_layers: `Optional[Union[int, List[int]]]`, optional (default = `None`) + If this is an integer, the last `reinit_layers` layers of the transformer will be + re-initialized. If this is a list, the layers indexed by `reinit_layers` will be + re-initialized. Re-initializing the last few layers of a pretrained transformer can reduce + the instability of fine-tuning on small datasets and may improve performance + (https://arxiv.org/abs/2006.05987v3). Has no effect if `load_weights` is `False`. override_weights_file: `Optional[str]`, optional (default = `None`) If set, this specifies a file from which to load alternate weights that override the weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created @@ -83,6 +89,7 @@ def __init__( train_parameters: bool = True, eval_mode: bool = False, last_layer_only: bool = True, + reinit_layers: Optional[Union[int, List[int]]] = None, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, load_weights: bool = True, @@ -120,6 +127,22 @@ def __init__( self._scalar_mix = ScalarMix(self.config.num_hidden_layers) self.config.output_hidden_states = True + # Optionally, re-initialize the parameters of certain layers. + self._reinit_layers = reinit_layers + if self._reinit_layers and load_weights: + num_layers = len(self.transformer_model.encoder.layer) + if isinstance(reinit_layers, int): + self._reinit_layers = list(range(num_layers - reinit_layers, num_layers)) + if any(layer_idx > num_layers for layer_idx in self._reinit_layers): + raise ValueError( + f"A layer index in reinit_layers ({self._reinit_layers}) is larger than the" + f" maximum layer index {num_layers - 1}." + ) + for layer_idx in self._reinit_layers: + self.transformer_model.encoder.layer[layer_idx].apply( + self.transformer_model._init_weights + ) + tokenizer = PretrainedTransformerTokenizer( model_name, tokenizer_kwargs=tokenizer_kwargs, From c75bc562e062cacf1eae3958b62637055da7856f Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Fri, 10 Dec 2021 14:45:41 -0500 Subject: [PATCH 02/17] Better error handelling for bad indices --- .../token_embedders/pretrained_transformer_embedder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 44127349502..dee7b718c1d 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -132,10 +132,10 @@ def __init__( num_layers = len(self.transformer_model.encoder.layer) if isinstance(reinit_layers, int): self._reinit_layers = list(range(num_layers - reinit_layers, num_layers)) - if any(layer_idx > num_layers for layer_idx in self._reinit_layers): + if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in self._reinit_layers): raise ValueError( - f"A layer index in reinit_layers ({self._reinit_layers}) is larger than the" - f" maximum layer index {num_layers - 1}." + f"A layer index in reinit_layers ({self._reinit_layers}) is invalid. Must be" + f" between 0 and the maximum layer index ({num_layers - 1}.)" ) for layer_idx in self._reinit_layers: self.transformer_model.encoder.layer[layer_idx].apply( From 69fef3675c1a807a2058c10eb344ec5da8da0ef5 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Fri, 10 Dec 2021 14:45:52 -0500 Subject: [PATCH 03/17] Add a test for reinit_layers --- .../pretrained_transformer_embedder_test.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 6cd14a1cf9e..580fcac6d15 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -341,6 +341,41 @@ def test_embeddings_resize(self): == 28997 ) + def test_reinit_layers(self): + regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased") + assert regular_token_embedder._reinit_layers is None + # Test the case when reinit_layers is a valid int. Comparing all weights of the model is + # rather complicated, so arbitrarily compare the weights of attention module. + preinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in regular_token_embedder.transformer_model.encoder.layer + ] + ) + reinit_token_embedder = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=2) + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in reinit_token_embedder.transformer_model.encoder.layer + ] + ) + assert reinit_token_embedder._reinit_layers == [10, 11] + assert torch.equal(postinit_weights[:10], preinit_weights[:10]) + assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + # Test the case when reinit_layers is a valid list of integers. + reinit_token_embedder = PretrainedTransformerEmbedder( + "bert-base-cased", reinit_layers=[10, 11] + ) + assert reinit_token_embedder._reinit_layers == [10, 11] + assert torch.equal(postinit_weights[:10], preinit_weights[:10]) + assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + # Should raise a ValueError because reinit_layers contains at least one index that is + # greater than the models maximum number of layers + with pytest.raises(ValueError): + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=1000) + with pytest.raises(ValueError): + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=[1, 1000]) + def test_eval_mode(self): token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True) assert token_embedder.training and not token_embedder.transformer_model.training From e6799ca2abcf808b11f978f628e53ca594d0853f Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Fri, 10 Dec 2021 14:54:42 -0500 Subject: [PATCH 04/17] Type cast to appease mypy --- .../token_embedders/pretrained_transformer_embedder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index dee7b718c1d..48f2e063e9a 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -1,6 +1,6 @@ import logging import math -from typing import Optional, Tuple, Dict, Any, Union, List +from typing import Optional, Tuple, Dict, Any, Union, List, cast import torch @@ -127,7 +127,7 @@ def __init__( self.config.output_hidden_states = True # Optionally, re-initialize the parameters of certain layers. - self._reinit_layers = reinit_layers + self._reinit_layers = cast(List[int], reinit_layers) if self._reinit_layers and load_weights: num_layers = len(self.transformer_model.encoder.layer) if isinstance(reinit_layers, int): From 5e55ad055f0ec5e8d4d11aef605229c7d9a438c3 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Wed, 15 Dec 2021 12:44:34 -0500 Subject: [PATCH 05/17] Get number of hidden layers in model agnostic way --- .../modules/token_embedders/pretrained_transformer_embedder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 48f2e063e9a..ccb5326ff4b 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -129,7 +129,7 @@ def __init__( # Optionally, re-initialize the parameters of certain layers. self._reinit_layers = cast(List[int], reinit_layers) if self._reinit_layers and load_weights: - num_layers = len(self.transformer_model.encoder.layer) + num_layers = self.transformer_model.config.num_hidden_layers if isinstance(reinit_layers, int): self._reinit_layers = list(range(num_layers - reinit_layers, num_layers)) if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in self._reinit_layers): From fe4297944c9168f89f0f79e1b1469ae88a37defd Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 16 Dec 2021 11:29:13 -0500 Subject: [PATCH 06/17] Support a list of regexes in reinit_modules --- .../pretrained_transformer_embedder.py | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index ccb5326ff4b..3dda8e23b02 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -1,16 +1,16 @@ import logging import math -from typing import Optional, Tuple, Dict, Any, Union, List, cast - +import re +from typing import Any, Dict, List, Optional, Tuple, Union, cast import torch import torch.nn.functional as F -from transformers import XLNetConfig - +from allennlp.common.checks import ConfigurationError from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.modules.scalar_mix import ScalarMix from allennlp.modules.token_embedders.token_embedder import TokenEmbedder from allennlp.nn.util import batched_index_select +from transformers import XLNetConfig logger = logging.getLogger(__name__) @@ -48,12 +48,14 @@ class PretrainedTransformerEmbedder(TokenEmbedder): When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers is used. - reinit_layers: `Optional[Union[int, List[int]]]`, optional (default = `None`) - If this is an integer, the last `reinit_layers` layers of the transformer will be - re-initialized. If this is a list, the layers indexed by `reinit_layers` will be - re-initialized. Re-initializing the last few layers of a pretrained transformer can reduce - the instability of fine-tuning on small datasets and may improve performance - (https://arxiv.org/abs/2006.05987v3). Has no effect if `load_weights` is `False`. + reinit_modules: `Optional[Union[int, List[int]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a list of integers, the layers indexed by `reinit_modules` will + be re-initialized. If this is a list of strings, they will be treated as regexes and any + module with a name matching the regex will be re-initialized. Re-initializing the last few + layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False`. override_weights_file: `Optional[str]`, optional (default = `None`) If set, this specifies a file from which to load alternate weights that override the weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created @@ -88,7 +90,7 @@ def __init__( train_parameters: bool = True, eval_mode: bool = False, last_layer_only: bool = True, - reinit_layers: Optional[Union[int, List[int]]] = None, + reinit_modules: Optional[Union[int, List[int], List[str]]] = None, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, load_weights: bool = True, @@ -127,20 +129,42 @@ def __init__( self.config.output_hidden_states = True # Optionally, re-initialize the parameters of certain layers. - self._reinit_layers = cast(List[int], reinit_layers) - if self._reinit_layers and load_weights: + self._reinit_modules = reinit_modules + if self._reinit_modules and load_weights: num_layers = self.transformer_model.config.num_hidden_layers - if isinstance(reinit_layers, int): - self._reinit_layers = list(range(num_layers - reinit_layers, num_layers)) - if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in self._reinit_layers): - raise ValueError( - f"A layer index in reinit_layers ({self._reinit_layers}) is invalid. Must be" - f" between 0 and the maximum layer index ({num_layers - 1}.)" - ) - for layer_idx in self._reinit_layers: - self.transformer_model.encoder.layer[layer_idx].apply( - self.transformer_model._init_weights - ) + if isinstance(self._reinit_modules, int): + self._reinit_modules = list(range(num_layers - self._reinit_modules, num_layers)) + # This type cast is neccessary to avoid a mypy error. + self._reinit_modules = cast(list, self._reinit_modules) + if all(isinstance(x, int) for x in self._reinit_modules): + self._reinit_modules = cast(List[int], self._reinit_modules) + if any( + layer_idx < 0 or layer_idx > num_layers for layer_idx in self._reinit_modules + ): + raise ValueError( + f"A layer index in reinit_modules ({self._reinit_modules}) is invalid." + f" Must be between 0 and the maximum layer index ({num_layers - 1}.)" + ) + # Some transformer models organize their modules differently, so if this fails, + # raise an error with a helpful message. + try: + for layer_idx in self._reinit_modules: + self.transformer_model.encoder.layer[layer_idx].apply( + self.transformer_model._init_weights + ) + except AttributeError: + raise ConfigurationError( + f"Unable to re-initialize the layers of transformer model" + f" {model_name} using layer indices. Please provide a list of" + " strings corresponding to the names of the layers to re-initialize." + ) + elif all(isinstance(x, str) for x in self._reinit_modules): + for regex in self._reinit_modules: + for name, module in self.transformer_model.named_modules(): + if re.search(regex, name): + module.apply(self.transformer_model._init_weights) + else: + raise ValueError("reinit_modules must be a list of strings or a list of integers.") tokenizer = PretrainedTransformerTokenizer( model_name, From 0c33cf838ce531a2d685625ddc6229211d8106a0 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 16 Dec 2021 11:29:25 -0500 Subject: [PATCH 07/17] Add tests for when reinit_modules is a list of str --- .../pretrained_transformer_embedder_test.py | 61 ++++++++++++++----- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 580fcac6d15..37fd1175576 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -1,8 +1,9 @@ import math + import pytest import torch - from allennlp.common import Params, cached_transformers +from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase, requires_gpu from allennlp.data import Vocabulary from allennlp.data.batch import Batch @@ -341,40 +342,72 @@ def test_embeddings_resize(self): == 28997 ) - def test_reinit_layers(self): - regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased") - assert regular_token_embedder._reinit_layers is None - # Test the case when reinit_layers is a valid int. Comparing all weights of the model is - # rather complicated, so arbitrarily compare the weights of attention module. + def test_reinit_modules(self): + # Test the base case, where reinit_modules is None. + transformer_model = cached_transformers.get("bert-base-cased", True) + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of attention module. preinit_weights = torch.cat( + [layer.attention.output.dense.weight for layer in transformer_model.encoder.layer] + ) + regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased") + postinit_weights = torch.cat( [ layer.attention.output.dense.weight for layer in regular_token_embedder.transformer_model.encoder.layer ] ) - reinit_token_embedder = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=2) + assert regular_token_embedder._reinit_modules is None + assert torch.equal(postinit_weights, preinit_weights) + + # Test the case when reinit_modules is a valid int. + reinit_token_embedder = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=2) postinit_weights = torch.cat( [ layer.attention.output.dense.weight for layer in reinit_token_embedder.transformer_model.encoder.layer ] ) - assert reinit_token_embedder._reinit_layers == [10, 11] + assert reinit_token_embedder._reinit_modules == [10, 11] assert torch.equal(postinit_weights[:10], preinit_weights[:10]) assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) - # Test the case when reinit_layers is a valid list of integers. + + # Test the case when reinit_modules is a valid list of integers. reinit_token_embedder = PretrainedTransformerEmbedder( - "bert-base-cased", reinit_layers=[10, 11] + "bert-base-cased", reinit_modules=[10, 11] ) - assert reinit_token_embedder._reinit_layers == [10, 11] + assert reinit_token_embedder._reinit_modules == [10, 11] assert torch.equal(postinit_weights[:10], preinit_weights[:10]) assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) - # Should raise a ValueError because reinit_layers contains at least one index that is + + # Test the case where reinit_modules is a list of regex strings. + transformer_model = cached_transformers.get("xlm-mlm-enfr-1024", True) + preinit_weights = list(transformer_model.parameters("position_embeddings")) + reinit_token_embedder = PretrainedTransformerEmbedder( + "xlm-mlm-enfr-1024", reinit_modules=["position_embeddings"] + ) + postinit_weights = list( + reinit_token_embedder.transformer_model.parameters("position_embeddings") + ) + assert all( + (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights)) + ) + + # Should raise a ValueError because reinit_modules contains at least one index that is # greater than the models maximum number of layers with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=1000) + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=1000) with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=[1, 1000]) + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, 1000]) + # This model has a non-standard structure, so if a layer index or list of layer indexes + # is provided, we raise a ConfigurationError. + with pytest.raises(ConfigurationError): + _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=1) + with pytest.raises(ConfigurationError): + _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=[1, 2]) + # The argument cannot mix layer indices and regex strings. + with pytest.raises(ConfigurationError): + _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=[1, "attentions"]) def test_eval_mode(self): token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True) From a84d069bdf7ac8abe545ee884681fd691e51791a Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Mon, 20 Dec 2021 11:09:10 -0500 Subject: [PATCH 08/17] Fix broken test for re-initializing modules --- .../pretrained_transformer_embedder_test.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 37fd1175576..96903ac6eae 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -1,3 +1,4 @@ +import copy import math import pytest @@ -382,12 +383,19 @@ def test_reinit_modules(self): # Test the case where reinit_modules is a list of regex strings. transformer_model = cached_transformers.get("xlm-mlm-enfr-1024", True) - preinit_weights = list(transformer_model.parameters("position_embeddings")) + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of position_embeddings module. + reinit_module = "position_embeddings" + # This MUST be a deep copy, otherwise the parameters will be re-initialized and the + # test will break. + preinit_weights = copy.deepcopy( + list(transformer_model.get_submodule(reinit_module).parameters()) + ) reinit_token_embedder = PretrainedTransformerEmbedder( - "xlm-mlm-enfr-1024", reinit_modules=["position_embeddings"] + "xlm-mlm-enfr-1024", reinit_modules=[reinit_module] ) postinit_weights = list( - reinit_token_embedder.transformer_model.parameters("position_embeddings") + reinit_token_embedder.transformer_model.get_submodule(reinit_module).parameters() ) assert all( (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights)) @@ -399,15 +407,15 @@ def test_reinit_modules(self): _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=1000) with pytest.raises(ValueError): _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, 1000]) + # The argument cannot mix layer indices and regex strings. + with pytest.raises(ValueError): + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, "attentions"]) # This model has a non-standard structure, so if a layer index or list of layer indexes # is provided, we raise a ConfigurationError. with pytest.raises(ConfigurationError): _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=1) with pytest.raises(ConfigurationError): _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=[1, 2]) - # The argument cannot mix layer indices and regex strings. - with pytest.raises(ConfigurationError): - _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=[1, "attentions"]) def test_eval_mode(self): token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True) From 39d92ca9124d0818a5f92001cf2a9a8c527dccc3 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Mon, 20 Dec 2021 19:25:35 -0500 Subject: [PATCH 09/17] Break reinit_modules unit test into two --- .../pretrained_transformer_embedder_test.py | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 96903ac6eae..38fe4149682 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -343,7 +343,7 @@ def test_embeddings_resize(self): == 28997 ) - def test_reinit_modules(self): + def test_reinit_modules_with_layer_indices(self): # Test the base case, where reinit_modules is None. transformer_model = cached_transformers.get("bert-base-cased", True) # Comparing all weights of the model is rather complicated, so arbitrarily compare the @@ -381,18 +381,35 @@ def test_reinit_modules(self): assert torch.equal(postinit_weights[:10], preinit_weights[:10]) assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + # Should raise a ValueError because reinit_modules contains at least one index that is + # greater than the models maximum number of layers + with pytest.raises(ValueError): + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=1000) + with pytest.raises(ValueError): + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, 1000]) + # The argument cannot mix layer indices and regex strings. + with pytest.raises(ValueError): + _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, "attentions"]) + # This model has a non-standard structure, so if a layer index or list of layer indexes + # is provided, we raise a ConfigurationError. + with pytest.raises(ConfigurationError): + _ = PretrainedTransformerEmbedder("sshleifer/tiny-gpt2", reinit_modules=1) + with pytest.raises(ConfigurationError): + _ = PretrainedTransformerEmbedder("sshleifer/tiny-gpt2", reinit_modules=[1, 2]) + + def test_reinit_modules_with_regex_strings(self): # Test the case where reinit_modules is a list of regex strings. - transformer_model = cached_transformers.get("xlm-mlm-enfr-1024", True) + transformer_model = cached_transformers.get("sshleifer/tiny-gpt2", True) # Comparing all weights of the model is rather complicated, so arbitrarily compare the - # weights of position_embeddings module. - reinit_module = "position_embeddings" + # weights of wpe module. + reinit_module = "wpe" # This MUST be a deep copy, otherwise the parameters will be re-initialized and the # test will break. preinit_weights = copy.deepcopy( list(transformer_model.get_submodule(reinit_module).parameters()) ) reinit_token_embedder = PretrainedTransformerEmbedder( - "xlm-mlm-enfr-1024", reinit_modules=[reinit_module] + "sshleifer/tiny-gpt2", reinit_modules=[reinit_module] ) postinit_weights = list( reinit_token_embedder.transformer_model.get_submodule(reinit_module).parameters() @@ -401,22 +418,6 @@ def test_reinit_modules(self): (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights)) ) - # Should raise a ValueError because reinit_modules contains at least one index that is - # greater than the models maximum number of layers - with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=1000) - with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, 1000]) - # The argument cannot mix layer indices and regex strings. - with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, "attentions"]) - # This model has a non-standard structure, so if a layer index or list of layer indexes - # is provided, we raise a ConfigurationError. - with pytest.raises(ConfigurationError): - _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=1) - with pytest.raises(ConfigurationError): - _ = PretrainedTransformerEmbedder("xlm-mlm-enfr-1024", reinit_modules=[1, 2]) - def test_eval_mode(self): token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True) assert token_embedder.training and not token_embedder.transformer_model.training From 718020d919818ee227a87e98b8b625458d4e6cb9 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Mon, 20 Dec 2021 19:25:41 -0500 Subject: [PATCH 10/17] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4b2a1d99f1..ae762057438 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added a way to resize the vocabulary in the T5 module +- Added an argument `reinit_modules` to `PretrainedTransformerEmbedder` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings. ### Fixed From 9d7915123a7d978c2c1292de6725021fe9bd1fb2 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Mon, 20 Dec 2021 19:47:03 -0500 Subject: [PATCH 11/17] Tests for when reinit_modules should have no effect --- .../pretrained_transformer_embedder.py | 4 ++-- .../pretrained_transformer_embedder_test.py | 20 +++++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 3dda8e23b02..2ecbf1b84d7 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -129,8 +129,8 @@ def __init__( self.config.output_hidden_states = True # Optionally, re-initialize the parameters of certain layers. - self._reinit_modules = reinit_modules - if self._reinit_modules and load_weights: + self._reinit_modules = reinit_modules if load_weights else None + if self._reinit_modules is not None: num_layers = self.transformer_model.config.num_hidden_layers if isinstance(self._reinit_modules, int): self._reinit_modules = list(range(num_layers - self._reinit_modules, num_layers)) diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 38fe4149682..6721b9aba39 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -343,14 +343,15 @@ def test_embeddings_resize(self): == 28997 ) - def test_reinit_modules_with_layer_indices(self): - # Test the base case, where reinit_modules is None. + def test_reinit_modules_no_ops(self): transformer_model = cached_transformers.get("bert-base-cased", True) # Comparing all weights of the model is rather complicated, so arbitrarily compare the # weights of attention module. preinit_weights = torch.cat( [layer.attention.output.dense.weight for layer in transformer_model.encoder.layer] ) + + # Test the case where reinit_modules is None regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased") postinit_weights = torch.cat( [ @@ -361,6 +362,21 @@ def test_reinit_modules_with_layer_indices(self): assert regular_token_embedder._reinit_modules is None assert torch.equal(postinit_weights, preinit_weights) + # Test the case when reinit_modules is a valid argument but load_weights is False. + reinit_token_embedder = PretrainedTransformerEmbedder( + "bert-base-cased", reinit_modules=2, load_weights=False + ) + assert reinit_token_embedder._reinit_modules is None + assert torch.equal(postinit_weights, preinit_weights) + + def test_reinit_modules_with_layer_indices(self): + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of attention module. + transformer_model = cached_transformers.get("bert-base-cased", True) + preinit_weights = torch.cat( + [layer.attention.output.dense.weight for layer in transformer_model.encoder.layer] + ) + # Test the case when reinit_modules is a valid int. reinit_token_embedder = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=2) postinit_weights = torch.cat( From f9b5164a86576966f630035d5148a01c9da0d829 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 23 Dec 2021 12:35:00 -0500 Subject: [PATCH 12/17] Revert pretrained transformer embedder to main --- .../pretrained_transformer_embedder.py | 55 +---------- .../pretrained_transformer_embedder_test.py | 95 +------------------ 2 files changed, 5 insertions(+), 145 deletions(-) diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 2ecbf1b84d7..6275f86de70 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -1,16 +1,16 @@ import logging import math -import re -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Optional, Tuple, Dict, Any + import torch import torch.nn.functional as F -from allennlp.common.checks import ConfigurationError +from transformers import XLNetConfig + from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.modules.scalar_mix import ScalarMix from allennlp.modules.token_embedders.token_embedder import TokenEmbedder from allennlp.nn.util import batched_index_select -from transformers import XLNetConfig logger = logging.getLogger(__name__) @@ -48,14 +48,6 @@ class PretrainedTransformerEmbedder(TokenEmbedder): When `True` (the default), only the final layer of the pretrained transformer is taken for the embeddings. But if set to `False`, a scalar mix of all of the layers is used. - reinit_modules: `Optional[Union[int, List[int]]]`, optional (default = `None`) - If this is an integer, the last `reinit_modules` layers of the transformer will be - re-initialized. If this is a list of integers, the layers indexed by `reinit_modules` will - be re-initialized. If this is a list of strings, they will be treated as regexes and any - module with a name matching the regex will be re-initialized. Re-initializing the last few - layers of a pretrained transformer can reduce the instability of fine-tuning on small - datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect - if `load_weights` is `False`. override_weights_file: `Optional[str]`, optional (default = `None`) If set, this specifies a file from which to load alternate weights that override the weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created @@ -90,7 +82,6 @@ def __init__( train_parameters: bool = True, eval_mode: bool = False, last_layer_only: bool = True, - reinit_modules: Optional[Union[int, List[int], List[str]]] = None, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, load_weights: bool = True, @@ -128,44 +119,6 @@ def __init__( self._scalar_mix = ScalarMix(self.config.num_hidden_layers) self.config.output_hidden_states = True - # Optionally, re-initialize the parameters of certain layers. - self._reinit_modules = reinit_modules if load_weights else None - if self._reinit_modules is not None: - num_layers = self.transformer_model.config.num_hidden_layers - if isinstance(self._reinit_modules, int): - self._reinit_modules = list(range(num_layers - self._reinit_modules, num_layers)) - # This type cast is neccessary to avoid a mypy error. - self._reinit_modules = cast(list, self._reinit_modules) - if all(isinstance(x, int) for x in self._reinit_modules): - self._reinit_modules = cast(List[int], self._reinit_modules) - if any( - layer_idx < 0 or layer_idx > num_layers for layer_idx in self._reinit_modules - ): - raise ValueError( - f"A layer index in reinit_modules ({self._reinit_modules}) is invalid." - f" Must be between 0 and the maximum layer index ({num_layers - 1}.)" - ) - # Some transformer models organize their modules differently, so if this fails, - # raise an error with a helpful message. - try: - for layer_idx in self._reinit_modules: - self.transformer_model.encoder.layer[layer_idx].apply( - self.transformer_model._init_weights - ) - except AttributeError: - raise ConfigurationError( - f"Unable to re-initialize the layers of transformer model" - f" {model_name} using layer indices. Please provide a list of" - " strings corresponding to the names of the layers to re-initialize." - ) - elif all(isinstance(x, str) for x in self._reinit_modules): - for regex in self._reinit_modules: - for name, module in self.transformer_model.named_modules(): - if re.search(regex, name): - module.apply(self.transformer_model._init_weights) - else: - raise ValueError("reinit_modules must be a list of strings or a list of integers.") - tokenizer = PretrainedTransformerTokenizer( model_name, tokenizer_kwargs=tokenizer_kwargs, diff --git a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py index 6721b9aba39..6cd14a1cf9e 100644 --- a/tests/modules/token_embedders/pretrained_transformer_embedder_test.py +++ b/tests/modules/token_embedders/pretrained_transformer_embedder_test.py @@ -1,10 +1,8 @@ -import copy import math - import pytest import torch + from allennlp.common import Params, cached_transformers -from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase, requires_gpu from allennlp.data import Vocabulary from allennlp.data.batch import Batch @@ -343,97 +341,6 @@ def test_embeddings_resize(self): == 28997 ) - def test_reinit_modules_no_ops(self): - transformer_model = cached_transformers.get("bert-base-cased", True) - # Comparing all weights of the model is rather complicated, so arbitrarily compare the - # weights of attention module. - preinit_weights = torch.cat( - [layer.attention.output.dense.weight for layer in transformer_model.encoder.layer] - ) - - # Test the case where reinit_modules is None - regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased") - postinit_weights = torch.cat( - [ - layer.attention.output.dense.weight - for layer in regular_token_embedder.transformer_model.encoder.layer - ] - ) - assert regular_token_embedder._reinit_modules is None - assert torch.equal(postinit_weights, preinit_weights) - - # Test the case when reinit_modules is a valid argument but load_weights is False. - reinit_token_embedder = PretrainedTransformerEmbedder( - "bert-base-cased", reinit_modules=2, load_weights=False - ) - assert reinit_token_embedder._reinit_modules is None - assert torch.equal(postinit_weights, preinit_weights) - - def test_reinit_modules_with_layer_indices(self): - # Comparing all weights of the model is rather complicated, so arbitrarily compare the - # weights of attention module. - transformer_model = cached_transformers.get("bert-base-cased", True) - preinit_weights = torch.cat( - [layer.attention.output.dense.weight for layer in transformer_model.encoder.layer] - ) - - # Test the case when reinit_modules is a valid int. - reinit_token_embedder = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=2) - postinit_weights = torch.cat( - [ - layer.attention.output.dense.weight - for layer in reinit_token_embedder.transformer_model.encoder.layer - ] - ) - assert reinit_token_embedder._reinit_modules == [10, 11] - assert torch.equal(postinit_weights[:10], preinit_weights[:10]) - assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) - - # Test the case when reinit_modules is a valid list of integers. - reinit_token_embedder = PretrainedTransformerEmbedder( - "bert-base-cased", reinit_modules=[10, 11] - ) - assert reinit_token_embedder._reinit_modules == [10, 11] - assert torch.equal(postinit_weights[:10], preinit_weights[:10]) - assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) - - # Should raise a ValueError because reinit_modules contains at least one index that is - # greater than the models maximum number of layers - with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=1000) - with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, 1000]) - # The argument cannot mix layer indices and regex strings. - with pytest.raises(ValueError): - _ = PretrainedTransformerEmbedder("bert-base-cased", reinit_modules=[1, "attentions"]) - # This model has a non-standard structure, so if a layer index or list of layer indexes - # is provided, we raise a ConfigurationError. - with pytest.raises(ConfigurationError): - _ = PretrainedTransformerEmbedder("sshleifer/tiny-gpt2", reinit_modules=1) - with pytest.raises(ConfigurationError): - _ = PretrainedTransformerEmbedder("sshleifer/tiny-gpt2", reinit_modules=[1, 2]) - - def test_reinit_modules_with_regex_strings(self): - # Test the case where reinit_modules is a list of regex strings. - transformer_model = cached_transformers.get("sshleifer/tiny-gpt2", True) - # Comparing all weights of the model is rather complicated, so arbitrarily compare the - # weights of wpe module. - reinit_module = "wpe" - # This MUST be a deep copy, otherwise the parameters will be re-initialized and the - # test will break. - preinit_weights = copy.deepcopy( - list(transformer_model.get_submodule(reinit_module).parameters()) - ) - reinit_token_embedder = PretrainedTransformerEmbedder( - "sshleifer/tiny-gpt2", reinit_modules=[reinit_module] - ) - postinit_weights = list( - reinit_token_embedder.transformer_model.get_submodule(reinit_module).parameters() - ) - assert all( - (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights)) - ) - def test_eval_mode(self): token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True) assert token_embedder.training and not token_embedder.transformer_model.training From d8392ee61c4449a8ffad84029b81b6b429850b83 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 23 Dec 2021 12:38:32 -0500 Subject: [PATCH 13/17] Move reinit_modules feature to cached transformers --- CHANGELOG.md | 2 +- allennlp/common/cached_transformers.py | 72 +++++++++++- .../pretrained_transformer_embedder.py | 16 ++- tests/common/cached_transformers_test.py | 103 +++++++++++++++++- 4 files changed, 178 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8db78d2bccf..e214b5ede1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added a way to resize the vocabulary in the T5 module -- Added an argument `reinit_modules` to `PretrainedTransformerEmbedder` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings. +- Added an argument `reinit_modules` to `cached_transformers.get()` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings. ### Fixed diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index bc7cc4a6dfd..fcbc72ece2b 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -1,10 +1,11 @@ import logging +import re import warnings -from typing import NamedTuple, Optional, Dict, Tuple +from typing import Dict, NamedTuple, Optional, Tuple, Union, cast import transformers -from transformers import AutoModel, AutoConfig - +from allennlp.common.checks import ConfigurationError +from transformers import AutoConfig, AutoModel logger = logging.getLogger(__name__) @@ -13,6 +14,7 @@ class TransformerSpec(NamedTuple): model_name: str override_weights_file: Optional[str] = None override_weights_strip_prefix: Optional[str] = None + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None _model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {} @@ -23,6 +25,7 @@ def get( make_copy: bool, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None, load_weights: bool = True, **kwargs, ) -> transformers.PreTrainedModel: @@ -43,13 +46,26 @@ def get( with `torch.save()`. override_weights_strip_prefix : `str`, optional (default = `None`) If set, strip the given prefix from the state dict when loading it. + reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will + be re-initialized. If this is a tuple of strings, they will be treated as regexes and any + module with a name matching the regex will be re-initialized. Re-initializing the last few + layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False` or `override_weights_file` is not None. load_weights : `bool`, optional (default = `True`) If set to `False`, no weights will be loaded. This is helpful when you only want to initialize the architecture, like when you've already fine-tuned a model and are going to load the weights from a state dict elsewhere. """ global _model_cache - spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix) + spec = TransformerSpec( + model_name, + override_weights_file, + override_weights_strip_prefix, + reinit_modules, + ) transformer = _model_cache.get(spec, None) if transformer is None: if not load_weights: @@ -59,6 +75,12 @@ def get( "but 'load_weights' is set to False, so 'override_weights_file' will be ignored.", UserWarning, ) + if reinit_modules is not None: + warnings.warn( + "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), " + "but 'load_weights' is set to False, so 'reinit_modules' will be ignored.", + UserWarning, + ) transformer = AutoModel.from_config( AutoConfig.from_pretrained( model_name, @@ -66,8 +88,14 @@ def get( ) ) elif override_weights_file is not None: - from allennlp.common.file_utils import cached_path + if reinit_modules is not None: + warnings.warn( + "You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), " + "but 'override_weights_file' is not None, so 'reinit_modules' will be ignored.", + UserWarning, + ) import torch + from allennlp.common.file_utils import cached_path override_weights_file = cached_path(override_weights_file) override_weights = torch.load(override_weights_file) @@ -110,6 +138,40 @@ def strip_prefix(s): transformer.module.load_state_dict(override_weights) else: transformer.load_state_dict(override_weights) + elif reinit_modules is not None: + transformer = AutoModel.from_pretrained( + model_name, + **kwargs, + ) + num_layers = transformer.config.num_hidden_layers + if isinstance(reinit_modules, int): + reinit_modules = tuple(range(num_layers - reinit_modules, num_layers)) + if all(isinstance(x, int) for x in reinit_modules): + # This type cast is neccessary to avoid a mypy error. + reinit_modules = cast(Tuple[int], reinit_modules) + if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in reinit_modules): + raise ValueError( + f"A layer index in reinit_modules ({reinit_modules}) is invalid." + f" Must be between 0 and the maximum layer index ({num_layers - 1}.)" + ) + # Some transformer models organize their modules differently, so if this fails, + # raise an error with a helpful message. + try: + for layer_idx in reinit_modules: + transformer.encoder.layer[layer_idx].apply(transformer._init_weights) + except AttributeError: + raise ConfigurationError( + f"Unable to re-initialize the layers of transformer model" + f" {model_name} using layer indices. Please provide a list of" + " strings corresponding to the names of the layers to re-initialize." + ) + elif all(isinstance(x, str) for x in reinit_modules): + for regex in reinit_modules: + for name, module in transformer.named_modules(): + if re.search(regex, name): + module.apply(transformer._init_weights) + else: + raise ValueError("reinit_modules must be a list of strings or a list of integers.") else: transformer = AutoModel.from_pretrained( model_name, diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 6275f86de70..4ec7df8dd18 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -1,16 +1,14 @@ import logging import math -from typing import Optional, Tuple, Dict, Any - +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn.functional as F -from transformers import XLNetConfig - from allennlp.data.tokenizers import PretrainedTransformerTokenizer from allennlp.modules.scalar_mix import ScalarMix from allennlp.modules.token_embedders.token_embedder import TokenEmbedder from allennlp.nn.util import batched_index_select +from transformers import XLNetConfig logger = logging.getLogger(__name__) @@ -54,6 +52,14 @@ class PretrainedTransformerEmbedder(TokenEmbedder): with `torch.save()`. override_weights_strip_prefix: `Optional[str]`, optional (default = `None`) If set, strip the given prefix from the state dict when loading it. + reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) + If this is an integer, the last `reinit_modules` layers of the transformer will be + re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will + be re-initialized. If this is a tuple of strings, they will be treated as regexes and any + module with a name matching the regex will be re-initialized. Re-initializing the last few + layers of a pretrained transformer can reduce the instability of fine-tuning on small + datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect + if `load_weights` is `False` or `override_weights_file` is not None. load_weights: `bool`, optional (default = `True`) Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive it usually makes sense to set this to `False` (via the `overrides` parameter) @@ -84,6 +90,7 @@ def __init__( last_layer_only: bool = True, override_weights_file: Optional[str] = None, override_weights_strip_prefix: Optional[str] = None, + reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None, load_weights: bool = True, gradient_checkpointing: Optional[bool] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, @@ -97,6 +104,7 @@ def __init__( True, override_weights_file=override_weights_file, override_weights_strip_prefix=override_weights_strip_prefix, + reinit_modules=reinit_modules, load_weights=load_weights, **(transformer_kwargs or {}), ) diff --git a/tests/common/cached_transformers_test.py b/tests/common/cached_transformers_test.py index e0a8bee054b..409da5bb8be 100644 --- a/tests/common/cached_transformers_test.py +++ b/tests/common/cached_transformers_test.py @@ -1,12 +1,12 @@ -import pytest -import torch -import os import json +import os +import pytest +import torch from allennlp.common import cached_transformers +from allennlp.common.checks import ConfigurationError from allennlp.common.testing import AllenNlpTestCase - -from transformers import AutoModel, AutoConfig +from transformers import AutoConfig, AutoModel class TestCachedTransformers(AllenNlpTestCase): @@ -72,6 +72,99 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self): for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()): assert p1.data.ne(p2.data).sum() == 0 + def test_reinit_modules_no_op(self): + # Test the case where reinit_modules is None (default) + preinit_weights = torch.cat( + [ + # Comparing all weights of the model is rather complicated, so arbitrarily + # compare the weights of attention module. + layer.attention.output.dense.weight + for layer in cached_transformers.get("bert-base-cased", True).encoder.layer + ] + ) + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get("bert-base-cased", True).encoder.layer + ] + ) + assert torch.equal(postinit_weights, preinit_weights) + + def test_reinit_modules_with_layer_indices(self): + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of attention module. + preinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get("bert-base-cased", True).encoder.layer + ] + ) + + # Test the case when reinit_modules is a valid int. + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get( + "bert-base-cased", True, reinit_modules=2 + ).encoder.layer + ] + ) + assert torch.equal(postinit_weights[:10], preinit_weights[:10]) + assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + + # Test the case when reinit_modules is a valid list of integers. + postinit_weights = torch.cat( + [ + layer.attention.output.dense.weight + for layer in cached_transformers.get( + "bert-base-cased", True, reinit_modules=(10, 11) + ).encoder.layer + ] + ) + assert torch.equal(postinit_weights[:10], preinit_weights[:10]) + assert not torch.equal(postinit_weights[10:], preinit_weights[10:]) + + # Should raise a ValueError because reinit_modules contains at least one index that is + # greater than the models maximum number of layers + with pytest.raises(ValueError): + _ = cached_transformers.get("bert-base-cased", True, reinit_modules=1000) + with pytest.raises(ValueError): + _ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, 1000)) + # The argument cannot mix layer indices and regex strings. + with pytest.raises(ValueError): + _ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, "attentions")) + # This model has a non-standard structure, so if a layer index or list of layer indexes + # is provided, we raise a ConfigurationError. + with pytest.raises(ConfigurationError): + _ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=1) + with pytest.raises(ConfigurationError): + _ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=(1, 2)) + + def test_reinit_modules_with_regex_strings(self): + # Comparing all weights of the model is rather complicated, so arbitrarily compare the + # weights of wpe module. + reinit_module = "wpe" + # This MUST be a deep copy, otherwise the parameters will be re-initialized and the + # test will break. + preinit_weights = list( + cached_transformers.get("sshleifer/tiny-gpt2", True) + .get_submodule(reinit_module) + .parameters() + ) + + postinit_weights = list( + cached_transformers.get( + "sshleifer/tiny-gpt2", + True, + reinit_modules=(reinit_module,), + ) + .get_submodule(reinit_module) + .parameters() + ) + assert all( + (not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights)) + ) + def test_from_pretrained_no_load_weights(self): _ = cached_transformers.get( "epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR From dae591d445ae6a4e9d43456b75d0a5d254af4869 Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Thu, 23 Dec 2021 15:08:03 -0500 Subject: [PATCH 14/17] Better error message for invalid reinit_modules argument Co-authored-by: Pete --- allennlp/common/cached_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index fcbc72ece2b..f2a0c1e8bed 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -171,7 +171,7 @@ def strip_prefix(s): if re.search(regex, name): module.apply(transformer._init_weights) else: - raise ValueError("reinit_modules must be a list of strings or a list of integers.") + raise ValueError("reinit_modules must be either an integer, a list of strings, or a list of integers.") else: transformer = AutoModel.from_pretrained( model_name, From 440705158893aa39291ae14c0fd37c305f160d85 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 23 Dec 2021 15:09:26 -0500 Subject: [PATCH 15/17] Correct error message to say tuple, not list --- allennlp/common/cached_transformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index f2a0c1e8bed..107bfdfdec5 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -171,7 +171,9 @@ def strip_prefix(s): if re.search(regex, name): module.apply(transformer._init_weights) else: - raise ValueError("reinit_modules must be either an integer, a list of strings, or a list of integers.") + raise ValueError( + "reinit_modules must be either an integer, a tuple of strings, or a tuple of integers." + ) else: transformer = AutoModel.from_pretrained( model_name, From f371e5cef2b1e984d7ad8c2b53148a8e9bc2c132 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 23 Dec 2021 15:14:33 -0500 Subject: [PATCH 16/17] Add note about layer indices failing in docstring --- allennlp/common/cached_transformers.py | 10 ++++++---- .../token_embedders/pretrained_transformer_embedder.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index 107bfdfdec5..0e61512162c 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -49,11 +49,13 @@ def get( reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) If this is an integer, the last `reinit_modules` layers of the transformer will be re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will - be re-initialized. If this is a tuple of strings, they will be treated as regexes and any - module with a name matching the regex will be re-initialized. Re-initializing the last few - layers of a pretrained transformer can reduce the instability of fine-tuning on small + be re-initialized. Note, because the module structure of the transformer `model_name` can + differ, we cannot guarantee that providing an integer or tuple of integers will work. If + this fails, you can instead provide a tuple of strings, which will be treated as regexes and + any module with a name matching the regex will be re-initialized. Re-initializing the last + few layers of a pretrained transformer can reduce the instability of fine-tuning on small datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect - if `load_weights` is `False` or `override_weights_file` is not None. + if `load_weights` is `False` or `override_weights_file` is not `None`. load_weights : `bool`, optional (default = `True`) If set to `False`, no weights will be loaded. This is helpful when you only want to initialize the architecture, like when you've already fine-tuned a model diff --git a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py index 4ec7df8dd18..6e819093c65 100644 --- a/allennlp/modules/token_embedders/pretrained_transformer_embedder.py +++ b/allennlp/modules/token_embedders/pretrained_transformer_embedder.py @@ -55,11 +55,13 @@ class PretrainedTransformerEmbedder(TokenEmbedder): reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`) If this is an integer, the last `reinit_modules` layers of the transformer will be re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will - be re-initialized. If this is a tuple of strings, they will be treated as regexes and any - module with a name matching the regex will be re-initialized. Re-initializing the last few - layers of a pretrained transformer can reduce the instability of fine-tuning on small + be re-initialized. Note, because the module structure of the transformer `model_name` can + differ, we cannot guarantee that providing an integer or tuple of integers will work. If + this fails, you can instead provide a tuple of strings, which will be treated as regexes and + any module with a name matching the regex will be re-initialized. Re-initializing the last + few layers of a pretrained transformer can reduce the instability of fine-tuning on small datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect - if `load_weights` is `False` or `override_weights_file` is not None. + if `load_weights` is `False` or `override_weights_file` is not `None`. load_weights: `bool`, optional (default = `True`) Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive it usually makes sense to set this to `False` (via the `overrides` parameter) From ed39b05be0aa820eed736380028e3e10aabc7a55 Mon Sep 17 00:00:00 2001 From: johngiorgi Date: Thu, 23 Dec 2021 15:15:42 -0500 Subject: [PATCH 17/17] Correct "list" with "tuple" in error message --- allennlp/common/cached_transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index 0e61512162c..3177faa2ab8 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -164,7 +164,7 @@ def strip_prefix(s): except AttributeError: raise ConfigurationError( f"Unable to re-initialize the layers of transformer model" - f" {model_name} using layer indices. Please provide a list of" + f" {model_name} using layer indices. Please provide a tuple of" " strings corresponding to the names of the layers to re-initialize." ) elif all(isinstance(x, str) for x in reinit_modules):