Skip to content

Fix test_generate_sp_model for stress test #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 3 additions & 27 deletions test/common/assets.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,8 @@
import os
import shutil
import atexit
import tempfile
from pathlib import Path

_ASSET_DIR = (Path(__file__).parent.parent / "asset").resolve()

_TEMP_DIR = None


def _init_temp_dir():
"""Initialize temporary directory and register clean up at the end of test."""
global _TEMP_DIR
_TEMP_DIR = tempfile.TemporaryDirectory() # noqa
atexit.register(_TEMP_DIR.cleanup)


def get_asset_path(*path_components, use_temp_dir=False):
"""Get the path to the file under `test/assets` directory.
When `use_temp_dir` is True, the asset is copied to a temporary location and
path to the temporary file is returned.
"""
path = str(_ASSET_DIR.joinpath(*path_components))
if not use_temp_dir:
return path

if _TEMP_DIR is None:
_init_temp_dir()
tgt = os.path.join(_TEMP_DIR.name, path_components[-1])
shutil.copy(path, tgt)
return tgt
def get_asset_path(*path_components):
"""Get the path to the file under `test/assets` directory."""
return str(_ASSET_DIR.joinpath(*path_components))
45 changes: 24 additions & 21 deletions test/data/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import unittest
import sys
import uuid
import shutil
import tempfile

import sentencepiece as spm
Expand All @@ -20,27 +22,28 @@

class TestFunctional(TorchtextTestCase):
def test_generate_sp_model(self):
# Test the function to train a sentencepiece tokenizer

# buck (fb internal) generates test environment which contains ',' in its path.
# SentencePieceTrainer considers such path as comma-delimited file list.
# So as workaround we copy the asset data to temporary directory and load it from there.
data_path = get_asset_path(
'text_normalization_ag_news_test.csv',
use_temp_dir=True)
generate_sp_model(data_path,
vocab_size=23456,
model_prefix='spm_user')

sp_user = spm.SentencePieceProcessor()
sp_user.Load('spm_user.model')

self.assertEqual(len(sp_user), 23456)

if os.path.isfile('spm_user.model'):
os.remove('spm_user.model')
if os.path.isfile('spm_user.vocab'):
os.remove('spm_user.vocab')
"""Test the function to train a sentencepiece tokenizer"""

asset_name = 'text_normalization_ag_news_test.csv'
asset_path = get_asset_path(asset_name)
# We use temporary directory for two reasons:
# 1. buck (fb internal) generates test environment which contains ',' in its path.
# SentencePieceTrainer considers such path as comma-delimited file list.
# So as workaround we copy the asset data to temporary directory and load it from there.
# 2. when fb infra performs stress tests, multiple instances of this test run.
# The name of the generated models have to be unique and they need to be cleaned up.
with tempfile.TemporaryDirectory() as dir_name:
data_path = os.path.join(dir_name, asset_name)
shutil.copy(asset_path, data_path)

model_prefix = os.path.join(dir_name, f'spm_user_{uuid.uuid4()}')
model_file = f'{model_prefix}.model'
generate_sp_model(data_path, vocab_size=23456, model_prefix=model_prefix)

sp_user = spm.SentencePieceProcessor()
sp_user.Load(model_file)

self.assertEqual(len(sp_user), 23456)

def test_sentencepiece_numericalizer(self):
test_sample = 'SentencePiece is an unsupervised text tokenizer and detokenizer'
Expand Down