|
24 | 24 | from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
25 | 25 | from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
26 | 26 |
|
| 27 | +from .distributed_checkpoint_utils import ( |
| 28 | + create_model_metadata, |
| 29 | + is_pytorch_model_meta_dist_file, |
| 30 | + load_dist_model, |
| 31 | + save_dist_sharded_model, |
| 32 | + save_dist_unshard_model, |
| 33 | +) |
27 | 34 | from .general_checkpoint_io import GeneralCheckpointIO
|
28 | 35 | from .index_file import CheckpointIndexFile
|
29 | 36 | from .utils import (
|
|
47 | 54 | sharded_optimizer_loading_epilogue,
|
48 | 55 | )
|
49 | 56 |
|
50 |
| -from .distributed_checkpoint_utils import ( |
51 |
| - save_dist_sharded_model, |
52 |
| - save_dist_unshard_model, |
53 |
| - load_dist_model, |
54 |
| - is_pytorch_model_meta_dist_file, |
55 |
| - create_model_metadata |
56 |
| -) |
57 |
| - |
58 | 57 | try:
|
59 | 58 | from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
60 | 59 | except ImportError:
|
@@ -244,9 +243,19 @@ def save_sharded_model(
|
244 | 243 | return
|
245 | 244 | dist_id = self.tp_size * self.pp_rank + self.tp_rank
|
246 | 245 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
|
247 |
| - save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) |
| 246 | + save_dist_sharded_model( |
| 247 | + model=model, |
| 248 | + model_metadata=model_metadata, |
| 249 | + checkpoint=checkpoint, |
| 250 | + prefix=prefix, |
| 251 | + size_per_shard=size_per_shard, |
| 252 | + use_safetensors=use_safetensors, |
| 253 | + use_async=use_async, |
| 254 | + dist_id=dist_id, |
| 255 | + pinned_state_dicts=self.pinned_state_dicts, |
| 256 | + ) |
248 | 257 | return
|
249 |
| - |
| 258 | + |
250 | 259 | model = model.unwrap()
|
251 | 260 |
|
252 | 261 | if os.path.isfile(checkpoint):
|
@@ -394,9 +403,15 @@ def load_sharded_model(
|
394 | 403 |
|
395 | 404 | if is_pytorch_model_meta_dist_file(checkpoint_index_file):
|
396 | 405 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
|
397 |
| - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) |
| 406 | + load_dist_model( |
| 407 | + model=model, |
| 408 | + model_metadata=model_metadata, |
| 409 | + checkpoint=checkpoint_index_file, |
| 410 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 411 | + num_threads=num_threads, |
| 412 | + ) |
398 | 413 | return
|
399 |
| - |
| 414 | + |
400 | 415 | model_before_wrapping = model # backup for model before wrapping
|
401 | 416 | model = model.unwrap()
|
402 | 417 |
|
@@ -792,9 +807,17 @@ def save_unsharded_model(
|
792 | 807 | if self.dp_rank != 0 and self.sp_rank != 0:
|
793 | 808 | return
|
794 | 809 | dist_id = self.tp_size * self.pp_rank + self.tp_rank
|
795 |
| - save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) |
| 810 | + save_dist_unshard_model( |
| 811 | + model=model, |
| 812 | + model_metadata=model_metadata, |
| 813 | + checkpoint=checkpoint, |
| 814 | + use_safetensors=use_safetensors, |
| 815 | + use_async=use_async, |
| 816 | + dist_id=dist_id, |
| 817 | + pinned_state_dicts=self.pinned_state_dicts, |
| 818 | + ) |
796 | 819 | return
|
797 |
| - |
| 820 | + |
798 | 821 | model = model.unwrap()
|
799 | 822 | if self.dp_rank != 0:
|
800 | 823 | return
|
@@ -867,7 +890,13 @@ def load_unsharded_model(
|
867 | 890 | for filename in os.listdir(checkpoint):
|
868 | 891 | if is_pytorch_model_meta_dist_file(filename):
|
869 | 892 | model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
|
870 |
| - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) |
| 893 | + load_dist_model( |
| 894 | + model=model, |
| 895 | + model_metadata=model_metadata, |
| 896 | + checkpoint=checkpoint, |
| 897 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 898 | + num_threads=num_threads, |
| 899 | + ) |
871 | 900 | return
|
872 | 901 |
|
873 | 902 | strict = False
|
@@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state(
|
1099 | 1128 | dist.all_gather(gather_tensor, v, group=dp_group)
|
1100 | 1129 | v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
1101 | 1130 |
|
1102 |
| - |
1103 | 1131 | # Then gather TP shards.
|
1104 | 1132 | partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
|
1105 | 1133 | if partition_dim is not None:
|
|
0 commit comments