Skip to content

Commit 5af1436

Browse files
parmeetfacebook-github-bot
authored andcommitted
Add a class method in Model Bundler to facilitate model creation with user-defined configuration and checkpoint (#1442)
Summary: Import from github Command used: `python pytorch/import.py --project_name text --commit_ids 2040d8d` Note that we still not importing the whole repo using import_text.sh. using import.py would be the worflow we would rely on till we merge [legacy code removal commit](2cebac3) into fbcode. Reviewed By: Nayef211 Differential Revision: D32603181 fbshipit-source-id: 1f583e5ac96e693b583ae42d5841bf387cf3727a
1 parent c256e7f commit 5af1436

File tree

2 files changed

+105
-15
lines changed

2 files changed

+105
-15
lines changed

test/models/test_models.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torchtext
22
import torch
3-
43
from ..common.torchtext_test_case import TorchtextTestCase
54
from ..common.assets import get_asset_path
65

@@ -91,3 +90,39 @@ def test_xlmr_transform_jit(self):
9190
actual = transform_jit([test_text])
9291
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
9392
torch.testing.assert_close(actual, expected)
93+
94+
def test_roberta_bundler_from_config(self):
95+
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
96+
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)
97+
98+
# case: user provide encoder checkpoint state dict
99+
dummy_encoder = RobertaModel(dummy_encoder_conf)
100+
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
101+
checkpoint=dummy_encoder.state_dict())
102+
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())
103+
104+
# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
105+
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
106+
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
107+
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
108+
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
109+
head=another_dummy_classifier_head,
110+
checkpoint=dummy_classifier.state_dict())
111+
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
112+
113+
# case: user provide classifier checkpoint state dict when head is given and override_head is set True
114+
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
115+
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
116+
head=another_dummy_classifier_head,
117+
checkpoint=dummy_classifier.state_dict(),
118+
override_head=True)
119+
self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict())
120+
121+
# case: user provide only encoder checkpoint state dict when head is given
122+
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
123+
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
124+
encoder_state_dict = {}
125+
for k, v in dummy_classifier.encoder.state_dict().items():
126+
encoder_state_dict['encoder.' + k] = v
127+
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
128+
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

torchtext/models/roberta/bundler.py

Lines changed: 69 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
21
from dataclasses import dataclass
32
from functools import partial
43
from urllib.parse import urljoin
54

6-
from typing import Optional, Callable
5+
from typing import Optional, Callable, Dict, Union, Any
76
from torchtext._download_hooks import load_state_dict_from_url
87
from torch.nn import Module
8+
import torch
99
import logging
10-
10+
import re
1111
logger = logging.getLogger(__name__)
1212

1313
from .model import (
@@ -21,6 +21,11 @@
2121
from torchtext import _TEXT_BUCKET
2222

2323

24+
def _is_head_available_in_checkpoint(checkpoint, head_state_dict):
25+
# ensure all keys are present
26+
return all(key in checkpoint.keys() for key in head_state_dict.keys())
27+
28+
2429
@dataclass
2530
class RobertaModelBundle:
2631
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)
@@ -44,7 +49,7 @@ class RobertaModelBundle:
4449
Example - Pretrained encoder attached to un-initialized classification head
4550
>>> import torch, torchtext
4651
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
47-
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim)
52+
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim)
4853
>>> classification_model = xlmr_large.get_model(head=classifier_head)
4954
>>> transform = xlmr_large.transform()
5055
>>> model_input = torch.tensor(transform(["Hello World"]))
@@ -60,14 +65,28 @@ class RobertaModelBundle:
6065
>>> encoder = roberta_bundle.get_model()
6166
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
6267
>>> classifier = roberta_bundle.get_model(head=classifier_head)
68+
>>> # using from_config
69+
>>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path)
70+
>>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path)
6371
"""
6472
_encoder_conf: RobertaEncoderConf
6573
_path: Optional[str] = None
6674
_head: Optional[Module] = None
6775
transform: Optional[Callable] = None
6876

69-
def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel:
77+
def get_model(self,
78+
head: Optional[Module] = None,
79+
load_weights: bool = True,
80+
freeze_encoder: bool = False,
81+
*,
82+
dl_kwargs=None) -> RobertaModel:
7083
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel
84+
85+
Args:
86+
head (nn.Module): A module to be attached to the encoder to perform specific task. If provided, it will replace the default member head (Default: ``None``)
87+
load_weights (bool): Indicates whether or not to load weights if available. (Default: ``True``)
88+
freeze_encoder (bool): Indicates whether or not to freeze the encoder weights. (Default: ``False``)
89+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
7190
"""
7291

7392
if load_weights:
@@ -84,17 +103,53 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr
84103
else:
85104
input_head = self._head
86105

87-
model = _get_model(self._encoder_conf, input_head, freeze_encoder)
88-
89-
if not load_weights:
90-
return model
106+
return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf,
107+
head=input_head,
108+
freeze_encoder=freeze_encoder,
109+
checkpoint=self._path,
110+
override_head=True,
111+
dl_kwargs=dl_kwargs)
112+
113+
@classmethod
114+
def from_config(
115+
cls,
116+
encoder_conf: RobertaEncoderConf,
117+
head: Optional[Module] = None,
118+
freeze_encoder: bool = False,
119+
checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
120+
*,
121+
override_head: bool = False,
122+
dl_kwargs: Dict[str, Any] = None,
123+
) -> RobertaModel:
124+
"""Class method to create model with user-defined encoder configuration and checkpoint
125+
126+
Args:
127+
encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration
128+
head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``)
129+
freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``)
130+
checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
131+
override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
132+
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
133+
"""
134+
model = _get_model(encoder_conf, head, freeze_encoder)
135+
if checkpoint is not None:
136+
if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
137+
state_dict = checkpoint
138+
elif isinstance(checkpoint, str):
139+
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
140+
state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs)
141+
else:
142+
raise TypeError("checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint)))
143+
144+
if head is not None:
145+
regex = re.compile(r"^head\.")
146+
head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)}
147+
# If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict
148+
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head:
149+
state_dict.update(head_state_dict)
91150

92-
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
93-
state_dict = load_state_dict_from_url(self._path, **dl_kwargs)
94-
if input_head is not None:
95-
model.load_state_dict(state_dict, strict=False)
96-
else:
97151
model.load_state_dict(state_dict, strict=True)
152+
98153
return model
99154

100155
@property

0 commit comments

Comments
 (0)