Skip to content

Replacing lambda functions with regular functions in all datasets #1718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchtext/datasets/ag_news.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,20 @@ def AG_NEWS(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, split + ".csv")

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

Comment on lines +55 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use global function rather than local function here. Local functions are still non-picklable for pickle module.

url_dp = IterableWrapper([URL[split]])
cache_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, split + ".csv"),
hash_dict={os.path.join(root, split + ".csv"): MD5[split]},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5[split]},
hash_type="md5",
)
cache_dp = HttpReader(cache_dp)
cache_dp = cache_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
return data_dp.parse_csv().map(fn=_modify_res)
26 changes: 17 additions & 9 deletions torchtext/datasets/amazonreviewfull.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,29 @@ def AmazonReviewFull(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

Comment on lines +61 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
return data_dp.parse_csv().map(fn=_modify_res)
26 changes: 17 additions & 9 deletions torchtext/datasets/amazonreviewpolarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,29 @@ def AmazonReviewPolarity(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
return data_dp.parse_csv().map(fn=_modify_res)
17 changes: 12 additions & 5 deletions torchtext/datasets/cc100.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,25 @@ def CC100(root: str, language_code: str = "en"):
if language_code not in VALID_CODES:
raise ValueError(f"Invalid language code {language_code}")

def _filepath_fn(_=None):
return os.path.join(root, os.path.basename(url))

def _decompressed_filepath_fn(x):
return os.path.join(root, os.path.basename(x).rstrip(".xz"))

def _modify_res(x):
return language_code, x

url = URL % language_code
url_dp = IterableWrapper([url])
cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=lambda x: os.path.join(root, os.path.basename(url)))
cache_compressed_dp = url_dp.on_disk_cache(filepath_fn=_filepath_fn)

cache_compressed_dp = HttpReader(cache_compressed_dp)
cache_compressed_dp = cache_compressed_dp.end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(x).rstrip(".xz"))
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_xz()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb")

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8").readlines(return_path=False)
return data_dp.map(lambda x: (language_code, x))
return data_dp.map(_modify_res)
14 changes: 9 additions & 5 deletions torchtext/datasets/conll2000chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,24 @@ def CoNLL2000Chunking(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, os.path.basename(URL[split]))

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

url_dp = IterableWrapper([URL[split]])

# Cache and check HTTP response
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.basename(URL[split])),
hash_dict={os.path.join(root, os.path.basename(URL[split])): MD5[split]},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5[split]},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

# Cache and check the gzip extraction for relevant split
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").extract(file_type="gzip")
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

Expand Down
26 changes: 17 additions & 9 deletions torchtext/datasets/dbpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,29 @@ def DBpedia(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, _EXTRACTED_FILES[split])

def _filter_fn(x):
return _EXTRACTED_FILES[split] in x[0]

def _modify_res(t):
return int(t[0]), " ".join(t[1:])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _EXTRACTED_FILES[split])
)
cache_decompressed_dp = (
FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(lambda x: _EXTRACTED_FILES[split] in x[0])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_tar().filter(_filter_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
return data_dp.parse_csv().map(fn=lambda t: (int(t[0]), " ".join(t[1:])))
return data_dp.parse_csv().map(fn=_modify_res)
14 changes: 9 additions & 5 deletions torchtext/datasets/enwik9.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,21 @@ def EnWik9(root: str):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _extracted_filepath_fn(_=None):
return os.path.join(root, os.path.splitext(_PATH)[0])

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, os.path.splitext(_PATH)[0])
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_extracted_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip()
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)

Expand Down
41 changes: 29 additions & 12 deletions torchtext/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,39 @@ def IMDB(root: str, split: Union[Tuple[str], str]):
"Package `torchdata` not found. Please install following instructions at `https://github.com/pytorch/data`"
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

def _decompressed_filepath_fn(_=None):
return [os.path.join(root, decompressed_folder, split, label) for label in labels]

def _filter_fn(t):
return filter_imdb_data(split, t[0])

def _path_map_fn(t):
return Path(t[0]).parts[-2], t[1]

def _encode_map_fn(x):
return x[0], x[1].encode()

def _cache_filepath_fn(x):
return os.path.join(root, decompressed_folder, split, x)

def _modify_res(t):
return Path(t[0]).parts[-1], t[1]

url_dp = IterableWrapper([URL])

cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True)

labels = {"neg", "pos"}
decompressed_folder = "aclImdb_v1"
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(
filepath_fn=lambda x: [os.path.join(root, decompressed_folder, split, label) for label in labels]
)
cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_decompressed_filepath_fn)
cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b")
cache_decompressed_dp = cache_decompressed_dp.load_from_tar()

Expand All @@ -69,17 +88,15 @@ def filter_imdb_data(key, fname):
*_, split, label, file = Path(fname).parts
return key == split and label in labels

cache_decompressed_dp = cache_decompressed_dp.filter(lambda t: filter_imdb_data(split, t[0]))
cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn)

# eg. "aclImdb/train/neg/12416_3.txt" -> "neg"
cache_decompressed_dp = cache_decompressed_dp.map(lambda t: (Path(t[0]).parts[-2], t[1]))
cache_decompressed_dp = cache_decompressed_dp.map(_path_map_fn)
cache_decompressed_dp = cache_decompressed_dp.readlines(decode=True)
cache_decompressed_dp = cache_decompressed_dp.lines_to_paragraphs() # group by label in cache file
cache_decompressed_dp = cache_decompressed_dp.map(lambda x: (x[0], x[1].encode()))
cache_decompressed_dp = cache_decompressed_dp.end_caching(
mode="wb", filepath_fn=lambda x: os.path.join(root, decompressed_folder, split, x), skip_read=True
)
cache_decompressed_dp = cache_decompressed_dp.map(_encode_map_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", filepath_fn=_cache_filepath_fn, skip_read=True)

data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8")
# get label from cache file, eg. "aclImdb_v1/train/neg" -> "neg"
return data_dp.readlines().map(lambda t: (Path(t[0]).parts[-1], t[1]))
return data_dp.readlines().map(_modify_res)
34 changes: 25 additions & 9 deletions torchtext/datasets/iwslt2016.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,19 @@
# TODO: migrate this to dataset_utils.py once torchdata is a hard dependency to
# avoid additional conditional imports.
def _filter_clean_cache(cache_decompressed_dp, full_filepath, uncleaned_filename):
cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=lambda x: full_filepath)
def _return_full_filepath(_=None):
return full_filepath

def _filter_fn(x):
return os.path.basename(uncleaned_filename) in x[0]

def _clean_files_wrapper(x):
return _clean_files(full_filepath, x[0], x[1])

cache_inner_decompressed_dp = cache_decompressed_dp.on_disk_cache(filepath_fn=_return_full_filepath)
cache_inner_decompressed_dp = cache_inner_decompressed_dp.open_files(mode="b").load_from_tar()
cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(
lambda x: os.path.basename(uncleaned_filename) in x[0]
)
cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(lambda x: _clean_files(full_filepath, x[0], x[1]))
cache_inner_decompressed_dp = cache_inner_decompressed_dp.filter(_filter_fn)
cache_inner_decompressed_dp = cache_inner_decompressed_dp.map(_clean_files_wrapper)
cache_inner_decompressed_dp = cache_inner_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
return cache_inner_decompressed_dp

Expand Down Expand Up @@ -234,10 +241,13 @@ def IWSLT2016(
SUPPORTED_DATASETS["year"], src_language, tgt_language, valid_set, test_set
)

def _filepath_fn(_=None):
return os.path.join(root, _PATH)

url_dp = IterableWrapper([URL])
cache_compressed_dp = url_dp.on_disk_cache(
filepath_fn=lambda x: os.path.join(root, _PATH),
hash_dict={os.path.join(root, _PATH): MD5},
filepath_fn=_filepath_fn,
hash_dict={_filepath_fn(): MD5},
hash_type="md5",
)
cache_compressed_dp = GDriveReader(cache_compressed_dp)
Expand All @@ -260,9 +270,15 @@ def IWSLT2016(
+ ".tgz"
)

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=lambda x: inner_iwslt_tar)
def _inner_iwslt_tar_filepath_fn(_=None):
return inner_iwslt_tar

def _filter_fn(x):
return os.path.basename(inner_iwslt_tar) in x[0]

cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=_inner_iwslt_tar_filepath_fn)
cache_decompressed_dp = cache_decompressed_dp.open_files(mode="b").load_from_tar()
cache_decompressed_dp = cache_decompressed_dp.filter(lambda x: os.path.basename(inner_iwslt_tar) in x[0])
cache_decompressed_dp = cache_decompressed_dp.filter(_filter_fn)
cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True)
cache_decompressed_dp_1, cache_decompressed_dp_2 = cache_decompressed_dp.fork(num_instances=2)

Expand Down
Loading