This repository was archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Reinit layers of pretrained transformer in cached_transformers.get() #5505
Merged
epwalsh
merged 20 commits into
allenai:main
from
JohnGiorgi:reinit-layers-of-pretrained-transformer-embedder
Dec 23, 2021
Merged
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
fe91837
Arg to re-init some layers of pretrained transformer
JohnGiorgi 1721882
Merge branch 'allenai:main' into reinit-layers-of-pretrained-transfor…
JohnGiorgi c75bc56
Better error handelling for bad indices
JohnGiorgi 69fef36
Add a test for reinit_layers
JohnGiorgi e6799ca
Type cast to appease mypy
JohnGiorgi 5e55ad0
Get number of hidden layers in model agnostic way
JohnGiorgi fe42979
Support a list of regexes in reinit_modules
JohnGiorgi 0c33cf8
Add tests for when reinit_modules is a list of str
JohnGiorgi a84d069
Fix broken test for re-initializing modules
JohnGiorgi 39d92ca
Break reinit_modules unit test into two
JohnGiorgi 718020d
Update changelog
JohnGiorgi 9d79151
Tests for when reinit_modules should have no effect
JohnGiorgi 284f76c
Merge branch 'main' into reinit-layers-of-pretrained-transformer-embe…
epwalsh f9b5164
Revert pretrained transformer embedder to main
JohnGiorgi d8392ee
Move reinit_modules feature to cached transformers
JohnGiorgi dae591d
Better error message for invalid reinit_modules argument
JohnGiorgi 4407051
Correct error message to say tuple, not list
JohnGiorgi f371e5c
Add note about layer indices failing in docstring
JohnGiorgi ed39b05
Correct "list" with "tuple" in error message
JohnGiorgi 01fc4d6
Merge branch 'main' into reinit-layers-of-pretrained-transformer-embe…
epwalsh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work fine for BERT, but won't for other models like XLM that name their modules differently. In general I think it'll be brittle to rely on guessing the internal module structure of transformer models.
But I can think of alternative which - although it's a little less friendly from a user perspective - still provides this useful functionality:
reinit_layerstoOptional[Union[int, List[int], List[str]]].reinit_layersisList[str], we interpret it as a list of regular expressions and we reinitialize all modules that match any of the regular expressions in the list.reinit_layersisintorList[int]we try the current approach, but if the giventransformer_modeldoes not have this module structure we throw aConfigurationErrorand suggest that the user use theList[str]form of this parameter instead.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I took a crack at this.
reinit_layerstoreinit_modulesjust because re-initializing modules (instead of whole layers) is possible when it is a list of regexes.Optional[Union[int, List[int], List[str]]]reinit_modulesis a list of regex strings.Unfortunately, I cant seem to get the weights to re-initialize. I figured that after finding a match between the regex and somemodulename I could just callmodule.apply(self.transformer_model._init_weights)like sobut this doesn't actually re-init themodules weights? Strange. I will have to dig into the HF repo and docs to figure out why.EDIT: My test was broken. I was making a shallow copy of the weights pre re-initialization, which were getting mutated.