|
38 | 38 | )
|
39 | 39 |
|
40 | 40 | from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
|
| 41 | +from torch.distributed.checkpoint._dedup_tensors import dedup_tensors |
41 | 42 |
|
42 | 43 |
|
43 | 44 | if TEST_WITH_DEV_DBG_ASAN:
|
@@ -86,14 +87,22 @@ def test_local_plan(self):
|
86 | 87 | "st": st
|
87 | 88 | }
|
88 | 89 | plan = create_default_local_save_plan(state_dict, False)
|
89 |
| - self.assertEqual(1, len(plan.items)) |
| 90 | + self.assertEqual(2, len(plan.items)) |
90 | 91 | wi = plan.items[0]
|
91 |
| - self.assertEqual(wi.index, MetadataIndex("st", [8])) |
92 |
| - self.assertEqual(wi.type, WriteItemType.SHARD) |
93 |
| - self.assertEqual(wi.tensor_data.size, st.size()) |
| 92 | + self.assertEqual(wi.index, MetadataIndex("tensor", [0])) |
| 93 | + self.assertEqual(wi.type, WriteItemType.TENSOR) |
| 94 | + self.assertEqual(wi.tensor_data.size, tensor.size()) |
94 | 95 | self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
|
95 |
| - self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([8])) |
96 |
| - self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([8])) |
| 96 | + self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0])) |
| 97 | + self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10])) |
| 98 | + |
| 99 | + st_wi = plan.items[1] |
| 100 | + self.assertEqual(st_wi.index, MetadataIndex("st", [8])) |
| 101 | + self.assertEqual(st_wi.type, WriteItemType.SHARD) |
| 102 | + self.assertEqual(st_wi.tensor_data.size, st.size()) |
| 103 | + self.assertEqual(st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1))) |
| 104 | + self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8])) |
| 105 | + self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8])) |
97 | 106 |
|
98 | 107 | # Coordinator rank, should include replicated items as well
|
99 | 108 | plan = create_default_local_save_plan(state_dict, True)
|
@@ -124,6 +133,7 @@ def create_data(rank):
|
124 | 133 | return create_default_local_save_plan(state_dict, rank == 0)
|
125 | 134 |
|
126 | 135 | all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
|
| 136 | + all_plans = dedup_tensors(all_plans) |
127 | 137 | final_plans, metadata = create_default_global_save_plan(all_plans=all_plans)
|
128 | 138 |
|
129 | 139 | # The default global plan updates all indexes to include hints
|
|
0 commit comments