Skip to content

Commit 2a712f4

Browse files
authored
Add Shuffle and sharding datapipes to datasets (#1729)
1 parent 88086d9 commit 2a712f4

23 files changed

+47
-22
lines changed

test/datasets/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from parameterized import parameterized
2+
from torch.utils.data.graph import traverse
3+
from torch.utils.data.graph_settings import get_all_graph_pipes
4+
from torchdata.datapipes.iter import Shuffler, ShardingFilter
5+
from torchtext.datasets import DATASETS
6+
7+
from ..common.torchtext_test_case import TorchtextTestCase
8+
9+
10+
class TestShuffleShardDatasetWrapper(TorchtextTestCase):
11+
# Note that for order i.e shuffle before sharding, TorchData will provide linter warning
12+
# Modify this test when linter warning is available
13+
@parameterized.expand(list(DATASETS.items()))
14+
def test_shuffle_shard_wrapper(self, dataset_name, dataset_fn):
15+
dp = dataset_fn()
16+
if type(dp) == tuple:
17+
dp = list(dp)
18+
else:
19+
dp = [dp]
20+
21+
for dp_split in dp:
22+
dp_graph = get_all_graph_pipes(traverse(dp_split))
23+
for annotation_dp_type in [Shuffler, ShardingFilter]:
24+
if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph):
25+
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")

torchtext/datasets/ag_news.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,4 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]):
7171
cache_dp = cache_dp.end_caching(mode="wb", same_filepath_fn=True)
7272

7373
data_dp = FileOpener(cache_dp, encoding="utf-8")
74-
return data_dp.parse_csv().map(fn=_modify_res)
74+
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/amazonreviewfull.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,4 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
9090
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
9191

9292
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
93-
return data_dp.parse_csv().map(fn=_modify_res)
93+
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/amazonreviewpolarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,4 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
8787
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
8888

8989
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
90-
return data_dp.parse_csv().map(fn=_modify_res)
90+
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/cc100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,4 @@ def CC100(root: str, language_code: str = "en"):
176176
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")
177177

178178
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False)
179-
return data_dp.map(partial(_modify_res, language_code))
179+
return data_dp.map(partial(_modify_res, language_code)).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/conll2000chunking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,4 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
8080
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
8181

8282
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
83-
return data_dp.readlines().read_iob(sep=" ")
83+
return data_dp.readlines().read_iob(sep=" ").shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/dbpedia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,4 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
8686
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
8787

8888
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
89-
return data_dp.parse_csv().map(fn=_modify_res)
89+
return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/enwik9.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ def EnWik9(root: str):
5959
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
6060

6161
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
62-
return data_dp.readlines(return_path=False)
62+
return data_dp.readlines(return_path=False).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/imdb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ def filter_imdb_data(key, fname):
111111

112112
data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
113113
# get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg"
114-
return data_dp.readlines().map(_modify_res)
114+
return data_dp.readlines().map(_modify_res).shuffle().set_shuffle(False).sharding_filter()

torchtext/datasets/iwslt2016.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,4 +322,4 @@ def IWSLT2016(
322322
src_lines = src_data_dp.readlines(return_path=False, strip_newline=False)
323323
tgt_lines = tgt_data_dp.readlines(return_path=False, strip_newline=False)
324324

325-
return src_lines.zip(tgt_lines)
325+
return src_lines.zip(tgt_lines).shuffle().set_shuffle(False).sharding_filter()

0 commit comments

Comments
 (0)