From 1542b4ac1ed57241f6180d43517a99c6e3446e65 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 11 May 2022 11:48:19 -0400 Subject: [PATCH 1/6] Add support for WNLI dataset + unit tests --- test/datasets/test_wnli.py | 84 +++++++++++++++++++++++++++++++++ torchtext/datasets/__init__.py | 2 + torchtext/datasets/wnli.py | 86 ++++++++++++++++++++++++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 test/datasets/test_wnli.py create mode 100644 torchtext/datasets/wnli.py diff --git a/test/datasets/test_wnli.py b/test/datasets/test_wnli.py new file mode 100644 index 0000000000..e14760d049 --- /dev/null +++ b/test/datasets/test_wnli.py @@ -0,0 +1,84 @@ +import os +import zipfile +from collections import defaultdict +from unittest.mock import patch + +from parameterized import parameterized +from torchtext.datasets.wnli import WNLI + +from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode +from ..common.torchtext_test_case import TorchtextTestCase + + +def _get_mock_dataset(root_dir): + """ + root_dir: directory to the mocked dataset + """ + base_dir = os.path.join(root_dir, "WNLI") + temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") + os.makedirs(temp_dataset_dir, exist_ok=True) + + seed = 1 + mocked_data = defaultdict(list) + for file_name in ("train.tsv", "test.tsv", "dev.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + with open(txt_file, "w", encoding="utf-8") as f: + f.write("index\tsentence1\tsentence2\tlabel\n") + for i in range(5): + label = seed % 2 + rand_string_1 = get_random_unicode(seed) + rand_string_2 = get_random_unicode(seed+1) + if file_name == "test.tsv": + dataset_line = (rand_string_1, rand_string_2) + f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\n") + else: + dataset_line = (label, rand_string_1, rand_string_2) + f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\t{label}\n") + + # append line to correct dataset split + mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) + seed += 1 + + compressed_dataset_path = os.path.join(base_dir, "WNLI.zip") + # create zip file from dataset folder + with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: + for file_name in ("train.tsv", "test.tsv", "dev.tsv"): + txt_file = os.path.join(temp_dataset_dir, file_name) + zip_file.write(txt_file, arcname=os.path.join("WNLI", file_name)) + + return mocked_data + + +class TestWNLI(TempDirMixin, TorchtextTestCase): + root_dir = None + samples = [] + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.root_dir = cls.get_base_temp_dir() + cls.samples = _get_mock_dataset(cls.root_dir) + cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) + cls.patcher.start() + + @classmethod + def tearDownClass(cls): + cls.patcher.stop() + super().tearDownClass() + + @parameterized.expand(["train", "test", "dev"]) + def test_wnli(self, split): + dataset = WNLI(root=self.root_dir, split=split) + + samples = list(dataset) + expected_samples = self.samples[split] + for sample, expected_sample in zip_equal(samples, expected_samples): + self.assertEqual(sample, expected_sample) + + @parameterized.expand(["train", "test", "dev"]) + def test_wnli_split_argument(self, split): + dataset1 = WNLI(root=self.root_dir, split=split) + (dataset2,) = WNLI(root=self.root_dir, split=(split,)) + + for d1, d2 in zip_equal(dataset1, dataset2): + self.assertEqual(d1, d2) diff --git a/torchtext/datasets/__init__.py b/torchtext/datasets/__init__.py index d7d33298ad..52f890a4d9 100644 --- a/torchtext/datasets/__init__.py +++ b/torchtext/datasets/__init__.py @@ -19,6 +19,7 @@ from .udpos import UDPOS from .wikitext103 import WikiText103 from .wikitext2 import WikiText2 +from .wnli import WNLI from .yahooanswers import YahooAnswers from .yelpreviewfull import YelpReviewFull from .yelpreviewpolarity import YelpReviewPolarity @@ -43,6 +44,7 @@ "UDPOS": UDPOS, "WikiText103": WikiText103, "WikiText2": WikiText2, + "WNLI": WNLI, "YahooAnswers": YahooAnswers, "YelpReviewFull": YelpReviewFull, "YelpReviewPolarity": YelpReviewPolarity, diff --git a/torchtext/datasets/wnli.py b/torchtext/datasets/wnli.py new file mode 100644 index 0000000000..f5dfe41cdf --- /dev/null +++ b/torchtext/datasets/wnli.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os + +from torchtext._internal.module_utils import is_module_available +from torchtext.data.datasets_utils import ( + _create_dataset_directory, + _wrap_split_argument, +) + +if is_module_available("torchdata"): + from torchdata.datapipes.iter import FileOpener, IterableWrapper + + # we import HttpReader from _download_hooks so we can swap out public URLs + # with interal URLs when the dataset is used within Facebook + from torchtext._download_hooks import HttpReader + + +URL = "https://dl.fbaipublicfiles.com/glue/data/WNLI.zip" + +MD5 = "a1b4bd2861017d302d29e42139657a42" + +NUM_LINES = { + "train": 635, + "dev": 71, + "test": 146, +} + +_PATH = "WNLI.zip" + +DATASET_NAME = "WNLI" + +_EXTRACTED_FILES = { + "train": os.path.join("WNLI", "train.tsv"), + "dev": os.path.join("WNLI", "dev.tsv"), + "test": os.path.join("WNLI", "test.tsv"), +} + + +@_create_dataset_directory(dataset_name=DATASET_NAME) +@_wrap_split_argument(("train", "dev", "test")) +def WNLI(root, split): + """WNLI Dataset + + For additional details refer to https://arxiv.org/pdf/1804.07461v3.pdf + + Number of lines per split: + - train: 635 + - dev: 71 + - test: 146 + + Args: + root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') + split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`) + + :returns: DataPipe that yields tuple of text and/or label (0 to 1). The `test` split only returns text. + :rtype: Union[(int, str, str), (str, str)] + """ + # TODO Remove this after removing conditional dependency + if not is_module_available("torchdata"): + raise ModuleNotFoundError( + "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" + ) + + url_dp = IterableWrapper([URL]) + cache_compressed_dp = url_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), + hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + hash_type="md5", + ) + cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) + + cache_decompressed_dp = cache_compressed_dp.on_disk_cache( + filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + ) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + ) + cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) + + data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") + # test split for WNLI doesn't have labels + if split == "test": + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (t[1], t[2])) + else: + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (int(t[3]), t[1], t[2])) + return parsed_data From ae9a982bc97f36cd5dc896935c616a2f22cfacb0 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Fri, 13 May 2022 11:31:53 -0400 Subject: [PATCH 2/6] Remove lambda functions --- torchtext/datasets/wnli.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/torchtext/datasets/wnli.py b/torchtext/datasets/wnli.py index f5dfe41cdf..f33d238927 100644 --- a/torchtext/datasets/wnli.py +++ b/torchtext/datasets/wnli.py @@ -61,26 +61,37 @@ def WNLI(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) + def _filepath_fn(x=None): + return os.path.join(root, os.path.basename(x)) + + def _extracted_filepath_fn(_=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + def _filter_fn(x): + return _EXTRACTED_FILES[split] in x[0] + + def _modify_res(t): + if split == "test": + return (t[1], t[2]) + else: + return (int(t[3]), t[1], t[2]) + url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, os.path.basename(x)), - hash_dict={os.path.join(root, os.path.basename(URL)): MD5}, + filepath_fn=_filepath_fn, + hash_dict={_filepath_fn(URL): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split]) + filepath_fn=_extracted_filepath_fn ) cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(lambda x: _EXTRACTED_FILES[split] in x[0]) + FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - # test split for WNLI doesn't have labels - if split == "test": - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (t[1], t[2])) - else: - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(lambda t: (int(t[3]), t[1], t[2])) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) return parsed_data From 60226d59298355a76af5a52cf082da1914f51b79 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Tue, 17 May 2022 12:02:32 -0400 Subject: [PATCH 3/6] Add dataset documentation --- docs/source/datasets.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 33eb44b21d..15097d9606 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -62,6 +62,11 @@ SST2 .. autofunction:: SST2 +WNLI +~~~~ + +.. autofunction:: WNLI + YahooAnswers ~~~~~~~~~~~~ From d9270b5a863fb9e1d4f32aec7ddc5ca7594446e0 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 15:51:24 -0400 Subject: [PATCH 4/6] Add shuffle and sharding --- torchtext/datasets/wnli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtext/datasets/wnli.py b/torchtext/datasets/wnli.py index f33d238927..f291f1fff9 100644 --- a/torchtext/datasets/wnli.py +++ b/torchtext/datasets/wnli.py @@ -94,4 +94,4 @@ def _modify_res(t): data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) - return parsed_data + return parsed_data.shuffle().set_shuffle(False).sharding_filter() From f6405d647191a01bbdc402e66832973195be8974 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Wed, 18 May 2022 16:09:05 -0400 Subject: [PATCH 5/6] Lint --- docs/source/datasets.rst | 2 +- test/datasets/test_wnli.py | 2 +- torchtext/datasets/wnli.py | 12 ++++-------- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 15097d9606..7fd16d2db7 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -65,7 +65,7 @@ SST2 WNLI ~~~~ -.. autofunction:: WNLI +.. autofunction:: WNLI YahooAnswers ~~~~~~~~~~~~ diff --git a/test/datasets/test_wnli.py b/test/datasets/test_wnli.py index e14760d049..2bef2ca7de 100644 --- a/test/datasets/test_wnli.py +++ b/test/datasets/test_wnli.py @@ -27,7 +27,7 @@ def _get_mock_dataset(root_dir): for i in range(5): label = seed % 2 rand_string_1 = get_random_unicode(seed) - rand_string_2 = get_random_unicode(seed+1) + rand_string_2 = get_random_unicode(seed + 1) if file_name == "test.tsv": dataset_line = (rand_string_1, rand_string_2) f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\n") diff --git a/torchtext/datasets/wnli.py b/torchtext/datasets/wnli.py index f291f1fff9..1df9467c81 100644 --- a/torchtext/datasets/wnli.py +++ b/torchtext/datasets/wnli.py @@ -62,7 +62,7 @@ def WNLI(root, split): ) def _filepath_fn(x=None): - return os.path.join(root, os.path.basename(x)) + return os.path.join(root, os.path.basename(x)) def _extracted_filepath_fn(_=None): return os.path.join(root, _EXTRACTED_FILES[split]) @@ -84,14 +84,10 @@ def _modify_res(t): ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache( - filepath_fn=_extracted_filepath_fn - ) - cache_decompressed_dp = ( - FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) - ) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) + cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) return parsed_data.shuffle().set_shuffle(False).sharding_filter() From 85eac40cb7d27eae594255e4bb0164f9975cfa87 Mon Sep 17 00:00:00 2001 From: Virgile Mison Date: Thu, 19 May 2022 15:52:08 -0400 Subject: [PATCH 6/6] Move local to global functions + use load_from_zip --- torchtext/datasets/wnli.py | 47 ++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/torchtext/datasets/wnli.py b/torchtext/datasets/wnli.py index 1df9467c81..5c0226e8c7 100644 --- a/torchtext/datasets/wnli.py +++ b/torchtext/datasets/wnli.py @@ -1,5 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os +from functools import partial from torchtext._internal.module_utils import is_module_available from torchtext.data.datasets_utils import ( @@ -36,6 +37,25 @@ } +def _filepath_fn(root, x=None): + return os.path.join(root, os.path.basename(x)) + + +def _extracted_filepath_fn(root, split, _=None): + return os.path.join(root, _EXTRACTED_FILES[split]) + + +def _filter_fn(split, x): + return _EXTRACTED_FILES[split] in x[0] + + +def _modify_res(split, t): + if split == "test": + return (t[1], t[2]) + else: + return (int(t[3]), t[1], t[2]) + + @_create_dataset_directory(dataset_name=DATASET_NAME) @_wrap_split_argument(("train", "dev", "test")) def WNLI(root, split): @@ -61,33 +81,20 @@ def WNLI(root, split): "Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`" ) - def _filepath_fn(x=None): - return os.path.join(root, os.path.basename(x)) - - def _extracted_filepath_fn(_=None): - return os.path.join(root, _EXTRACTED_FILES[split]) - - def _filter_fn(x): - return _EXTRACTED_FILES[split] in x[0] - - def _modify_res(t): - if split == "test": - return (t[1], t[2]) - else: - return (int(t[3]), t[1], t[2]) - url_dp = IterableWrapper([URL]) cache_compressed_dp = url_dp.on_disk_cache( - filepath_fn=_filepath_fn, - hash_dict={_filepath_fn(URL): MD5}, + filepath_fn=partial(_filepath_fn, root), + hash_dict={_filepath_fn(root, URL): MD5}, hash_type="md5", ) cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) - cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn) - cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").read_from_zip().filter(_filter_fn) + cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split)) + cache_decompressed_dp = ( + FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split)) + ) cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") - parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) + parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(partial(_modify_res, split)) return parsed_data.shuffle().set_shuffle(False).sharding_filter()