Skip to content

Commit 125684c

Browse files
datumboxfacebook-github-bot
authored andcommitted
switch to_ivalue to __prepare_scriptable__ (#1080)
Reviewed By: zhangguanheng66 Differential Revision: D26368995 fbshipit-source-id: 0352c04e422c835350bd42df35d4054d543fee36
1 parent ce0cb15 commit 125684c

14 files changed

+70
-93
lines changed

benchmark/benchmark_basic_english_normalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _run_benchmark_lookup(train, tokenizer):
1515

1616
existing_basic_english_tokenizer = get_tokenizer("basic_english")
1717
experimental_basic_english_normalize = basic_english_normalize()
18-
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize.to_ivalue())
18+
experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize)
1919

2020
# existing eager lookup
2121
train, _ = AG_NEWS()

benchmark/benchmark_experimental_vectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def _run_benchmark_lookup(tokens, vector):
4242

4343
# experimental FastText jit lookup
4444
print("FastText Experimental - Jit Mode")
45-
jit_fast_text_experimental = torch.jit.script(fast_text_experimental.to_ivalue())
45+
jit_fast_text_experimental = torch.jit.script(fast_text_experimental)
4646
_run_benchmark_lookup(tokens, jit_fast_text_experimental)
4747

4848

benchmark/benchmark_experimental_vocab.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def benchmark_experimental_vocab_construction(vocab_file_path, is_raw_text=True,
6767
print("Loading from raw text file with basic_english_normalize tokenizer")
6868
for _ in range(num_iters):
6969
tokenizer = basic_english_normalize()
70-
jited_tokenizer = torch.jit.script(tokenizer.to_ivalue())
70+
jited_tokenizer = torch.jit.script(tokenizer)
7171
build_vocab_from_text_file(f, jited_tokenizer, num_cpus=1)
7272
print("Construction time:", time.monotonic() - t0)
7373
else:
@@ -140,7 +140,7 @@ def token_iterator(file_path):
140140
t0 = time.monotonic()
141141
v_experimental = VocabExperimental(ordered_dict)
142142
print("Construction time:", time.monotonic() - t0)
143-
jit_v_experimental = torch.jit.script(v_experimental.to_ivalue())
143+
jit_v_experimental = torch.jit.script(v_experimental)
144144

145145
# existing Vocab eager lookup
146146
print("Vocab - Eager Mode")
@@ -154,7 +154,7 @@ def token_iterator(file_path):
154154
_run_benchmark_lookup([tokens], v_experimental)
155155
_run_benchmark_lookup(tokens_lists, v_experimental)
156156

157-
jit_v_experimental = torch.jit.script(v_experimental.to_ivalue())
157+
jit_v_experimental = torch.jit.script(v_experimental)
158158
# experimental Vocab jit lookup
159159
print("Vocab Experimental - Jit Mode")
160160
_run_benchmark_lookup(tokens, jit_v_experimental)

benchmark/benchmark_pytext_vocab.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def benchmark_experimental_vocab():
150150
t0 = time.monotonic()
151151
experimental_script_vocab = ExperimentalScriptVocabulary(ordered_dict, unk_token="<unk>")
152152
print("Construction time:", time.monotonic() - t0)
153-
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab.to_ivalue())
153+
jit_experimental_script_vocab = torch.jit.script(experimental_script_vocab)
154154

155155
# pytext Vocab eager lookup
156156
print("Pytext Vocabulary - Eager Mode")

examples/data_pipeline/pipelines.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,10 @@ def build_sp_pipeline(spm_file):
3232
vocab = PretrainedSPVocab(load_sp_model(spm_file))
3333

3434
# Insert token in vocab to match a pretrained vocab
35-
vocab.insert_token('<pad>', 1)
3635
pipeline = TextSequentialTransforms(tokenizer, vocab)
37-
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
36+
jit_pipeline = torch.jit.script(pipeline)
3837
print('jit sentencepiece pipeline success!')
39-
return pipeline, pipeline.to_ivalue(), jit_pipeline
38+
return pipeline, pipeline, jit_pipeline
4039

4140

4241
def build_legacy_torchtext_vocab_pipeline(vocab_file):
@@ -59,9 +58,9 @@ def build_experimental_torchtext_pipeline(hf_vocab_file):
5958
with open(hf_vocab_file, 'r') as f:
6059
vocab = load_vocab_from_file(f)
6160
pipeline = TextSequentialTransforms(tokenizer, vocab)
62-
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
61+
jit_pipeline = torch.jit.script(pipeline)
6362
print('jit experimental torchtext pipeline success!')
64-
return pipeline, pipeline.to_ivalue(), jit_pipeline
63+
return pipeline, pipeline, jit_pipeline
6564

6665

6766
def build_legacy_batch_torchtext_vocab_pipeline(vocab_file):
@@ -104,9 +103,9 @@ def build_legacy_pytext_script_vocab_pipeline(vocab_file):
104103
vocab_list.insert(0, "<unk>")
105104
pipeline = TextSequentialTransforms(tokenizer,
106105
PyTextScriptVocabTransform(ScriptVocabulary(vocab_list)))
107-
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
106+
jit_pipeline = torch.jit.script(pipeline)
108107
print('jit legacy PyText pipeline success!')
109-
return pipeline, pipeline.to_ivalue(), jit_pipeline
108+
return pipeline, pipeline, jit_pipeline
110109

111110

112111
def build_experimental_pytext_script_pipeline(vocab_file):
@@ -125,9 +124,9 @@ def build_experimental_pytext_script_pipeline(vocab_file):
125124
# Insert token in vocab to match a pretrained vocab
126125
pipeline = TextSequentialTransforms(tokenizer,
127126
PyTextScriptVocabTransform(script_vocab(ordered_dict)))
128-
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
127+
jit_pipeline = torch.jit.script(pipeline)
129128
print('jit legacy PyText pipeline success!')
130-
return pipeline, pipeline.to_ivalue(), jit_pipeline
129+
return pipeline, pipeline, jit_pipeline
131130

132131

133132
def build_legacy_fasttext_vector_pipeline():
@@ -143,10 +142,10 @@ def build_experimental_fasttext_vector_pipeline():
143142
vector = FastTextExperimental()
144143

145144
pipeline = TextSequentialTransforms(tokenizer, vector)
146-
jit_pipeline = torch.jit.script(pipeline.to_ivalue())
145+
jit_pipeline = torch.jit.script(pipeline)
147146

148147
print('jit legacy fasttext pipeline success!')
149-
return pipeline, pipeline.to_ivalue(), jit_pipeline
148+
return pipeline, pipeline, jit_pipeline
150149

151150

152151
def run_benchmark_lookup(text_classification_dataset, pipeline):

examples/data_pipeline/transforms.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@ def forward(self, tokens: List[str]) -> List[int]:
2424
def insert_token(self, token: str, index: int) -> None:
2525
self.vocab.insert_token(token, index)
2626

27-
def to_ivalue(self):
28-
if hasattr(self.vocab, 'to_ivalue'):
29-
sp_model = self.sp_model
30-
new_module = PretrainedSPVocab(sp_model)
31-
new_module.vocab = self.vocab.to_ivalue()
32-
return new_module
33-
return self
34-
3527

3628
class PyTextVocabTransform(nn.Module):
3729
r"""PyTextVocabTransform transform
@@ -57,12 +49,6 @@ def __init__(self, vocab):
5749
def forward(self, tokens: List[str]) -> List[int]:
5850
return self.vocab.lookup_indices_1d(tokens)
5951

60-
def to_ivalue(self):
61-
if hasattr(self.vocab, 'to_ivalue'):
62-
vocab = self.vocab.to_ivalue()
63-
return PyTextScriptVocabTransform(vocab)
64-
return self
65-
6652

6753
class ToLongTensor(nn.Module):
6854
r"""Convert a list of integers to long tensor

test/data/test_functional.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,16 @@ def test_BasicEnglishNormalize(self):
9494
basic_eng_norm = basic_english_normalize()
9595
experimental_eager_tokens = basic_eng_norm(test_sample)
9696

97-
jit_basic_eng_norm = torch.jit.script(basic_eng_norm.to_ivalue())
97+
jit_basic_eng_norm = torch.jit.script(basic_eng_norm)
9898
experimental_jit_tokens = jit_basic_eng_norm(test_sample)
9999

100100
basic_english_tokenizer = data.get_tokenizer("basic_english")
101101
eager_tokens = basic_english_tokenizer(test_sample)
102102

103103
assert not basic_eng_norm.is_jitable
104-
assert basic_eng_norm.to_ivalue().is_jitable
104+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
105+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
106+
assert basic_eng_norm.__prepare_scriptable__().is_jitable
105107

106108
self.assertEqual(experimental_jit_tokens, ref_results)
107109
self.assertEqual(eager_tokens, ref_results)
@@ -121,7 +123,9 @@ def test_basicEnglishNormalize_load_and_save(self):
121123

122124
with self.subTest('torchscript'):
123125
save_path = os.path.join(self.test_dir, 'ben_torchscrip.pt')
124-
ben = basic_english_normalize().to_ivalue()
126+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
127+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
128+
ben = basic_english_normalize().__prepare_scriptable__()
125129
torch.save(ben, save_path)
126130
loaded_ben = torch.load(save_path)
127131
self.assertEqual(loaded_ben(test_sample), ref_results)
@@ -149,11 +153,13 @@ def test_RegexTokenizer(self):
149153
r_tokenizer = regex_tokenizer(patterns_list)
150154
eager_tokens = r_tokenizer(test_sample)
151155

152-
jit_r_tokenizer = torch.jit.script(r_tokenizer.to_ivalue())
156+
jit_r_tokenizer = torch.jit.script(r_tokenizer)
153157
jit_tokens = jit_r_tokenizer(test_sample)
154158

155159
assert not r_tokenizer.is_jitable
156-
assert r_tokenizer.to_ivalue().is_jitable
160+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
161+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
162+
assert r_tokenizer.__prepare_scriptable__().is_jitable
157163

158164
self.assertEqual(eager_tokens, ref_results)
159165
self.assertEqual(jit_tokens, ref_results)
@@ -186,7 +192,9 @@ def test_load_and_save(self):
186192

187193
with self.subTest('torchscript'):
188194
save_path = os.path.join(self.test_dir, 'regex_torchscript.pt')
189-
tokenizer = regex_tokenizer(patterns_list).to_ivalue()
195+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
196+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
197+
tokenizer = regex_tokenizer(patterns_list).__prepare_scriptable__()
190198
torch.save(tokenizer, save_path)
191199
loaded_tokenizer = torch.load(save_path)
192200
results = loaded_tokenizer(test_sample)

test/experimental/test_transforms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TestTransforms(TorchtextTestCase):
1616
def test_sentencepiece_processor(self):
1717
model_path = get_asset_path('spm_example.model')
1818
spm_transform = sentencepiece_processor(model_path)
19-
jit_spm_transform = torch.jit.script(spm_transform.to_ivalue())
19+
jit_spm_transform = torch.jit.script(spm_transform)
2020
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
2121
ref_results = [15340, 4286, 981, 1207, 1681, 17, 84, 684, 8896, 5366,
2222
144, 3689, 9, 5602, 12114, 6, 560, 649, 5602, 12114]
@@ -28,7 +28,7 @@ def test_sentencepiece_processor(self):
2828
def test_sentencepiece_tokenizer(self):
2929
model_path = get_asset_path('spm_example.model')
3030
spm_tokenizer = sentencepiece_tokenizer(model_path)
31-
jit_spm_tokenizer = torch.jit.script(spm_tokenizer.to_ivalue())
31+
jit_spm_tokenizer = torch.jit.script(spm_tokenizer)
3232
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
3333
ref_results = ['\u2581Sent', 'ence', 'P', 'ie', 'ce', '\u2581is',
3434
'\u2581an', '\u2581un', 'super', 'vis', 'ed', '\u2581text',
@@ -48,7 +48,7 @@ def test_vector_transform(self):
4848
data_path = os.path.join(dir_name, asset_name)
4949
shutil.copy(asset_path, data_path)
5050
vector_transform = VectorTransform(FastText(root=dir_name, validate_file=False))
51-
jit_vector_transform = torch.jit.script(vector_transform.to_ivalue())
51+
jit_vector_transform = torch.jit.script(vector_transform)
5252
# The first 3 entries in each vector.
5353
expected_fasttext_simple_en = torch.tensor([[-0.065334, -0.093031, -0.017571],
5454
[-0.32423, -0.098845, -0.0073467]])
@@ -74,7 +74,9 @@ def test_sentencepiece_load_and_save(self):
7474

7575
with self.subTest('torchscript'):
7676
save_path = os.path.join(self.test_dir, 'spm_torchscript.pt')
77-
spm = sentencepiece_tokenizer((model_path)).to_ivalue()
77+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
78+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
79+
spm = sentencepiece_tokenizer((model_path)).__prepare_scriptable__()
7880
torch.save(spm, save_path)
7981
loaded_spm = torch.load(save_path)
8082
self.assertEqual(expected, loaded_spm(input))

test/experimental/test_vectors.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ def test_vectors_jit(self):
5454
tokens = ['a', 'b']
5555
vecs = torch.stack((tensorA, tensorB), 0)
5656
vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor)
57-
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
57+
jit_vectors_obj = torch.jit.script(vectors_obj)
5858

5959
assert not vectors_obj.is_jitable
60-
assert vectors_obj.to_ivalue().is_jitable
60+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
61+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
62+
assert vectors_obj.__prepare_scriptable__().is_jitable
6163

6264
self.assertEqual(vectors_obj['a'], jit_vectors_obj['a'])
6365
self.assertEqual(vectors_obj['b'], jit_vectors_obj['b'])
@@ -71,7 +73,7 @@ def test_vectors_forward(self):
7173
tokens = ['a', 'b']
7274
vecs = torch.stack((tensorA, tensorB), 0)
7375
vectors_obj = build_vectors(tokens, vecs, unk_tensor=unk_tensor)
74-
jit_vectors_obj = torch.jit.script(vectors_obj.to_ivalue())
76+
jit_vectors_obj = torch.jit.script(vectors_obj)
7577

7678
tokens_to_lookup = ['a', 'b', 'c']
7779
expected_vectors = torch.stack((tensorA, tensorB, unk_tensor), 0)
@@ -148,7 +150,9 @@ def test_vectors_load_and_save(self):
148150

149151
with self.subTest('torchscript'):
150152
vector_path = os.path.join(self.test_dir, 'vectors_torchscript.pt')
151-
torch.save(vectors_obj.to_ivalue(), vector_path)
153+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
154+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
155+
torch.save(vectors_obj.__prepare_scriptable__(), vector_path)
152156
loaded_vectors_obj = torch.load(vector_path)
153157

154158
self.assertEqual(loaded_vectors_obj['a'], tensorA)

test/experimental/test_vocab.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,15 @@ def test_vocab_jit(self):
104104

105105
c = OrderedDict(sorted_by_freq_tuples)
106106
v = vocab(c, min_freq=3)
107-
jit_v = torch.jit.script(v.to_ivalue())
107+
jit_v = torch.jit.script(v)
108108

109109
expected_itos = ['<unk>', 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']
110110
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
111111

112112
assert not v.is_jitable
113-
assert v.to_ivalue().is_jitable
113+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
114+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
115+
assert v.__prepare_scriptable__().is_jitable
114116

115117
self.assertEqual(jit_v.get_itos(), expected_itos)
116118
self.assertEqual(dict(jit_v.get_stoi()), expected_stoi)
@@ -121,7 +123,7 @@ def test_vocab_forward(self):
121123

122124
c = OrderedDict(sorted_by_freq_tuples)
123125
v = vocab(c)
124-
jit_v = torch.jit.script(v.to_ivalue())
126+
jit_v = torch.jit.script(v)
125127

126128
tokens = ['b', 'a', 'c']
127129
expected_indices = [2, 1, 3]
@@ -208,7 +210,9 @@ def test_vocab_load_and_save(self):
208210

209211
with self.subTest('torchscript'):
210212
vocab_path = os.path.join(self.test_dir, 'vocab_torchscript.pt')
211-
torch.save(v.to_ivalue(), vocab_path)
213+
# Call the __prepare_scriptable__() func and convert the building block to the torbhind version
214+
# Not expect users to use the torchbind version on eager mode but still need a CI test here.
215+
torch.save(v.__prepare_scriptable__(), vocab_path)
212216
loaded_v = torch.load(vocab_path)
213217
self.assertEqual(v.get_itos(), expected_itos)
214218
self.assertEqual(dict(loaded_v.get_stoi()), expected_stoi)

0 commit comments

Comments
 (0)