File tree 1 file changed +9
-3
lines changed
torch/distributed/checkpoint 1 file changed +9
-3
lines changed Original file line number Diff line number Diff line change 22
22
TensorStorageMetadata ,
23
23
ChunkStorageMetadata ,
24
24
)
25
+ from torch .distributed .distributed_c10d import _get_default_group
26
+ from torch .distributed .fsdp ._shard_utils import _create_chunk_sharded_tensor
25
27
from torch .distributed .checkpoint .planner_helpers import (
26
28
create_read_items_for_chunk_list ,
27
29
_create_read_items ,
32
34
from torch .distributed .checkpoint .default_planner import (
33
35
DefaultLoadPlanner ,
34
36
)
35
- from torch . distributed . _shard . api import _shard_tensor
37
+
36
38
from torch .distributed .checkpoint .planner import LoadPlanner
37
39
38
40
from torch .distributed .checkpoint ._nested_dict import unflatten_state_dict
@@ -293,8 +295,12 @@ def load_sharded_optimizer_state_dict(
293
295
if value .size .numel () == 1 :
294
296
state_dict [key ] = _alloc_tensor (value .properties , value .size , dp_pg_device_type )
295
297
elif dp_pg is None :
296
- state_dict [key ] = _shard_tensor (
297
- _alloc_tensor (value .properties , value .size , dp_pg_device_type ), sharding_spec
298
+ state_dict [key ] = _create_chunk_sharded_tensor (
299
+ _alloc_tensor (value .properties , value .size , dp_pg_device_type ),
300
+ rank = dist .get_rank (),
301
+ world_size = dist .get_world_size (),
302
+ num_devices_per_node = device_module .device_count (),
303
+ pg = _get_default_group (),
298
304
)
299
305
else :
300
306
spec_key = key_path [2 ]
You can’t perform that action at this time.
0 commit comments