-
Notifications
You must be signed in to change notification settings - Fork 814
Add support for STS-B dataset with unit tests #1714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
672227e
Add support for STS-B dataset _ unit test
vcm2114 6667396
Merge branch 'pytorch:main' into stsb_dataset
vcm2114 e81d8ec
Fix quote issue
vcm2114 d5afa01
Modify tests + docstring
vcm2114 0b493bf
Merge branch 'pytorch:main' into stsb_dataset
vcm2114 9cb201e
Remove lambda functions
vcm2114 b605c46
Merge branch 'stsb_dataset' of github.com:VirgileHlav/text into stsb_…
vcm2114 271a044
Lint, adjust test float & quote issues in parsing
vcm2114 f6e18c9
Add dataset documentation
vcm2114 4350c4b
Add shuffle and sharding
vcm2114 0dc7a65
Lint
vcm2114 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,11 @@ SST2 | |
|
||
.. autofunction:: SST2 | ||
|
||
STSB | ||
~~~~ | ||
|
||
.. autofunction:: STSB | ||
|
||
YahooAnswers | ||
~~~~~~~~~~~~ | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import random | ||
import tarfile | ||
from collections import defaultdict | ||
from unittest.mock import patch | ||
|
||
from parameterized import parameterized | ||
from torchtext.datasets.stsb import STSB | ||
|
||
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode | ||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
def _get_mock_dataset(root_dir): | ||
""" | ||
root_dir: directory to the mocked dataset | ||
""" | ||
base_dir = os.path.join(root_dir, "STSB") | ||
temp_dataset_dir = os.path.join(base_dir, "stsbenchmark") | ||
os.makedirs(temp_dataset_dir, exist_ok=True) | ||
|
||
seed = 1 | ||
mocked_data = defaultdict(list) | ||
for file_name, name in zip(["sts-train.csv", "sts-dev.csv" "sts-test.csv"], ["train", "dev", "test"]): | ||
txt_file = os.path.join(temp_dataset_dir, file_name) | ||
with open(txt_file, "w", encoding="utf-8") as f: | ||
for i in range(5): | ||
label = random.uniform(0, 5) | ||
rand_string_1 = get_random_unicode(seed) | ||
rand_string_2 = get_random_unicode(seed + 1) | ||
rand_string_3 = get_random_unicode(seed + 2) | ||
rand_string_4 = get_random_unicode(seed + 3) | ||
rand_string_5 = get_random_unicode(seed + 4) | ||
dataset_line = (i, label, rand_string_4, rand_string_5) | ||
# append line to correct dataset split | ||
mocked_data[name].append(dataset_line) | ||
f.write( | ||
f"{rand_string_1}\t{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n" | ||
) | ||
seed += 1 | ||
# case with quotes to test arg `quoting=csv.QUOTE_NONE` | ||
dataset_line = (i, label, rand_string_4, rand_string_5) | ||
# append line to correct dataset split | ||
mocked_data[name].append(dataset_line) | ||
f.write( | ||
f'{rand_string_1}"\t"{rand_string_2}\t{rand_string_3}\t{i}\t{label}\t{rand_string_4}\t{rand_string_5}\n' | ||
) | ||
|
||
compressed_dataset_path = os.path.join(base_dir, "Stsbenchmark.tar.gz") | ||
# create tar file from dataset folder | ||
with tarfile.open(compressed_dataset_path, "w:gz") as tar: | ||
tar.add(temp_dataset_dir, arcname="stsbenchmark") | ||
|
||
return mocked_data | ||
|
||
|
||
class TestSTSB(TempDirMixin, TorchtextTestCase): | ||
root_dir = None | ||
samples = [] | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
super().setUpClass() | ||
cls.root_dir = cls.get_base_temp_dir() | ||
cls.samples = _get_mock_dataset(cls.root_dir) | ||
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) | ||
cls.patcher.start() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.patcher.stop() | ||
super().tearDownClass() | ||
|
||
@parameterized.expand(["train", "dev", "test"]) | ||
def test_stsb(self, split): | ||
dataset = STSB(root=self.root_dir, split=split) | ||
|
||
samples = list(dataset) | ||
expected_samples = self.samples[split] | ||
for sample, expected_sample in zip_equal(samples, expected_samples): | ||
self.assertEqual(sample, expected_sample) | ||
|
||
@parameterized.expand(["train", "dev", "test"]) | ||
def test_stsb_split_argument(self, split): | ||
dataset1 = STSB(root=self.root_dir, split=split) | ||
(dataset2,) = STSB(root=self.root_dir, split=(split,)) | ||
|
||
for d1, d2 in zip_equal(dataset1, dataset2): | ||
self.assertEqual(d1, d2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import csv | ||
import os | ||
|
||
from torchtext._internal.module_utils import is_module_available | ||
from torchtext.data.datasets_utils import ( | ||
_create_dataset_directory, | ||
_wrap_split_argument, | ||
) | ||
|
||
if is_module_available("torchdata"): | ||
from torchdata.datapipes.iter import FileOpener, IterableWrapper | ||
|
||
# we import HttpReader from _download_hooks so we can swap out public URLs | ||
# with interal URLs when the dataset is used within Facebook | ||
from torchtext._download_hooks import HttpReader | ||
|
||
|
||
URL = "http://ixa2.si.ehu.es/stswiki/images/4/48/Stsbenchmark.tar.gz" | ||
|
||
MD5 = "4eb0065aba063ef77873d3a9c8088811" | ||
|
||
NUM_LINES = { | ||
"train": 5749, | ||
"dev": 1500, | ||
"test": 1379, | ||
} | ||
|
||
_PATH = "Stsbenchmark.tar.gz" | ||
|
||
DATASET_NAME = "STSB" | ||
|
||
_EXTRACTED_FILES = { | ||
"train": os.path.join("stsbenchmark", "sts-train.csv"), | ||
"dev": os.path.join("stsbenchmark", "sts-dev.csv"), | ||
"test": os.path.join("stsbenchmark", "sts-test.csv"), | ||
} | ||
|
||
|
||
@_create_dataset_directory(dataset_name=DATASET_NAME) | ||
@_wrap_split_argument(("train", "dev", "test")) | ||
def STSB(root, split): | ||
"""STSB Dataset | ||
|
||
For additional details refer to https://ixa2.si.ehu.eus/stswiki/index.php/STSbenchmark | ||
|
||
Number of lines per split: | ||
- train: 5749 | ||
- dev: 1500 | ||
- test: 1379 | ||
|
||
Args: | ||
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') | ||
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) | ||
|
||
:returns: DataPipe that yields tuple of (index (int), label (float), sentence1 (str), sentence2 (str)) | ||
:rtype: (int, float, str, str) | ||
""" | ||
# TODO Remove this after removing conditional dependency | ||
if not is_module_available("torchdata"): | ||
raise ModuleNotFoundError( | ||
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" | ||
) | ||
|
||
def _filepath_fn(x=_PATH): | ||
return os.path.join(root, os.path.basename(x)) | ||
|
||
def _extracted_filepath_fn(_=None): | ||
return _filepath_fn(_EXTRACTED_FILES[split]) | ||
|
||
def _filter_fn(x): | ||
return _EXTRACTED_FILES[split] in x[0] | ||
|
||
def _modify_res(x): | ||
return (int(x[3]), float(x[4]), x[5], x[6]) | ||
|
||
url_dp = IterableWrapper([URL]) | ||
cache_compressed_dp = url_dp.on_disk_cache( | ||
filepath_fn=_filepath_fn, | ||
hash_dict={_filepath_fn(URL): MD5}, | ||
hash_type="md5", | ||
) | ||
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) | ||
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_tar().filter(_filter_fn) | ||
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) | ||
|
||
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") | ||
parsed_data = data_dp.parse_csv(delimiter="\t", quoting=csv.QUOTE_NONE).map(_modify_res) | ||
return parsed_data.shuffle().set_shuffle(False).sharding_filter() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.