Skip to content

Commit cbcd908

Browse files
eddogolapytorchmergebot
authored andcommitted
[DCP] Modify tensor saving logic in DCP (#106415)
Currently, DCP treats tensors as duplicates and only saves them on rank0. This won't work for PiPPy as PiPPy does have unique tensors across different ranks. With the current setup, we would only be saving the tensors on rank0 (coordinator rank). In this PR, we are changing to letting each rank create its own WriteItem for tensors. For the ones that does replicate across different ranks, we are handling it thru dedup_tensors(), which will dedup the replicate WriteItem so we only do the actual writing once. Pull Request resolved: #106415 Approved by: https://github.com/wz337
1 parent c913f38 commit cbcd908

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

test/distributed/checkpoint/test_planner.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939

4040
from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
41+
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
4142

4243

4344
if TEST_WITH_DEV_DBG_ASAN:
@@ -86,14 +87,22 @@ def test_local_plan(self):
8687
"st": st
8788
}
8889
plan = create_default_local_save_plan(state_dict, False)
89-
self.assertEqual(1, len(plan.items))
90+
self.assertEqual(2, len(plan.items))
9091
wi = plan.items[0]
91-
self.assertEqual(wi.index, MetadataIndex("st", [8]))
92-
self.assertEqual(wi.type, WriteItemType.SHARD)
93-
self.assertEqual(wi.tensor_data.size, st.size())
92+
self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
93+
self.assertEqual(wi.type, WriteItemType.TENSOR)
94+
self.assertEqual(wi.tensor_data.size, tensor.size())
9495
self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
95-
self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([8]))
96-
self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([8]))
96+
self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
97+
self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
98+
99+
st_wi = plan.items[1]
100+
self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
101+
self.assertEqual(st_wi.type, WriteItemType.SHARD)
102+
self.assertEqual(st_wi.tensor_data.size, st.size())
103+
self.assertEqual(st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
104+
self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
105+
self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
97106

98107
# Coordinator rank, should include replicated items as well
99108
plan = create_default_local_save_plan(state_dict, True)
@@ -124,6 +133,7 @@ def create_data(rank):
124133
return create_default_local_save_plan(state_dict, rank == 0)
125134

126135
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
136+
all_plans = dedup_tensors(all_plans)
127137
final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
128138

129139
# The default global plan updates all indexes to include hints

torch/distributed/checkpoint/default_planner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import torch
1212

1313
from torch.distributed._shard._utils import narrow_tensor_by_index
14-
from torch.distributed._shard.sharded_tensor import ShardedTensor
1514
from torch.distributed._tensor import DTensor
1615

1716

@@ -294,7 +293,7 @@ def create_default_local_save_plan(
294293
if isinstance(obj, DTensor):
295294
if obj.device_mesh.get_coordinate() is not None:
296295
requests += _create_write_items(fqn, obj)
297-
elif isinstance(obj, (ShardedTensor)) or is_coordinator:
296+
elif isinstance(obj, (torch.Tensor)) or is_coordinator:
298297
requests += _create_write_items(fqn, obj)
299298

300299
return SavePlan(requests)

0 commit comments

Comments
 (0)