2
2
from unittest .mock import patch
3
3
4
4
import torch
5
- import torchtext
6
5
from torch .nn import functional as torch_F
7
6
8
- from ..common .torchtext_test_case import TorchtextTestCase
7
+ from ..common .case_utils import TestBaseMixin
9
8
10
9
11
- class TestModels (TorchtextTestCase ):
10
+ class BaseTestModels (TestBaseMixin ):
11
+ def get_model (self , encoder_conf , head = None , freeze_encoder = False , checkpoint = None , override_checkpoint_head = False ):
12
+ from torchtext .models import RobertaBundle
13
+
14
+ model = RobertaBundle .build_model (
15
+ encoder_conf = encoder_conf ,
16
+ head = head ,
17
+ freeze_encoder = freeze_encoder ,
18
+ checkpoint = checkpoint ,
19
+ override_checkpoint_head = override_checkpoint_head ,
20
+ )
21
+ model .to (device = self .device , dtype = self .dtype )
22
+ return model
23
+
12
24
def test_roberta_bundler_build_model (self ) -> None :
13
- from torchtext .models import RobertaClassificationHead , RobertaEncoderConf , RobertaModel , RobertaBundle
25
+ from torchtext .models import RobertaClassificationHead , RobertaEncoderConf , RobertaModel
14
26
15
27
dummy_encoder_conf = RobertaEncoderConf (
16
28
vocab_size = 10 , embedding_dim = 16 , ffn_dimension = 64 , num_attention_heads = 2 , num_encoder_layers = 2
17
29
)
18
30
19
31
# case: user provide encoder checkpoint state dict
20
32
dummy_encoder = RobertaModel (dummy_encoder_conf )
21
- model = RobertaBundle . build_model (encoder_conf = dummy_encoder_conf , checkpoint = dummy_encoder .state_dict ())
33
+ model = self . get_model (encoder_conf = dummy_encoder_conf , checkpoint = dummy_encoder .state_dict ())
22
34
self .assertEqual (model .state_dict (), dummy_encoder .state_dict ())
23
35
24
36
# case: user provide classifier checkpoint state dict when head is given and override_head is False (by default)
25
37
dummy_classifier_head = RobertaClassificationHead (num_classes = 2 , input_dim = 16 )
26
38
another_dummy_classifier_head = RobertaClassificationHead (num_classes = 2 , input_dim = 16 )
27
39
dummy_classifier = RobertaModel (dummy_encoder_conf , dummy_classifier_head )
28
- model = RobertaBundle . build_model (
40
+ model = self . get_model (
29
41
encoder_conf = dummy_encoder_conf ,
30
42
head = another_dummy_classifier_head ,
31
43
checkpoint = dummy_classifier .state_dict (),
@@ -34,7 +46,7 @@ def test_roberta_bundler_build_model(self) -> None:
34
46
35
47
# case: user provide classifier checkpoint state dict when head is given and override_head is set True
36
48
another_dummy_classifier_head = RobertaClassificationHead (num_classes = 2 , input_dim = 16 )
37
- model = RobertaBundle . build_model (
49
+ model = self . get_model (
38
50
encoder_conf = dummy_encoder_conf ,
39
51
head = another_dummy_classifier_head ,
40
52
checkpoint = dummy_classifier .state_dict (),
@@ -48,13 +60,13 @@ def test_roberta_bundler_build_model(self) -> None:
48
60
encoder_state_dict = {}
49
61
for k , v in dummy_classifier .encoder .state_dict ().items ():
50
62
encoder_state_dict ["encoder." + k ] = v
51
- model = torchtext . models . RobertaBundle . build_model (
63
+ model = self . get_model (
52
64
encoder_conf = dummy_encoder_conf , head = dummy_classifier_head , checkpoint = encoder_state_dict
53
65
)
54
66
self .assertEqual (model .state_dict (), dummy_classifier .state_dict ())
55
67
56
68
def test_roberta_bundler_train (self ) -> None :
57
- from torchtext .models import RobertaClassificationHead , RobertaEncoderConf , RobertaModel , RobertaBundle
69
+ from torchtext .models import RobertaClassificationHead , RobertaEncoderConf , RobertaModel
58
70
59
71
dummy_encoder_conf = RobertaEncoderConf (
60
72
vocab_size = 10 , embedding_dim = 16 , ffn_dimension = 64 , num_attention_heads = 2 , num_encoder_layers = 2
@@ -63,8 +75,8 @@ def test_roberta_bundler_train(self) -> None:
63
75
64
76
def _train (model ):
65
77
optim = SGD (model .parameters (), lr = 1 )
66
- model_input = torch .tensor ([[0 , 1 , 2 , 3 , 4 , 5 ]])
67
- target = torch .tensor ([0 ])
78
+ model_input = torch .tensor ([[0 , 1 , 2 , 3 , 4 , 5 ]]). to ( device = self . device )
79
+ target = torch .tensor ([0 ]). to ( device = self . device )
68
80
logits = model (model_input )
69
81
loss = torch_F .cross_entropy (logits , target )
70
82
loss .backward ()
@@ -73,7 +85,7 @@ def _train(model):
73
85
# does not freeze encoder
74
86
dummy_classifier_head = RobertaClassificationHead (num_classes = 2 , input_dim = 16 )
75
87
dummy_classifier = RobertaModel (dummy_encoder_conf , dummy_classifier_head )
76
- model = RobertaBundle . build_model (
88
+ model = self . get_model (
77
89
encoder_conf = dummy_encoder_conf ,
78
90
head = dummy_classifier_head ,
79
91
freeze_encoder = False ,
@@ -91,7 +103,7 @@ def _train(model):
91
103
# freeze encoder
92
104
dummy_classifier_head = RobertaClassificationHead (num_classes = 2 , input_dim = 16 )
93
105
dummy_classifier = RobertaModel (dummy_encoder_conf , dummy_classifier_head )
94
- model = RobertaBundle . build_model (
106
+ model = self . get_model (
95
107
encoder_conf = dummy_encoder_conf ,
96
108
head = dummy_classifier_head ,
97
109
freeze_encoder = True ,
0 commit comments