diff --git a/examples/tutorials/sst2_classification_non_distributed.py b/examples/tutorials/sst2_classification_non_distributed.py index 983735ef35..e3d1f77133 100644 --- a/examples/tutorials/sst2_classification_non_distributed.py +++ b/examples/tutorials/sst2_classification_non_distributed.py @@ -116,14 +116,14 @@ def apply_transform(x): # # :: # -# def batch_transform(x): -# return {"token_ids": text_transform(x["text"]), "target": x["label"]} +# def batch_transform(x): +# return {"token_ids": text_transform(x["text"]), "target": x["label"]} # # -# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"]) -# train_datapipe = train_datapipe.map(lambda x: batch_transform) -# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"]) -# dev_datapipe = dev_datapipe.map(lambda x: batch_transform) +# train_datapipe = train_datapipe.batch(batch_size).rows2columnar(["text", "label"]) +# train_datapipe = train_datapipe.map(lambda x: batch_transform) +# dev_datapipe = dev_datapipe.batch(batch_size).rows2columnar(["text", "label"]) +# dev_datapipe = dev_datapipe.map(lambda x: batch_transform) # ######################################################################