Skip to content

Commit e4e8da8

Browse files
fjoswsoumith
authored andcommitted
fix: fixed local device name in multinode example.
1 parent d8456a3 commit e4e8da8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

distributed/ddp-tutorial-series/multinode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(
3737
self.model = DDP(self.model, device_ids=[self.local_rank])
3838

3939
def _load_snapshot(self, snapshot_path):
40-
loc = f"cuda:{self.gpu_id}"
40+
loc = f"cuda:{self.local_rank}"
4141
snapshot = torch.load(snapshot_path, map_location=loc)
4242
self.model.load_state_dict(snapshot["MODEL_STATE"])
4343
self.epochs_run = snapshot["EPOCHS_RUN"]

0 commit comments

Comments
 (0)