@@ -117,7 +117,7 @@ def forward(self, x):
117
117
118
118
119
119
######################################################################
120
- # The training process uses Wikitext-2 dataset from ``torchtext``. The
120
+ # This tutorial uses ``torchtext`` to generate Wikitext-2 dataset . The
121
121
# vocab object is built based on the train dataset and is used to numericalize
122
122
# tokens into tensors. Starting from sequential data, the ``batchify()``
123
123
# function arranges the dataset into columns, trimming off any tokens remaining
@@ -143,18 +143,31 @@ def forward(self, x):
143
143
# efficient batch processing.
144
144
#
145
145
146
- import torchtext
146
+ import io
147
+ import torch
148
+ from torchtext .utils import download_from_url , extract_archive
147
149
from torchtext .data .utils import get_tokenizer
148
- TEXT = torchtext .data .Field (tokenize = get_tokenizer ("basic_english" ),
149
- init_token = '<sos>' ,
150
- eos_token = '<eos>' ,
151
- lower = True )
152
- train_txt , val_txt , test_txt = torchtext .datasets .WikiText2 .splits (TEXT )
153
- TEXT .build_vocab (train_txt )
150
+ from torchtext .vocab import build_vocab_from_iterator
151
+
152
+ url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
153
+ test_filepath , valid_filepath , train_filepath = extract_archive (download_from_url (url ))
154
+ tokenizer = get_tokenizer ('basic_english' )
155
+ vocab = build_vocab_from_iterator (map (tokenizer ,
156
+ iter (io .open (train_filepath ,
157
+ encoding = "utf8" ))))
158
+
159
+ def data_process (raw_text_iter ):
160
+ data = [torch .tensor ([vocab [token ] for token in tokenizer (item )],
161
+ dtype = torch .long ) for item in raw_text_iter ]
162
+ return torch .cat (tuple (filter (lambda t : t .numel () > 0 , data )))
163
+
164
+ train_data = data_process (iter (io .open (train_filepath , encoding = "utf8" )))
165
+ val_data = data_process (iter (io .open (valid_filepath , encoding = "utf8" )))
166
+ test_data = data_process (iter (io .open (test_filepath , encoding = "utf8" )))
167
+
154
168
device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
155
169
156
170
def batchify (data , bsz ):
157
- data = TEXT .numericalize ([data .examples [0 ].text ])
158
171
# Divide the dataset into bsz parts.
159
172
nbatch = data .size (0 ) // bsz
160
173
# Trim off any extra elements that wouldn't cleanly fit (remainders).
@@ -165,9 +178,9 @@ def batchify(data, bsz):
165
178
166
179
batch_size = 20
167
180
eval_batch_size = 10
168
- train_data = batchify (train_txt , batch_size )
169
- val_data = batchify (val_txt , eval_batch_size )
170
- test_data = batchify (test_txt , eval_batch_size )
181
+ train_data = batchify (train_data , batch_size )
182
+ val_data = batchify (val_data , eval_batch_size )
183
+ test_data = batchify (test_data , eval_batch_size )
171
184
172
185
173
186
######################################################################
@@ -209,7 +222,7 @@ def get_batch(source, i):
209
222
# equal to the length of the vocab object.
210
223
#
211
224
212
- ntokens = len (TEXT . vocab .stoi ) # the size of vocabulary
225
+ ntokens = len (vocab .stoi ) # the size of vocabulary
213
226
emsize = 200 # embedding dimension
214
227
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
215
228
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
@@ -246,7 +259,6 @@ def train():
246
259
model .train () # Turn on the train mode
247
260
total_loss = 0.
248
261
start_time = time .time ()
249
- ntokens = len (TEXT .vocab .stoi )
250
262
src_mask = model .generate_square_subsequent_mask (bptt ).to (device )
251
263
for batch , i in enumerate (range (0 , train_data .size (0 ) - 1 , bptt )):
252
264
data , targets = get_batch (train_data , i )
@@ -276,7 +288,6 @@ def train():
276
288
def evaluate (eval_model , data_source ):
277
289
eval_model .eval () # Turn on the evaluation mode
278
290
total_loss = 0.
279
- ntokens = len (TEXT .vocab .stoi )
280
291
src_mask = model .generate_square_subsequent_mask (bptt ).to (device )
281
292
with torch .no_grad ():
282
293
for i in range (0 , data_source .size (0 ) - 1 , bptt ):
0 commit comments