13
13
from .model import (
14
14
RobertaEncoderConf ,
15
15
RobertaModel ,
16
- _get_model ,
17
16
)
18
17
19
18
from .transforms import get_xlmr_transform
@@ -30,56 +29,50 @@ def _is_head_available_in_checkpoint(checkpoint, head_state_dict):
30
29
class RobertaModelBundle :
31
30
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)
32
31
33
- Example - Pretrained encoder
32
+ Example - Pretrained base xlmr encoder
34
33
>>> import torch, torchtext
34
+ >>> from torchtext.functional import to_tensor
35
35
>>> xlmr_base = torchtext.models.XLMR_BASE_ENCODER
36
36
>>> model = xlmr_base.get_model()
37
37
>>> transform = xlmr_base.transform()
38
- >>> model_input = torch.tensor(transform(["Hello World"]))
39
- >>> output = model(model_input)
40
- >>> output.shape
41
- torch.Size([1, 4, 768])
42
38
>>> input_batch = ["Hello world", "How are you!"]
43
- >>> from torchtext.functional import to_tensor
44
39
>>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
45
40
>>> output = model(model_input)
46
41
>>> output.shape
47
42
torch.Size([2, 6, 768])
48
43
49
- Example - Pretrained encoder attached to un-initialized classification head
44
+ Example - Pretrained large xlmr encoder attached to un-initialized classification head
50
45
>>> import torch, torchtext
46
+ >>> from torchtext.models import RobertaClassificationHead
47
+ >>> from torchtext.functional import to_tensor
51
48
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
52
- >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim )
53
- >>> classification_model = xlmr_large.get_model(head=classifier_head)
49
+ >>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = 1024 )
50
+ >>> model = xlmr_large.get_model(head=classifier_head)
54
51
>>> transform = xlmr_large.transform()
55
- >>> model_input = torch.tensor(transform(["Hello World"]))
56
- >>> output = classification_model(model_input)
52
+ >>> input_batch = ["Hello world", "How are you!"]
53
+ >>> model_input = to_tensor(transform(input_batch), padding_value=transform.pad_idx)
54
+ >>> output = model(model_input)
57
55
>>> output.shape
58
56
torch.Size([1, 2])
59
57
60
58
Example - User-specified configuration and checkpoint
61
59
>>> from torchtext.models import RobertaEncoderConf, RobertaModelBundle, RobertaClassificationHead
62
60
>>> model_weights_path = "https://download.pytorch.org/models/text/xlmr.base.encoder.pt"
63
- >>> roberta_encoder_conf = RobertaEncoderConf(vocab_size=250002)
64
- >>> roberta_bundle = RobertaModelBundle(_encoder_conf=roberta_encoder_conf, _path=model_weights_path)
65
- >>> encoder = roberta_bundle.get_model()
61
+ >>> encoder_conf = RobertaEncoderConf(vocab_size=250002)
66
62
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
67
- >>> classifier = roberta_bundle.get_model(head=classifier_head)
68
- >>> # using from_config
69
- >>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path)
70
- >>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path)
63
+ >>> model = RobertaModelBundle.build_model(encoder_conf=encoder_conf, head=classifier_head, checkpoint=model_weights_path)
71
64
"""
72
65
_encoder_conf : RobertaEncoderConf
73
66
_path : Optional [str ] = None
74
67
_head : Optional [Module ] = None
75
68
transform : Optional [Callable ] = None
76
69
77
70
def get_model (self ,
71
+ * ,
78
72
head : Optional [Module ] = None ,
79
73
load_weights : bool = True ,
80
74
freeze_encoder : bool = False ,
81
- * ,
82
- dl_kwargs = None ) -> RobertaModel :
75
+ dl_kwargs : Dict [str , Any ] = None ) -> RobertaModel :
83
76
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel
84
77
85
78
Args:
@@ -103,35 +96,38 @@ def get_model(self,
103
96
else :
104
97
input_head = self ._head
105
98
106
- return RobertaModelBundle .from_config (encoder_conf = self ._encoder_conf ,
99
+ return RobertaModelBundle .build_model (encoder_conf = self ._encoder_conf ,
107
100
head = input_head ,
108
101
freeze_encoder = freeze_encoder ,
109
- checkpoint = self ._path ,
110
- override_head = True ,
102
+ checkpoint = self ._path if load_weights else None ,
103
+ override_checkpoint_head = True ,
104
+ strict = True ,
111
105
dl_kwargs = dl_kwargs )
112
106
113
107
@classmethod
114
- def from_config (
108
+ def build_model (
115
109
cls ,
116
110
encoder_conf : RobertaEncoderConf ,
111
+ * ,
117
112
head : Optional [Module ] = None ,
118
113
freeze_encoder : bool = False ,
119
114
checkpoint : Optional [Union [str , Dict [str , torch .Tensor ]]] = None ,
120
- * ,
121
- override_head : bool = False ,
115
+ override_checkpoint_head : bool = False ,
116
+ strict = True ,
122
117
dl_kwargs : Dict [str , Any ] = None ,
123
118
) -> RobertaModel :
124
- """Class method to create model with user-defined encoder configuration and checkpoint
119
+ """Class builder method
125
120
126
121
Args:
127
122
encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration
128
123
head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``)
129
124
freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``)
130
125
checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
131
- override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
126
+ override_checkpoint_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
127
+ strict (bool): Passed to :func: `torch.nn.Module.load_state_dict` method. (Default: ``True``)
132
128
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
133
129
"""
134
- model = _get_model (encoder_conf , head , freeze_encoder )
130
+ model = RobertaModel (encoder_conf , head , freeze_encoder )
135
131
if checkpoint is not None :
136
132
if torch .jit .isinstance (checkpoint , Dict [str , torch .Tensor ]):
137
133
state_dict = checkpoint
@@ -145,10 +141,10 @@ def from_config(
145
141
regex = re .compile (r"^head\." )
146
142
head_state_dict = {k : v for k , v in model .state_dict ().items () if regex .findall (k )}
147
143
# If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict
148
- if not _is_head_available_in_checkpoint (state_dict , head_state_dict ) or override_head :
144
+ if not _is_head_available_in_checkpoint (state_dict , head_state_dict ) or override_checkpoint_head :
149
145
state_dict .update (head_state_dict )
150
146
151
- model .load_state_dict (state_dict , strict = True )
147
+ model .load_state_dict (state_dict , strict = strict )
152
148
153
149
return model
154
150
@@ -168,7 +164,7 @@ def encoderConf(self) -> RobertaEncoderConf:
168
164
169
165
XLMR_BASE_ENCODER .__doc__ = (
170
166
'''
171
- XLM-R Encoder with base configuration
167
+ XLM-R Encoder with Base configuration
172
168
173
169
Please refer to :func:`torchtext.models.RobertaModelBundle` for the usage.
174
170
'''
0 commit comments