-
Notifications
You must be signed in to change notification settings - Fork 814
Add Shuffle and sharding datapipes to datasets #1729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I enabled AWS extension last night. It seems the C-Extension is using glibc-2.29 during compilation. Will do a quick fix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, LGTM.
Perhaps it might make sense to add a small test that makes sure all datapipes come with a sharding filter and a shuffler? We have that in torchvision
Thanks @NicolasHug for the suggestion. I think it is indeed a good idea. Let me add it as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect to have variations for shuffle and sharding between datasets, as for now all datasets use .shuffle().set_shuffle(False).sharding_filter()
? If not, we could consider using instead a decorator to wrap this call on the returned object and clean up the code.
Good point @VirgileHlav . In general you want to shard and shuffle on light objects (before decoding, before transforms) to avoid unnecessary computations, and to save memory. For now torchtext datasets yield light objects (simple text), but maybe in the future this will change? In torchvision, we wrap in different places for each dataset, so using a decorator isn't an option. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NicolasHug thanks for the clarification, in that case let's keep it as is for now. @parmeet I will also add this to the dataset effort in #1710.
Otherwise LGTM
Thanks @VirgileHlav SGTM! As @NicolasHug mentioned it could potentially vary from dataset to dataset. So if in some cases processing is needed at sample level, you could potentially shard the pipe before processing, but otherwise adding it at the end is fine. |
from parameterized import parameterized | ||
from torch.utils.data.graph import traverse | ||
from torch.utils.data.graph_settings import get_all_graph_pipes | ||
from torchdata.datapipes.iter import Shuffler, ShardingFilter | ||
from torchtext.datasets import DATASETS | ||
|
||
from ..common.torchtext_test_case import TorchtextTestCase | ||
|
||
|
||
class TestShuffleShardDatasetWrapper(TorchtextTestCase): | ||
# Note that for order i.e shuffle before sharding, TorchData will provide linter warning | ||
# Modify this test when linter warning is available | ||
@parameterized.expand(list(DATASETS.items())) | ||
def test_shuffle_shard_wrapper(self, dataset_name, dataset_fn): | ||
dp = dataset_fn() | ||
if type(dp) == tuple: | ||
dp = list(dp) | ||
else: | ||
dp = [dp] | ||
|
||
for dp_split in dp: | ||
dp_graph = get_all_graph_pipes(traverse(dp_split)) | ||
for annotation_dp_type in [Shuffler, ShardingFilter]: | ||
if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph): | ||
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Nayef211 Just FYI, if we can do something similar for pickle :). @ejguan I left a comment to update the test once linter warnings are available, it is not a blocker for landing this PR.
cc: @NicolasHug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I will do a fix for manylinux1 wheel first then add linter for you.
class TestShuffleShardDatasetWrapper(TorchtextTestCase): | ||
# Note that for order i.e shuffle before sharding, TorchData will provide linter warning | ||
# Modify this test when linter warning is available | ||
@parameterized.expand(list(DATASETS.items())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: it looks like we're not using the dataset_name
anywhere. Why don't we just pass in the dataset_fn
to the test by doing something like
@parameterized.expand(list(DATASETS.values()))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh my bad, thanks for the catch. Will fix it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually realized that the parameterized.expand
decorator complains when passing in the list of dataset_fn
. Lmk if you're able to figure out how to resolve the error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yupp, the problem is we need to pass tuples inside list. Just created PR to fix it #1733
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome. Just incorporated this in my PR #1732
Summary: This PR introduces the linter function to validate there is a shuffle operation before sharding. (Required by TorchText in pytorch/text#1729) - When `sharding_filter` is not presented in the graph, this function always returns `True` - For single-path graph, `shuffle` needs to be placed before `sharding_filter`. - For multi-path graph, any `sharding_filter` requires a `shuffle` before along the path This linter function won't check if there are multiple `sharding_filter` in the graph or `sharding_filter` is at the right place Pull Request resolved: #429 Reviewed By: NivekT Differential Revision: D36529167 Pulled By: ejguan fbshipit-source-id: 56e734eac98b2ddadcd7707ee92ea4032a896969
Reference Issue: #1727