@@ -53,7 +53,6 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
53
53
super (TransformerModel , self ).__init__ ()
54
54
from torch .nn import TransformerEncoder , TransformerEncoderLayer
55
55
self .model_type = 'Transformer'
56
- self .src_mask = None
57
56
self .pos_encoder = PositionalEncoding (ninp , dropout )
58
57
encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
59
58
self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
@@ -63,7 +62,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
63
62
64
63
self .init_weights ()
65
64
66
- def _generate_square_subsequent_mask (self , sz ):
65
+ def generate_square_subsequent_mask (self , sz ):
67
66
mask = (torch .triu (torch .ones (sz , sz )) == 1 ).transpose (0 , 1 )
68
67
mask = mask .float ().masked_fill (mask == 0 , float ('-inf' )).masked_fill (mask == 1 , float (0.0 ))
69
68
return mask
@@ -74,18 +73,12 @@ def init_weights(self):
74
73
self .decoder .bias .data .zero_ ()
75
74
self .decoder .weight .data .uniform_ (- initrange , initrange )
76
75
77
- def forward (self , src ):
78
- if self .src_mask is None or self .src_mask .size (0 ) != len (src ):
79
- device = src .device
80
- mask = self ._generate_square_subsequent_mask (len (src )).to (device )
81
- self .src_mask = mask
82
-
76
+ def forward (self , src , src_mask ):
83
77
src = self .encoder (src ) * math .sqrt (self .ninp )
84
78
src = self .pos_encoder (src )
85
- output = self .transformer_encoder (src , self . src_mask )
79
+ output = self .transformer_encoder (src , src_mask )
86
80
output = self .decoder (output )
87
- return F .log_softmax (output , dim = - 1 )
88
-
81
+ return output
89
82
90
83
######################################################################
91
84
# ``PositionalEncoding`` module injects some information about the
@@ -113,7 +106,6 @@ def forward(self, x):
113
106
x = x + self .pe [:x .size (0 ), :]
114
107
return self .dropout (x )
115
108
116
-
117
109
######################################################################
118
110
# Load data
119
111
# ---------
@@ -200,15 +192,14 @@ def get_batch(batch_data):
200
192
# equal to the length of the vocab object.
201
193
#
202
194
203
- ntokens = len (TEXT . vocab .stoi ) # the size of vocabulary
195
+ ntokens = len (vocab .stoi ) # the size of vocabulary
204
196
emsize = 200 # embedding dimension
205
197
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
206
198
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
207
- nhead = 2 # the number of heads in the multiheadattention models
199
+ nhead = 2 # the number of heads in the multiheadattention models
208
200
dropout = 0.2 # the dropout value
209
201
model = TransformerModel (ntokens , emsize , nhead , nhid , nlayers , dropout ).to (device )
210
202
211
-
212
203
######################################################################
213
204
# Run the model
214
205
# -------------
0 commit comments