Skip to content

Commit 73bf4fa

Browse files
authored
Add support for WNLI dataset with unit tests (#1724)
* Add support for WNLI dataset + unit tests * Add dataset documentation * Add shuffle and sharding * Move local to global functions + use load_from_zip
1 parent 932d776 commit 73bf4fa

File tree

4 files changed

+191
-0
lines changed

4 files changed

+191
-0
lines changed

docs/source/datasets.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ STSB
9797

9898
.. autofunction:: STSB
9999

100+
WNLI
101+
~~~~
102+
103+
.. autofunction:: WNLI
104+
100105
YahooAnswers
101106
~~~~~~~~~~~~
102107

test/datasets/test_wnli.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
import zipfile
3+
from collections import defaultdict
4+
from unittest.mock import patch
5+
6+
from parameterized import parameterized
7+
from torchtext.datasets.wnli import WNLI
8+
9+
from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode
10+
from ..common.torchtext_test_case import TorchtextTestCase
11+
12+
13+
def _get_mock_dataset(root_dir):
14+
"""
15+
root_dir: directory to the mocked dataset
16+
"""
17+
base_dir = os.path.join(root_dir, "WNLI")
18+
temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir")
19+
os.makedirs(temp_dataset_dir, exist_ok=True)
20+
21+
seed = 1
22+
mocked_data = defaultdict(list)
23+
for file_name in ("train.tsv", "test.tsv", "dev.tsv"):
24+
txt_file = os.path.join(temp_dataset_dir, file_name)
25+
with open(txt_file, "w", encoding="utf-8") as f:
26+
f.write("index\tsentence1\tsentence2\tlabel\n")
27+
for i in range(5):
28+
label = seed % 2
29+
rand_string_1 = get_random_unicode(seed)
30+
rand_string_2 = get_random_unicode(seed + 1)
31+
if file_name == "test.tsv":
32+
dataset_line = (rand_string_1, rand_string_2)
33+
f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\n")
34+
else:
35+
dataset_line = (label, rand_string_1, rand_string_2)
36+
f.write(f"{i}\t{rand_string_1}\t{rand_string_2}\t{label}\n")
37+
38+
# append line to correct dataset split
39+
mocked_data[os.path.splitext(file_name)[0]].append(dataset_line)
40+
seed += 1
41+
42+
compressed_dataset_path = os.path.join(base_dir, "WNLI.zip")
43+
# create zip file from dataset folder
44+
with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file:
45+
for file_name in ("train.tsv", "test.tsv", "dev.tsv"):
46+
txt_file = os.path.join(temp_dataset_dir, file_name)
47+
zip_file.write(txt_file, arcname=os.path.join("WNLI", file_name))
48+
49+
return mocked_data
50+
51+
52+
class TestWNLI(TempDirMixin, TorchtextTestCase):
53+
root_dir = None
54+
samples = []
55+
56+
@classmethod
57+
def setUpClass(cls):
58+
super().setUpClass()
59+
cls.root_dir = cls.get_base_temp_dir()
60+
cls.samples = _get_mock_dataset(cls.root_dir)
61+
cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True)
62+
cls.patcher.start()
63+
64+
@classmethod
65+
def tearDownClass(cls):
66+
cls.patcher.stop()
67+
super().tearDownClass()
68+
69+
@parameterized.expand(["train", "test", "dev"])
70+
def test_wnli(self, split):
71+
dataset = WNLI(root=self.root_dir, split=split)
72+
73+
samples = list(dataset)
74+
expected_samples = self.samples[split]
75+
for sample, expected_sample in zip_equal(samples, expected_samples):
76+
self.assertEqual(sample, expected_sample)
77+
78+
@parameterized.expand(["train", "test", "dev"])
79+
def test_wnli_split_argument(self, split):
80+
dataset1 = WNLI(root=self.root_dir, split=split)
81+
(dataset2,) = WNLI(root=self.root_dir, split=(split,))
82+
83+
for d1, d2 in zip_equal(dataset1, dataset2):
84+
self.assertEqual(d1, d2)

torchtext/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .udpos import UDPOS
2727
from .wikitext103 import WikiText103
2828
from .wikitext2 import WikiText2
29+
from .wnli import WNLI
2930
from .yahooanswers import YahooAnswers
3031
from .yelpreviewfull import YelpReviewFull
3132
from .yelpreviewpolarity import YelpReviewPolarity
@@ -57,6 +58,7 @@
5758
"UDPOS": UDPOS,
5859
"WikiText103": WikiText103,
5960
"WikiText2": WikiText2,
61+
"WNLI": WNLI,
6062
"YahooAnswers": YahooAnswers,
6163
"YelpReviewFull": YelpReviewFull,
6264
"YelpReviewPolarity": YelpReviewPolarity,

torchtext/datasets/wnli.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import os
3+
from functools import partial
4+
5+
from torchtext._internal.module_utils import is_module_available
6+
from torchtext.data.datasets_utils import (
7+
_create_dataset_directory,
8+
_wrap_split_argument,
9+
)
10+
11+
if is_module_available("torchdata"):
12+
from torchdata.datapipes.iter import FileOpener, IterableWrapper
13+
14+
# we import HttpReader from _download_hooks so we can swap out public URLs
15+
# with interal URLs when the dataset is used within Facebook
16+
from torchtext._download_hooks import HttpReader
17+
18+
19+
URL = "https://dl.fbaipublicfiles.com/glue/data/WNLI.zip"
20+
21+
MD5 = "a1b4bd2861017d302d29e42139657a42"
22+
23+
NUM_LINES = {
24+
"train": 635,
25+
"dev": 71,
26+
"test": 146,
27+
}
28+
29+
_PATH = "WNLI.zip"
30+
31+
DATASET_NAME = "WNLI"
32+
33+
_EXTRACTED_FILES = {
34+
"train": os.path.join("WNLI", "train.tsv"),
35+
"dev": os.path.join("WNLI", "dev.tsv"),
36+
"test": os.path.join("WNLI", "test.tsv"),
37+
}
38+
39+
40+
def _filepath_fn(root, x=None):
41+
return os.path.join(root, os.path.basename(x))
42+
43+
44+
def _extracted_filepath_fn(root, split, _=None):
45+
return os.path.join(root, _EXTRACTED_FILES[split])
46+
47+
48+
def _filter_fn(split, x):
49+
return _EXTRACTED_FILES[split] in x[0]
50+
51+
52+
def _modify_res(split, t):
53+
if split == "test":
54+
return (t[1], t[2])
55+
else:
56+
return (int(t[3]), t[1], t[2])
57+
58+
59+
@_create_dataset_directory(dataset_name=DATASET_NAME)
60+
@_wrap_split_argument(("train", "dev", "test"))
61+
def WNLI(root, split):
62+
"""WNLI Dataset
63+
64+
For additional details refer to https://arxiv.org/pdf/1804.07461v3.pdf
65+
66+
Number of lines per split:
67+
- train: 635
68+
- dev: 71
69+
- test: 146
70+
71+
Args:
72+
root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache')
73+
split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `dev`, `test`)
74+
75+
:returns: DataPipe that yields tuple of text and/or label (0 to 1). The `test` split only returns text.
76+
:rtype: Union[(int, str, str), (str, str)]
77+
"""
78+
# TODO Remove this after removing conditional dependency
79+
if not is_module_available("torchdata"):
80+
raise ModuleNotFoundError(
81+
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
82+
)
83+
84+
url_dp = IterableWrapper([URL])
85+
cache_compressed_dp = url_dp.on_disk_cache(
86+
filepath_fn=partial(_filepath_fn, root),
87+
hash_dict={_filepath_fn(root, URL): MD5},
88+
hash_type="md5",
89+
)
90+
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)
91+
92+
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root, split))
93+
cache_decompressed_dp = (
94+
FileOpener(cache_decompressed_dp, mode="b").load_from_zip().filter(partial(_filter_fn, split))
95+
)
96+
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
97+
98+
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
99+
parsed_data = data_dp.parse_csv(skip_lines=1, delimiter="\t").map(partial(_modify_res, split))
100+
return parsed_data.shuffle().set_shuffle(False).sharding_filter()

0 commit comments

Comments
 (0)