Skip to content

Commit 63882e5

Browse files
parmeetfacebook-github-bot
authored andcommitted
Import torchtext from github aea6ad6,#1449 to 9f2fb3f,#1452
Summary: command: `python pytorch/import.py --project_name text --commit_ids aea6ad6 9f2fb3f --squash` Reviewed By: abhinavarora Differential Revision: D32690771 fbshipit-source-id: cde616182ecfe643ab48d727b66bbf0194480d3e
1 parent 5af1436 commit 63882e5

File tree

3 files changed

+99
-61
lines changed

3 files changed

+99
-61
lines changed

test/models/test_models.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import torchtext
22
import torch
3+
from torch.nn import functional as torch_F
4+
import copy
35
from ..common.torchtext_test_case import TorchtextTestCase
46
from ..common.assets import get_asset_path
57

@@ -91,31 +93,31 @@ def test_xlmr_transform_jit(self):
9193
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
9294
torch.testing.assert_close(actual, expected)
9395

94-
def test_roberta_bundler_from_config(self):
96+
def test_roberta_bundler_build_model(self):
9597
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
9698
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)
9799

98100
# case: user provide encoder checkpoint state dict
99101
dummy_encoder = RobertaModel(dummy_encoder_conf)
100-
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
102+
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
101103
checkpoint=dummy_encoder.state_dict())
102104
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())
103105

104106
# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
105107
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
106108
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
107109
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
108-
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
110+
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
109111
head=another_dummy_classifier_head,
110112
checkpoint=dummy_classifier.state_dict())
111113
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
112114

113115
# case: user provide classifier checkpoint state dict when head is given and override_head is set True
114116
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
115-
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
117+
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
116118
head=another_dummy_classifier_head,
117119
checkpoint=dummy_classifier.state_dict(),
118-
override_head=True)
120+
override_checkpoint_head=True)
119121
self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict())
120122

121123
# case: user provide only encoder checkpoint state dict when head is given
@@ -124,5 +126,51 @@ def test_roberta_bundler_from_config(self):
124126
encoder_state_dict = {}
125127
for k, v in dummy_classifier.encoder.state_dict().items():
126128
encoder_state_dict['encoder.' + k] = v
127-
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
129+
model = torchtext.models.RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
128130
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
131+
132+
def test_roberta_bundler_train(self):
133+
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
134+
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)
135+
from torch.optim import SGD
136+
137+
def _train(model):
138+
optim = SGD(model.parameters(), lr=1)
139+
model_input = torch.tensor([[0, 1, 2, 3, 4, 5]])
140+
target = torch.tensor([0])
141+
logits = model(model_input)
142+
loss = torch_F.cross_entropy(logits, target)
143+
loss.backward()
144+
optim.step()
145+
146+
# does not freeze encoder
147+
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
148+
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
149+
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
150+
head=dummy_classifier_head,
151+
freeze_encoder=False,
152+
checkpoint=dummy_classifier.state_dict())
153+
154+
encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
155+
head_current_state_dict = copy.deepcopy(model.head.state_dict())
156+
157+
_train(model)
158+
159+
self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict)
160+
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)
161+
162+
# freeze encoder
163+
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
164+
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
165+
model = RobertaModelBundle.build_model(encoder_conf=dummy_encoder_conf,
166+
head=dummy_classifier_head,
167+
freeze_encoder=True,
168+
checkpoint=dummy_classifier.state_dict())
169+
170+
encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
171+
head_current_state_dict = copy.deepcopy(model.head.state_dict())
172+
173+
_train(model)
174+
175+
self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict)
176+
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)

torchtext/models/roberta/bundler.py

Lines changed: 29 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from .model import (
1414
RobertaEncoderConf,
1515
RobertaModel,
16-
_get_model,
1716
)
1817

1918
from .transforms import get_xlmr_transform
@@ -30,56 +29,50 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict):
3029
class RobertaModelBundle:
3130
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)
3231
33-
Example - Pretrained encoder
32+
Example - Pretrained base xlmr encoder
3433
>>> import torch, torchtext
34+
>>> from torchtext.functional import to_tensor
3535
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
3636
>>> model = xlmr_base.get_model()
3737
>>> transform = xlmr_base.transform()
38-
>>> model_input = torch.tensor(transform(["Hello World"]))
39-
>>> output = model(model_input)
40-
>>> output.shape
41-
torch.Size([1, 4, 768])
4238
>>> input_batch = ["Hello world", "How are you!"]
43-
>>> from torchtext.functional import to_tensor
4439
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
4540
>>> output = model(model_input)
4641
>>> output.shape
4742
torch.Size([2, 6, 768])
4843
49-
Example - Pretrained encoder attached to un-initialized classification head
44+
Example - Pretrained large xlmr encoder attached to un-initialized classification head
5045
>>> import torch, torchtext
46+
>>> from torchtext.models import RobertaClassificationHead
47+
>>> from torchtext.functional import to_tensor
5148
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
52-
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim)
53-
>>> classification_model = xlmr_large.get_model(head=classifier_head)
49+
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = 1024)
50+
>>> model = xlmr_large.get_model(head=classifier_head)
5451
>>> transform = xlmr_large.transform()
55-
>>> model_input = torch.tensor(transform(["Hello World"]))
56-
>>> output = classification_model(model_input)
52+
>>> input_batch = ["Hello world", "How are you!"]
53+
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
54+
>>> output = model(model_input)
5755
>>> output.shape
5856
torch.Size([1, 2])
5957
6058
Example - User-specified configuration and checkpoint
6159
>>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead
6260
>>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt"
63-
>>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
64-
>>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path)
65-
>>> encoder = roberta_bundle.get_model()
61+
>>> encoder_conf = RobertaEncoderConf(vocab_size=250002)
6662
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
67-
>>> classifier = roberta_bundle.get_model(head=classifier_head)
68-
>>> # using from_config
69-
>>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path)
70-
>>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path)
63+
>>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path)
7164
"""
7265
_encoder_conf: RobertaEncoderConf
7366
_path: Optional[str] = None
7467
_head: Optional[Module] = None
7568
transform: Optional[Callable] = None
7669

7770
def get_model(self,
71+
*,
7872
head: Optional[Module] = None,
7973
load_weights: bool = True,
8074
freeze_encoder: bool = False,
81-
*,
82-
dl_kwargs=None) -> RobertaModel:
75+
dl_kwargs: Dict[str, Any] = None) -> RobertaModel:
8376
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel
8477
8578
Args:
@@ -103,35 +96,38 @@ def get_model(self,
10396
else:
10497
input_head = self._head
10598

106-
return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf,
99+
return RobertaModelBundle.build_model(encoder_conf=self._encoder_conf,
107100
head=input_head,
108101
freeze_encoder=freeze_encoder,
109-
checkpoint=self._path,
110-
override_head=True,
102+
checkpoint=self._path if load_weights else None,
103+
override_checkpoint_head=True,
104+
strict=True,
111105
dl_kwargs=dl_kwargs)
112106

113107
@classmethod
114-
def from_config(
108+
def build_model(
115109
cls,
116110
encoder_conf: RobertaEncoderConf,
111+
*,
117112
head: Optional[Module] = None,
118113
freeze_encoder: bool = False,
119114
checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
120-
*,
121-
override_head: bool = False,
115+
override_checkpoint_head: bool = False,
116+
strict=True,
122117
dl_kwargs: Dict[str, Any] = None,
123118
) -> RobertaModel:
124-
"""Class method to create model with user-defined encoder configuration and checkpoint
119+
"""Class builder method
125120
126121
Args:
127122
encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration
128123
head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``)
129124
freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``)
130125
checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
131-
override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
126+
override_checkpoint_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
127+
strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: ``True``)
132128
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
133129
"""
134-
model = _get_model(encoder_conf, head, freeze_encoder)
130+
model = RobertaModel(encoder_conf, head, freeze_encoder)
135131
if checkpoint is not None:
136132
if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
137133
state_dict = checkpoint
@@ -145,10 +141,10 @@ def from_config(
145141
regex = re.compile(r"^head\.")
146142
head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)}
147143
# If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict
148-
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head:
144+
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_checkpoint_head:
149145
state_dict.update(head_state_dict)
150146

151-
model.load_state_dict(state_dict, strict=True)
147+
model.load_state_dict(state_dict, strict=strict)
152148

153149
return model
154150

@@ -168,7 +164,7 @@ def encoderConf(self) -> RobertaEncoderConf:
168164

169165
XLMR_BASE_ENCODER.__doc__ = (
170166
'''
171-
XLM-R Encoder with base configuration
167+
XLM-R Encoder with Base configuration
172168
173169
Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
174170
'''

torchtext/models/roberta/model.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
dropout: float = 0.1,
4343
scaling: Optional[float] = None,
4444
normalize_before: bool = False,
45+
freeze: bool = False,
4546
):
4647
super().__init__()
4748
if not scaling:
@@ -62,17 +63,17 @@ def __init__(
6263
return_all_layers=False,
6364
)
6465

65-
@classmethod
66-
def from_config(cls, config: RobertaEncoderConf):
67-
return cls(**asdict(config))
66+
if freeze:
67+
for p in self.parameters():
68+
p.requires_grad = False
6869

69-
def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
70+
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
7071
output = self.transformer(tokens)
7172
if torch.jit.isinstance(output, List[Tensor]):
7273
output = output[-1]
7374
output = output.transpose(1, 0)
74-
if mask is not None:
75-
output = output[mask.to(torch.bool), :]
75+
if masked_tokens is not None:
76+
output = output[masked_tokens.to(torch.bool), :]
7677
return output
7778

7879

@@ -100,35 +101,28 @@ def forward(self, features):
100101
class RobertaModel(Module):
101102
"""
102103
103-
Example - Instantiate model with user-specified configuration
104+
Example - Instatiating model object
104105
>>> from torchtext.models import RobertaEncoderConf, RobertaModel, RobertaClassificationHead
105106
>>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
106107
>>> encoder = RobertaModel(config=roberta_encoder_conf)
107108
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
108109
>>> classifier = RobertaModel(config=roberta_encoder_conf, head=classifier_head)
109110
"""
110111

111-
def __init__(self, config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False):
112+
def __init__(self,
113+
encoder_conf: RobertaEncoderConf,
114+
head: Optional[Module] = None,
115+
freeze_encoder: bool = False):
112116
super().__init__()
113-
assert isinstance(config, RobertaEncoderConf)
114-
115-
self.encoder = RobertaEncoder.from_config(config)
116-
if freeze_encoder:
117-
for param in self.encoder.parameters():
118-
param.requires_grad = False
119-
120-
logger.info("Encoder weights are frozen")
117+
assert isinstance(encoder_conf, RobertaEncoderConf)
121118

119+
self.encoder = RobertaEncoder(**asdict(encoder_conf), freeze=freeze_encoder)
122120
self.head = head
123121

124-
def forward(self, tokens: Tensor, mask: Optional[Tensor] = None) -> Tensor:
125-
features = self.encoder(tokens, mask)
122+
def forward(self, tokens: Tensor, masked_tokens: Optional[Tensor] = None) -> Tensor:
123+
features = self.encoder(tokens, masked_tokens)
126124
if self.head is None:
127125
return features
128126

129127
x = self.head(features)
130128
return x
131-
132-
133-
def _get_model(config: RobertaEncoderConf, head: Optional[Module] = None, freeze_encoder: bool = False) -> RobertaModel:
134-
return RobertaModel(config, head, freeze_encoder)

0 commit comments

Comments
 (0)