-
Notifications
You must be signed in to change notification settings - Fork 814
Add Shuffle and sharding datapipes to datasets #1729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from parameterized import parameterized | ||
from torch.utils.data.graph import traverse | ||
from torch.utils.data.graph_settings import get_all_graph_pipes | ||
from torchdata.datapipes.iter import Shuffler, ShardingFilter | ||
from torchtext.datasets import DATASETS | ||
|
||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
class TestShuffleShardDatasetWrapper(TorchtextTestCase): | ||
# Note that for order i.e shuffle before sharding, TorchData will provide linter warning | ||
# Modify this test when linter warning is available | ||
@parameterized.expand(list(DATASETS.items())) | ||
def test_shuffle_shard_wrapper(self, dataset_name, dataset_fn): | ||
dp = dataset_fn() | ||
if type(dp) == tuple: | ||
dp = list(dp) | ||
else: | ||
dp = [dp] | ||
|
||
for dp_split in dp: | ||
dp_graph = get_all_graph_pipes(traverse(dp_split)) | ||
for annotation_dp_type in [Shuffler, ShardingFilter]: | ||
if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph): | ||
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") | ||
Comment on lines
+1
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Nayef211 Just FYI, if we can do something similar for pickle :). @ejguan I left a comment to update the test once linter warnings are available, it is not a blocker for landing this PR. cc: @NicolasHug There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I will do a fix for manylinux1 wheel first then add linter for you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: it looks like we're not using the
dataset_name
anywhere. Why don't we just pass in thedataset_fn
to the test by doing something like@parameterized.expand(list(DATASETS.values()))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh my bad, thanks for the catch. Will fix it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually realized that the
parameterized.expand
decorator complains when passing in the list ofdataset_fn
. Lmk if you're able to figure out how to resolve the errorThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yupp, the problem is we need to pass tuples inside list. Just created PR to fix it #1733
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. Just incorporated this in my PR #1732