Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fe91837
Arg to re-init some layers of pretrained transformer
JohnGiorgi Dec 10, 2021
1721882
Merge branch 'allenai:main' into reinit-layers-of-pretrained-transfor…
JohnGiorgi Dec 10, 2021
c75bc56
Better error handelling for bad indices
JohnGiorgi Dec 10, 2021
69fef36
Add a test for reinit_layers
JohnGiorgi Dec 10, 2021
e6799ca
Type cast to appease mypy
JohnGiorgi Dec 10, 2021
5e55ad0
Get number of hidden layers in model agnostic way
JohnGiorgi Dec 15, 2021
fe42979
Support a list of regexes in reinit_modules
JohnGiorgi Dec 16, 2021
0c33cf8
Add tests for when reinit_modules is a list of str
JohnGiorgi Dec 16, 2021
a84d069
Fix broken test for re-initializing modules
JohnGiorgi Dec 20, 2021
39d92ca
Break reinit_modules unit test into two
JohnGiorgi Dec 21, 2021
718020d
Update changelog
JohnGiorgi Dec 21, 2021
9d79151
Tests for when reinit_modules should have no effect
JohnGiorgi Dec 21, 2021
284f76c
Merge branch 'main' into reinit-layers-of-pretrained-transformer-embe…
epwalsh Dec 22, 2021
f9b5164
Revert pretrained transformer embedder to main
JohnGiorgi Dec 23, 2021
d8392ee
Move reinit_modules feature to cached transformers
JohnGiorgi Dec 23, 2021
dae591d
Better error message for invalid reinit_modules argument
JohnGiorgi Dec 23, 2021
4407051
Correct error message to say tuple, not list
JohnGiorgi Dec 23, 2021
f371e5c
Add note about layer indices failing in docstring
JohnGiorgi Dec 23, 2021
ed39b05
Correct "list" with "tuple" in error message
JohnGiorgi Dec 23, 2021
01fc4d6
Merge branch 'main' into reinit-layers-of-pretrained-transformer-embe…
epwalsh Dec 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added a way to resize the vocabulary in the T5 module
- Added an argument `reinit_modules` to `cached_transformers.get()` that allows you to re-initialize the pretrained weights of a transformer model, using layer indices or regex strings.

### Fixed

Expand Down
76 changes: 71 additions & 5 deletions allennlp/common/cached_transformers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import re
import warnings
from typing import NamedTuple, Optional, Dict, Tuple
from typing import Dict, NamedTuple, Optional, Tuple, Union, cast

import transformers
from transformers import AutoModel, AutoConfig

from allennlp.common.checks import ConfigurationError
from transformers import AutoConfig, AutoModel

logger = logging.getLogger(__name__)

Expand All @@ -13,6 +14,7 @@ class TransformerSpec(NamedTuple):
model_name: str
override_weights_file: Optional[str] = None
override_weights_strip_prefix: Optional[str] = None
reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None


_model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {}
Expand All @@ -23,6 +25,7 @@ def get(
make_copy: bool,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None,
load_weights: bool = True,
**kwargs,
) -> transformers.PreTrainedModel:
Expand All @@ -43,13 +46,28 @@ def get(
with `torch.save()`.
override_weights_strip_prefix : `str`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`)
If this is an integer, the last `reinit_modules` layers of the transformer will be
re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
be re-initialized. Note, because the module structure of the transformer `model_name` can
differ, we cannot guarantee that providing an integer or tuple of integers will work. If
this fails, you can instead provide a tuple of strings, which will be treated as regexes and
any module with a name matching the regex will be re-initialized. Re-initializing the last
Comment on lines +52 to +55

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@epwalsh Updated part of docstring here

few layers of a pretrained transformer can reduce the instability of fine-tuning on small
datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
if `load_weights` is `False` or `override_weights_file` is not `None`.
load_weights : `bool`, optional (default = `True`)
If set to `False`, no weights will be loaded. This is helpful when you only
want to initialize the architecture, like when you've already fine-tuned a model
and are going to load the weights from a state dict elsewhere.
"""
global _model_cache
spec = TransformerSpec(model_name, override_weights_file, override_weights_strip_prefix)
spec = TransformerSpec(
model_name,
override_weights_file,
override_weights_strip_prefix,
reinit_modules,
)
transformer = _model_cache.get(spec, None)
if transformer is None:
if not load_weights:
Expand All @@ -59,15 +77,27 @@ def get(
"but 'load_weights' is set to False, so 'override_weights_file' will be ignored.",
UserWarning,
)
if reinit_modules is not None:
warnings.warn(
"You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), "
"but 'load_weights' is set to False, so 'reinit_modules' will be ignored.",
UserWarning,
)
transformer = AutoModel.from_config(
AutoConfig.from_pretrained(
model_name,
**kwargs,
)
)
elif override_weights_file is not None:
from allennlp.common.file_utils import cached_path
if reinit_modules is not None:
warnings.warn(
"You specified 'reinit_modules' in allennlp.common.cached_transformers.get(), "
"but 'override_weights_file' is not None, so 'reinit_modules' will be ignored.",
UserWarning,
)
import torch
from allennlp.common.file_utils import cached_path

override_weights_file = cached_path(override_weights_file)
override_weights = torch.load(override_weights_file)
Expand Down Expand Up @@ -110,6 +140,42 @@ def strip_prefix(s):
transformer.module.load_state_dict(override_weights)
else:
transformer.load_state_dict(override_weights)
elif reinit_modules is not None:
transformer = AutoModel.from_pretrained(
model_name,
**kwargs,
)
num_layers = transformer.config.num_hidden_layers
if isinstance(reinit_modules, int):
reinit_modules = tuple(range(num_layers - reinit_modules, num_layers))
if all(isinstance(x, int) for x in reinit_modules):
# This type cast is neccessary to avoid a mypy error.
reinit_modules = cast(Tuple[int], reinit_modules)
if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in reinit_modules):
raise ValueError(
f"A layer index in reinit_modules ({reinit_modules}) is invalid."
f" Must be between 0 and the maximum layer index ({num_layers - 1}.)"
)
# Some transformer models organize their modules differently, so if this fails,
# raise an error with a helpful message.
try:
for layer_idx in reinit_modules:
transformer.encoder.layer[layer_idx].apply(transformer._init_weights)
except AttributeError:
raise ConfigurationError(
f"Unable to re-initialize the layers of transformer model"
f" {model_name} using layer indices. Please provide a tuple of"
" strings corresponding to the names of the layers to re-initialize."
)
elif all(isinstance(x, str) for x in reinit_modules):
for regex in reinit_modules:
for name, module in transformer.named_modules():
if re.search(regex, name):
module.apply(transformer._init_weights)
else:
raise ValueError(
"reinit_modules must be either an integer, a tuple of strings, or a tuple of integers."
)
else:
transformer = AutoModel.from_pretrained(
model_name,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import logging
import math
from typing import Optional, Tuple, Dict, Any

from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from transformers import XLNetConfig

from allennlp.data.tokenizers import PretrainedTransformerTokenizer
from allennlp.modules.scalar_mix import ScalarMix
from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
from allennlp.nn.util import batched_index_select
from transformers import XLNetConfig

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,6 +52,16 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
with `torch.save()`.
override_weights_strip_prefix: `Optional[str]`, optional (default = `None`)
If set, strip the given prefix from the state dict when loading it.
reinit_modules: `Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]]`, optional (default = `None`)
If this is an integer, the last `reinit_modules` layers of the transformer will be
re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will
be re-initialized. Note, because the module structure of the transformer `model_name` can
differ, we cannot guarantee that providing an integer or tuple of integers will work. If
this fails, you can instead provide a tuple of strings, which will be treated as regexes and
any module with a name matching the regex will be re-initialized. Re-initializing the last
few layers of a pretrained transformer can reduce the instability of fine-tuning on small
datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect
if `load_weights` is `False` or `override_weights_file` is not `None`.
load_weights: `bool`, optional (default = `True`)
Whether to load the pretrained weights. If you're loading your model/predictor from an AllenNLP archive
it usually makes sense to set this to `False` (via the `overrides` parameter)
Expand Down Expand Up @@ -84,6 +92,7 @@ def __init__(
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None,
load_weights: bool = True,
gradient_checkpointing: Optional[bool] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -97,6 +106,7 @@ def __init__(
True,
override_weights_file=override_weights_file,
override_weights_strip_prefix=override_weights_strip_prefix,
reinit_modules=reinit_modules,
load_weights=load_weights,
**(transformer_kwargs or {}),
)
Expand Down
103 changes: 98 additions & 5 deletions tests/common/cached_transformers_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pytest
import torch
import os
import json
import os

import pytest
import torch
from allennlp.common import cached_transformers
from allennlp.common.checks import ConfigurationError
from allennlp.common.testing import AllenNlpTestCase

from transformers import AutoModel, AutoConfig
from transformers import AutoConfig, AutoModel


class TestCachedTransformers(AllenNlpTestCase):
Expand Down Expand Up @@ -72,6 +72,99 @@ def test_from_pretrained_avoids_weights_download_if_override_weights(self):
for p1, p2 in zip(transformer.parameters(), override_transformer.parameters()):
assert p1.data.ne(p2.data).sum() == 0

def test_reinit_modules_no_op(self):
# Test the case where reinit_modules is None (default)
preinit_weights = torch.cat(
[
# Comparing all weights of the model is rather complicated, so arbitrarily
# compare the weights of attention module.
layer.attention.output.dense.weight
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
]
)
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
]
)
assert torch.equal(postinit_weights, preinit_weights)

def test_reinit_modules_with_layer_indices(self):
# Comparing all weights of the model is rather complicated, so arbitrarily compare the
# weights of attention module.
preinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get("bert-base-cased", True).encoder.layer
]
)

# Test the case when reinit_modules is a valid int.
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get(
"bert-base-cased", True, reinit_modules=2
).encoder.layer
]
)
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])

# Test the case when reinit_modules is a valid list of integers.
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in cached_transformers.get(
"bert-base-cased", True, reinit_modules=(10, 11)
).encoder.layer
]
)
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])

# Should raise a ValueError because reinit_modules contains at least one index that is
# greater than the models maximum number of layers
with pytest.raises(ValueError):
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=1000)
with pytest.raises(ValueError):
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, 1000))
# The argument cannot mix layer indices and regex strings.
with pytest.raises(ValueError):
_ = cached_transformers.get("bert-base-cased", True, reinit_modules=(1, "attentions"))
# This model has a non-standard structure, so if a layer index or list of layer indexes
# is provided, we raise a ConfigurationError.
with pytest.raises(ConfigurationError):
_ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=1)
with pytest.raises(ConfigurationError):
_ = cached_transformers.get("sshleifer/tiny-gpt2", True, reinit_modules=(1, 2))

def test_reinit_modules_with_regex_strings(self):
# Comparing all weights of the model is rather complicated, so arbitrarily compare the
# weights of wpe module.
reinit_module = "wpe"
# This MUST be a deep copy, otherwise the parameters will be re-initialized and the
# test will break.
preinit_weights = list(
cached_transformers.get("sshleifer/tiny-gpt2", True)
.get_submodule(reinit_module)
.parameters()
)

postinit_weights = list(
cached_transformers.get(
"sshleifer/tiny-gpt2",
True,
reinit_modules=(reinit_module,),
)
.get_submodule(reinit_module)
.parameters()
)
assert all(
(not torch.equal(pre, post) for pre, post in zip(preinit_weights, postinit_weights))
)

def test_from_pretrained_no_load_weights(self):
_ = cached_transformers.get(
"epwalsh/bert-xsmall-dummy", False, load_weights=False, cache_dir=self.TEST_DIR
Expand Down