Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fe91837
Arg to re-init some layers of pretrained transformer
JohnGiorgi Dec 10, 2021
1721882
Merge branch 'allenai:main' into reinit-layers-of-pretrained-transfor…
JohnGiorgi Dec 10, 2021
c75bc56
Better error handelling for bad indices
JohnGiorgi Dec 10, 2021
69fef36
Add a test for reinit_layers
JohnGiorgi Dec 10, 2021
e6799ca
Type cast to appease mypy
JohnGiorgi Dec 10, 2021
5e55ad0
Get number of hidden layers in model agnostic way
JohnGiorgi Dec 15, 2021
fe42979
Support a list of regexes in reinit_modules
JohnGiorgi Dec 16, 2021
0c33cf8
Add tests for when reinit_modules is a list of str
JohnGiorgi Dec 16, 2021
a84d069
Fix broken test for re-initializing modules
JohnGiorgi Dec 20, 2021
39d92ca
Break reinit_modules unit test into two
JohnGiorgi Dec 21, 2021
718020d
Update changelog
JohnGiorgi Dec 21, 2021
9d79151
Tests for when reinit_modules should have no effect
JohnGiorgi Dec 21, 2021
284f76c
Merge branch 'main' into reinit-layers-of-pretrained-transformer-embe…
epwalsh Dec 22, 2021
f9b5164
Revert pretrained transformer embedder to main
JohnGiorgi Dec 23, 2021
d8392ee
Move reinit_modules feature to cached transformers
JohnGiorgi Dec 23, 2021
dae591d
Better error message for invalid reinit_modules argument
JohnGiorgi Dec 23, 2021
4407051
Correct error message to say tuple, not list
JohnGiorgi Dec 23, 2021
f371e5c
Add note about layer indices failing in docstring
JohnGiorgi Dec 23, 2021
ed39b05
Correct "list" with "tuple" in error message
JohnGiorgi Dec 23, 2021
01fc4d6
Merge branch 'main' into reinit-layers-of-pretrained-transformer-embe…
epwalsh Dec 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import math
from typing import Optional, Tuple, Dict, Any
from typing import Optional, Tuple, Dict, Any, Union, List, cast


import torch
Expand Down Expand Up @@ -48,6 +48,12 @@ class PretrainedTransformerEmbedder(TokenEmbedder):
When `True` (the default), only the final layer of the pretrained transformer is taken
for the embeddings. But if set to `False`, a scalar mix of all of the layers
is used.
reinit_layers: `Optional[Union[int, List[int]]]`, optional (default = `None`)
If this is an integer, the last `reinit_layers` layers of the transformer will be
re-initialized. If this is a list, the layers indexed by `reinit_layers` will be
re-initialized. Re-initializing the last few layers of a pretrained transformer can reduce
the instability of fine-tuning on small datasets and may improve performance
(https://arxiv.org/abs/2006.05987v3). Has no effect if `load_weights` is `False`.
override_weights_file: `Optional[str]`, optional (default = `None`)
If set, this specifies a file from which to load alternate weights that override the
weights from huggingface. The file is expected to contain a PyTorch `state_dict`, created
Expand Down Expand Up @@ -82,6 +88,7 @@ def __init__(
train_parameters: bool = True,
eval_mode: bool = False,
last_layer_only: bool = True,
reinit_layers: Optional[Union[int, List[int]]] = None,
override_weights_file: Optional[str] = None,
override_weights_strip_prefix: Optional[str] = None,
load_weights: bool = True,
Expand Down Expand Up @@ -119,6 +126,22 @@ def __init__(
self._scalar_mix = ScalarMix(self.config.num_hidden_layers)
self.config.output_hidden_states = True

# Optionally, re-initialize the parameters of certain layers.
self._reinit_layers = cast(List[int], reinit_layers)
if self._reinit_layers and load_weights:
num_layers = len(self.transformer_model.encoder.layer)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work fine for BERT, but won't for other models like XLM that name their modules differently. In general I think it'll be brittle to rely on guessing the internal module structure of transformer models.

But I can think of alternative which - although it's a little less friendly from a user perspective - still provides this useful functionality:

  • Change the type of reinit_layers to Optional[Union[int, List[int], List[str]]].
  • When reinit_layers is List[str], we interpret it as a list of regular expressions and we reinitialize all modules that match any of the regular expressions in the list.
  • Otherwise when reinit_layers is int or List[int] we try the current approach, but if the given transformer_model does not have this module structure we throw a ConfigurationError and suggest that the user use the List[str] form of this parameter instead.

@JohnGiorgi JohnGiorgi Dec 16, 2021

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I took a crack at this.

  • I renamed reinit_layers to reinit_modules just because re-initializing modules (instead of whole layers) is possible when it is a list of regexes.
  • Updated the type of the argument to Optional[Union[int, List[int], List[str]]]
  • A configuration error is thrown if the model's layers cannot be easily indexed. In practice, most models I tried from hugging face could (BERTs, RoBERTa's) but you were right that XLM based models cant.
  • Added tests for when reinit_modules is a list of regex strings.

Unfortunately, I cant seem to get the weights to re-initialize. I figured that after finding a match between the regex and some module name I could just call module.apply(self.transformer_model._init_weights) like so

for regex in self._reinit_modules:
    for name, module in self.transformer_model.named_modules():
        if re.search(regex, name):
            module.apply(self.transformer_model._init_weights)

but this doesn't actually re-init the modules weights? Strange. I will have to dig into the HF repo and docs to figure out why.

EDIT: My test was broken. I was making a shallow copy of the weights pre re-initialization, which were getting mutated.

if isinstance(reinit_layers, int):
self._reinit_layers = list(range(num_layers - reinit_layers, num_layers))
if any(layer_idx < 0 or layer_idx > num_layers for layer_idx in self._reinit_layers):
raise ValueError(
f"A layer index in reinit_layers ({self._reinit_layers}) is invalid. Must be"
f" between 0 and the maximum layer index ({num_layers - 1}.)"
)
for layer_idx in self._reinit_layers:
self.transformer_model.encoder.layer[layer_idx].apply(
self.transformer_model._init_weights
)

tokenizer = PretrainedTransformerTokenizer(
model_name,
tokenizer_kwargs=tokenizer_kwargs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,41 @@ def test_embeddings_resize(self):
== 28997
)

def test_reinit_layers(self):
regular_token_embedder = PretrainedTransformerEmbedder("bert-base-cased")
assert regular_token_embedder._reinit_layers is None
# Test the case when reinit_layers is a valid int. Comparing all weights of the model is
# rather complicated, so arbitrarily compare the weights of attention module.
preinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in regular_token_embedder.transformer_model.encoder.layer
]
)
reinit_token_embedder = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=2)
postinit_weights = torch.cat(
[
layer.attention.output.dense.weight
for layer in reinit_token_embedder.transformer_model.encoder.layer
]
)
assert reinit_token_embedder._reinit_layers == [10, 11]
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])
# Test the case when reinit_layers is a valid list of integers.
reinit_token_embedder = PretrainedTransformerEmbedder(
"bert-base-cased", reinit_layers=[10, 11]
)
assert reinit_token_embedder._reinit_layers == [10, 11]
assert torch.equal(postinit_weights[:10], preinit_weights[:10])
assert not torch.equal(postinit_weights[10:], preinit_weights[10:])
# Should raise a ValueError because reinit_layers contains at least one index that is
# greater than the models maximum number of layers
with pytest.raises(ValueError):
_ = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=1000)
with pytest.raises(ValueError):
_ = PretrainedTransformerEmbedder("bert-base-cased", reinit_layers=[1, 1000])

def test_eval_mode(self):
token_embedder = PretrainedTransformerEmbedder("epwalsh/bert-xsmall-dummy", eval_mode=True)
assert token_embedder.training and not token_embedder.transformer_model.training
Expand Down