Skip to content

Commit 3128e13

Browse files
zhangguanheng66Guanheng Zhang
andauthored
Remove legacy torchtext code from Transformer tutorial (#1251)
* checkpoint * remove dataloader * checkpoint * Fix ascii decode error Co-authored-by: Guanheng Zhang <[email protected]>
1 parent 8777ee2 commit 3128e13

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

beginner_source/transformer_tutorial.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def forward(self, x):
117117

118118

119119
######################################################################
120-
# The training process uses Wikitext-2 dataset from ``torchtext``. The
120+
# This tutorial uses ``torchtext`` to generate Wikitext-2 dataset. The
121121
# vocab object is built based on the train dataset and is used to numericalize
122122
# tokens into tensors. Starting from sequential data, the ``batchify()``
123123
# function arranges the dataset into columns, trimming off any tokens remaining
@@ -143,18 +143,31 @@ def forward(self, x):
143143
# efficient batch processing.
144144
#
145145

146-
import torchtext
146+
import io
147+
import torch
148+
from torchtext.utils import download_from_url, extract_archive
147149
from torchtext.data.utils import get_tokenizer
148-
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
149-
init_token='<sos>',
150-
eos_token='<eos>',
151-
lower=True)
152-
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
153-
TEXT.build_vocab(train_txt)
150+
from torchtext.vocab import build_vocab_from_iterator
151+
152+
url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
153+
test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url))
154+
tokenizer = get_tokenizer('basic_english')
155+
vocab = build_vocab_from_iterator(map(tokenizer,
156+
iter(io.open(train_filepath,
157+
encoding="utf8"))))
158+
159+
def data_process(raw_text_iter):
160+
data = [torch.tensor([vocab[token] for token in tokenizer(item)],
161+
dtype=torch.long) for item in raw_text_iter]
162+
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
163+
164+
train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
165+
val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
166+
test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))
167+
154168
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155169

156170
def batchify(data, bsz):
157-
data = TEXT.numericalize([data.examples[0].text])
158171
# Divide the dataset into bsz parts.
159172
nbatch = data.size(0) // bsz
160173
# Trim off any extra elements that wouldn't cleanly fit (remainders).
@@ -165,9 +178,9 @@ def batchify(data, bsz):
165178

166179
batch_size = 20
167180
eval_batch_size = 10
168-
train_data = batchify(train_txt, batch_size)
169-
val_data = batchify(val_txt, eval_batch_size)
170-
test_data = batchify(test_txt, eval_batch_size)
181+
train_data = batchify(train_data, batch_size)
182+
val_data = batchify(val_data, eval_batch_size)
183+
test_data = batchify(test_data, eval_batch_size)
171184

172185

173186
######################################################################
@@ -209,7 +222,7 @@ def get_batch(source, i):
209222
# equal to the length of the vocab object.
210223
#
211224

212-
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
225+
ntokens = len(vocab.stoi) # the size of vocabulary
213226
emsize = 200 # embedding dimension
214227
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
215228
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
@@ -246,7 +259,6 @@ def train():
246259
model.train() # Turn on the train mode
247260
total_loss = 0.
248261
start_time = time.time()
249-
ntokens = len(TEXT.vocab.stoi)
250262
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
251263
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
252264
data, targets = get_batch(train_data, i)
@@ -276,7 +288,6 @@ def train():
276288
def evaluate(eval_model, data_source):
277289
eval_model.eval() # Turn on the evaluation mode
278290
total_loss = 0.
279-
ntokens = len(TEXT.vocab.stoi)
280291
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
281292
with torch.no_grad():
282293
for i in range(0, data_source.size(0) - 1, bptt):

0 commit comments

Comments
 (0)