Skip to content

Add a class method in Model Bundler to facilitate model creation with user-defined configuration and checkpoint #1442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 22, 2021
37 changes: 36 additions & 1 deletion test/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torchtext
import torch

from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path

Expand Down Expand Up @@ -91,3 +90,39 @@ def test_xlmr_transform_jit(self):
actual = transform_jit([test_text])
expected = [[0, 43523, 52005, 3647, 13293, 113307, 40514, 2]]
torch.testing.assert_close(actual, expected)

def test_roberta_bundler_from_config(self):
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)

# case: user provide encoder checkpoint state dict
dummy_encoder = RobertaModel(dummy_encoder_conf)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
checkpoint=dummy_encoder.state_dict())
self.assertEqual(model.state_dict(), dummy_encoder.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict())
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

# case: user provide classifier checkpoint state dict when head is given and override_head is set True
another_dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
head=another_dummy_classifier_head,
checkpoint=dummy_classifier.state_dict(),
override_head=True)
self.assertEqual(model.head.state_dict(), another_dummy_classifier_head.state_dict())

# case: user provide only encoder checkpoint state dict when head is given
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
encoder_state_dict = {}
for k, v in dummy_classifier.encoder.state_dict().items():
encoder_state_dict['encoder.' + k] = v
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())
83 changes: 69 additions & 14 deletions torchtext/models/roberta/bundler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

from dataclasses import dataclass
from functools import partial
from urllib.parse import urljoin

from typing import Optional, Callable
from typing import Optional, Callable, Dict, Union, Any
from torchtext._download_hooks import load_state_dict_from_url
from torch.nn import Module
import torch
import logging

import re
logger = logging.getLogger(__name__)

from .model import (
Expand All @@ -21,6 +21,11 @@
from torchtext import _TEXT_BUCKET


def _is_head_available_in_checkpoint(checkpoint, head_state_dict):
# ensure all keys are present
return all(key in checkpoint.keys() for key in head_state_dict.keys())


@dataclass
class RobertaModelBundle:
"""RobertaModelBundle(_params: torchtext.models.RobertaEncoderParams, _path: Optional[str] = None, _head: Optional[torch.nn.Module] = None, transform: Optional[Callable] = None)
Expand All @@ -44,7 +49,7 @@ class RobertaModelBundle:
Example - Pretrained encoder attached to un-initialized classification head
>>> import torch, torchtext
>>> xlmr_large = torchtext.models.XLMR_LARGE_ENCODER
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.params.embedding_dim)
>>> classifier_head = torchtext.models.RobertaClassificationHead(num_classes=2, input_dim = xlmr_large.encoderConf.embedding_dim)
>>> classification_model = xlmr_large.get_model(head=classifier_head)
>>> transform = xlmr_large.transform()
>>> model_input = torch.tensor(transform(["Hello World"]))
Expand All @@ -60,14 +65,28 @@ class RobertaModelBundle:
>>> encoder = roberta_bundle.get_model()
>>> classifier_head = RobertaClassificationHead(num_classes=2, input_dim=768)
>>> classifier = roberta_bundle.get_model(head=classifier_head)
>>> # using from_config
>>> encoder = RobertaModelBundle.from_config(config=roberta_encoder_conf, checkpoint=model_weights_path)
>>> classifier = RobertaModelBundle.from_config(config=roberta_encoder_conf, head=classifier_head, checkpoint=model_weights_path)
"""
_encoder_conf: RobertaEncoderConf
_path: Optional[str] = None
_head: Optional[Module] = None
transform: Optional[Callable] = None

def get_model(self, head: Optional[Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> RobertaModel:
def get_model(self,
head: Optional[Module] = None,
load_weights: bool = True,
freeze_encoder: bool = False,
*,
dl_kwargs=None) -> RobertaModel:
r"""get_model(head: Optional[torch.nn.Module] = None, load_weights: bool = True, freeze_encoder: bool = False, *, dl_kwargs=None) -> torctext.models.RobertaModel

Args:
head (nn.Module): A module to be attached to the encoder to perform specific task. If provided, it will replace the default member head (Default: ``None``)
load_weights (bool): Indicates whether or not to load weights if available. (Default: ``True``)
freeze_encoder (bool): Indicates whether or not to freeze the encoder weights. (Default: ``False``)
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
"""

if load_weights:
Expand All @@ -84,17 +103,53 @@ def get_model(self, head: Optional[Module] = None, load_weights: bool = True, fr
else:
input_head = self._head

model = _get_model(self._encoder_conf, input_head, freeze_encoder)

if not load_weights:
return model
return RobertaModelBundle.from_config(encoder_conf=self._encoder_conf,
head=input_head,
freeze_encoder=freeze_encoder,
checkpoint=self._path,
override_head=True,
dl_kwargs=dl_kwargs)

@classmethod
def from_config(
cls,
encoder_conf: RobertaEncoderConf,
head: Optional[Module] = None,
freeze_encoder: bool = False,
checkpoint: Optional[Union[str, Dict[str, torch.Tensor]]] = None,
*,
override_head: bool = False,
dl_kwargs: Dict[str, Any] = None,
) -> RobertaModel:
"""Class method to create model with user-defined encoder configuration and checkpoint

Args:
encoder_conf (RobertaEncoderConf): An instance of class RobertaEncoderConf that defined the encoder configuration
head (nn.Module): A module to be attached to the encoder to perform specific task. (Default: ``None``)
freeze_encoder (bool): Indicates whether to freeze the encoder weights. (Default: ``False``)
checkpoint (str or Dict[str, torch.Tensor]): Path to or actual model state_dict. state_dict can have partial weights i.e only for encoder. (Default: ``None``)
override_head (bool): Override the checkpoint's head state dict (if present) with provided head state dict. (Default: ``False``)
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. (Default: ``None``)
"""
model = _get_model(encoder_conf, head, freeze_encoder)
if checkpoint is not None:
if torch.jit.isinstance(checkpoint, Dict[str, torch.Tensor]):
state_dict = checkpoint
elif isinstance(checkpoint, str):
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(checkpoint, **dl_kwargs)
else:
raise TypeError("checkpoint must be of type `str` or `Dict[str, torch.Tensor]` but got {}".format(type(checkpoint)))

if head is not None:
regex = re.compile(r"^head\.")
head_state_dict = {k: v for k, v in model.state_dict().items() if regex.findall(k)}
# If checkpoint does not contains head_state_dict, then we augment the checkpoint with user-provided head state_dict
if not _is_head_available_in_checkpoint(state_dict, head_state_dict) or override_head:
state_dict.update(head_state_dict)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mthrok I have updated the logic how to handle the case when checkpoint contains the head parameters or not.

Case1: Does not contains head parameters: In this case, we would update the state with supplied head state dict. This would mean that we are keeping intact the head module whatever is provided by the user. This is inline with what we discussed here #1424 (comment)

Case2: Does contain head parameters: In this case we would use the checkpoint as such.

Note that in both the cases, now we have made strict=True to ensure all key matches. Let me know if this sounds reasonable to contain various edge cases?


dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(self._path, **dl_kwargs)
if input_head is not None:
model.load_state_dict(state_dict, strict=False)
else:
model.load_state_dict(state_dict, strict=True)

return model

@property
Expand Down