Skip to content

Commit 4c55dc5

Browse files
wz337atalman
andauthored
remove _shard_tensor() call (pytorch#111687)
Co-authored-by: Andrey Talman <[email protected]>
1 parent f58669b commit 4c55dc5

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

torch/distributed/checkpoint/optimizer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
TensorStorageMetadata,
2323
ChunkStorageMetadata,
2424
)
25+
from torch.distributed.distributed_c10d import _get_default_group
26+
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
2527
from torch.distributed.checkpoint.planner_helpers import (
2628
create_read_items_for_chunk_list,
2729
_create_read_items,
@@ -32,7 +34,7 @@
3234
from torch.distributed.checkpoint.default_planner import (
3335
DefaultLoadPlanner,
3436
)
35-
from torch.distributed._shard.api import _shard_tensor
37+
3638
from torch.distributed.checkpoint.planner import LoadPlanner
3739

3840
from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
@@ -293,8 +295,12 @@ def load_sharded_optimizer_state_dict(
293295
if value.size.numel() == 1:
294296
state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
295297
elif dp_pg is None:
296-
state_dict[key] = _shard_tensor(
297-
_alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec
298+
state_dict[key] = _create_chunk_sharded_tensor(
299+
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
300+
rank=dist.get_rank(),
301+
world_size=dist.get_world_size(),
302+
num_devices_per_node=device_module.device_count(),
303+
pg=_get_default_group(),
298304
)
299305
else:
300306
spec_key = key_path[2]

0 commit comments

Comments
 (0)