Reinit layers of pretrained transformer in cached_transformers.get()#5505
Conversation
|
@epwalsh mind taking a look? |
| # Optionally, re-initialize the parameters of certain layers. | ||
| self._reinit_layers = cast(List[int], reinit_layers) | ||
| if self._reinit_layers and load_weights: | ||
| num_layers = len(self.transformer_model.encoder.layer) |
There was a problem hiding this comment.
This should work fine for BERT, but won't for other models like XLM that name their modules differently. In general I think it'll be brittle to rely on guessing the internal module structure of transformer models.
But I can think of alternative which - although it's a little less friendly from a user perspective - still provides this useful functionality:
- Change the type of
reinit_layerstoOptional[Union[int, List[int], List[str]]]. - When
reinit_layersisList[str], we interpret it as a list of regular expressions and we reinitialize all modules that match any of the regular expressions in the list. - Otherwise when
reinit_layersisintorList[int]we try the current approach, but if the giventransformer_modeldoes not have this module structure we throw aConfigurationErrorand suggest that the user use theList[str]form of this parameter instead.
There was a problem hiding this comment.
Okay, I took a crack at this.
- I renamed
reinit_layerstoreinit_modulesjust because re-initializing modules (instead of whole layers) is possible when it is a list of regexes. - Updated the type of the argument to
Optional[Union[int, List[int], List[str]]] - A configuration error is thrown if the model's layers cannot be easily indexed. In practice, most models I tried from hugging face could (BERTs, RoBERTa's) but you were right that XLM based models cant.
- Added tests for when
reinit_modulesis a list of regex strings.
Unfortunately, I cant seem to get the weights to re-initialize. I figured that after finding a match between the regex and some module name I could just call module.apply(self.transformer_model._init_weights) like so
for regex in self._reinit_modules:
for name, module in self.transformer_model.named_modules():
if re.search(regex, name):
module.apply(self.transformer_model._init_weights)but this doesn't actually re-init the modules weights? Strange. I will have to dig into the HF repo and docs to figure out why.
EDIT: My test was broken. I was making a shallow copy of the weights pre re-initialization, which were getting mutated.
|
@epwalsh Okay, I believe everything works and this new functionality is covered by a bunch of tests. I think the only outstanding questions are:
|
That makes sense to me!
We could test with a GPT2 model instead since those don't have an |
Decided against this as I realized it's actually a little complicated. The caching doesn't work because
Perfect! I used tiny-gpt2 and it made the tests run quite a bit faster |
We could convert it to a tuple instead, which will be hashable |
epwalsh
left a comment
There was a problem hiding this comment.
This looks really great! I just have a couple minor comments
| If this is an integer, the last `reinit_modules` layers of the transformer will be | ||
| re-initialized. If this is a tuple of integers, the layers indexed by `reinit_modules` will | ||
| be re-initialized. If this is a tuple of strings, they will be treated as regexes and any | ||
| module with a name matching the regex will be re-initialized. Re-initializing the last few | ||
| layers of a pretrained transformer can reduce the instability of fine-tuning on small | ||
| datasets and may improve performance (https://arxiv.org/abs/2006.05987v3). Has no effect | ||
| if `load_weights` is `False` or `override_weights_file` is not None. |
There was a problem hiding this comment.
I think it would be helpful to have a note here about how we can't guarantee using the int or list[int] form will work. And if that fails, the user should use the list[str] form.
There was a problem hiding this comment.
Good call! Updated this docstring in PretrainedTransformerEmbedder and cached_transformers.get() with a note warning that layer indices could fail
| be re-initialized. Note, because the module structure of the transformer `model_name` can | ||
| differ, we cannot guarantee that providing an integer or tuple of integers will work. If | ||
| this fails, you can instead provide a tuple of strings, which will be treated as regexes and | ||
| any module with a name matching the regex will be re-initialized. Re-initializing the last |
epwalsh
left a comment
There was a problem hiding this comment.
LGTM! I think #5529 should fix the failures in CI, so once that goes through and all tests pass I'll merge. Thanks @JohnGiorgi!
Awesome :) Thanks for your help as always!! |
|
@JohnGiorgi, @epwalsh, I don't think this makes a ton of sense as implemented. You want to cache the weights because you want to re-use them in many places. But We can still have this be part of |
|
Thanks! |
Fixes #5491.
Changes proposed in this pull request:
cached_transformers.get(),reinit_layers.reinit_layersis an integer, the parameters of the lastreinit_layersof the transformer model are re-initialized.reinit_layersis a list, the parameters of the layers indexed byreinit_layersof the transformer model are re-initialized.Before submitting
section of the
CONTRIBUTINGdocs.Writing docstrings section of the
CONTRIBUTINGdocs.After submitting
codecov/patchreports high test coverage (at least 90%).You can find this under the "Actions" tab of the pull request once the other checks have finished.