|
60 | 60 | import mindspore as ms
|
61 | 61 | from mindspore import Parameter, Tensor, mint, nn, ops
|
62 | 62 | from mindspore.nn import CrossEntropyLoss, Identity
|
| 63 | +from mindspore.nn.utils import no_init_parameters |
| 64 | +from mindspore.ops import Cast |
63 | 65 |
|
64 | 66 | from .activations import get_activation
|
65 | 67 | from .generation.utils import GenerationMixin
|
|
77 | 79 | prune_linear_layer,
|
78 | 80 | )
|
79 | 81 | from .modeling_attn_mask_utils import dtype_to_min
|
| 82 | +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype |
80 | 83 | from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available
|
81 | 84 |
|
82 | 85 | if is_safetensors_available():
|
|
86 | 89 | from mindone.safetensors.mindspore import save_file as safe_save_file
|
87 | 90 |
|
88 | 91 | logger = logging.get_logger(__name__)
|
| 92 | +cpu_cast = Cast().set_device("CPU") |
89 | 93 |
|
90 | 94 | _init_weights = True
|
91 | 95 |
|
@@ -349,7 +353,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
|
349 | 353 | local_state = {v.name: v for k, v in model_to_load.parameters_and_names()}
|
350 | 354 | for k, v in state_dict.items():
|
351 | 355 | if k in local_state:
|
352 |
| - v.set_dtype(local_state[k].dtype) |
| 356 | + state_dict[k] = ms.Parameter(cpu_cast(v, local_state[k].dtype), name=k) |
353 | 357 | else:
|
354 | 358 | pass # unexpect key keeps origin dtype
|
355 | 359 | cm = silence_mindspore_logger() if is_sharded else nullcontext()
|
@@ -387,7 +391,8 @@ def to(self, dtype: Optional[ms.Type] = None):
|
387 | 391 | # Now we use `Parameter` and `Parameter.set_dtype()` instead.
|
388 | 392 |
|
389 | 393 | for p in self.get_parameters():
|
390 |
| - p.set_dtype(dtype) |
| 394 | + if p.dtype != dtype: |
| 395 | + p.set_dtype(dtype) |
391 | 396 | return self
|
392 | 397 |
|
393 | 398 | def float(self):
|
@@ -977,8 +982,12 @@ def _from_config(cls, config, **kwargs):
|
977 | 982 | use_flash_attention_2=use_flash_attention_2,
|
978 | 983 | mindspore_dtype=mindspore_dtype,
|
979 | 984 | )
|
980 |
| - |
981 |
| - model = cls(config, **kwargs) |
| 985 | + with no_init_parameters(): |
| 986 | + if mindspore_dtype is not None: |
| 987 | + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) |
| 988 | + model = cls(config, **kwargs) |
| 989 | + if mindspore_dtype is not None: |
| 990 | + restore_nn_default_dtype() |
982 | 991 |
|
983 | 992 | # We cannot set default mindspore dtype. So we need to cast model weights after creating.
|
984 | 993 | if mindspore_dtype is not None:
|
@@ -2348,7 +2357,12 @@ def from_pretrained(
|
2348 | 2357 | config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
|
2349 | 2358 | )
|
2350 | 2359 |
|
2351 |
| - model = cls(config, *model_args, **model_kwargs) |
| 2360 | + with no_init_parameters(): |
| 2361 | + if mindspore_dtype is not None: |
| 2362 | + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) |
| 2363 | + model = cls(config, *model_args, **model_kwargs) |
| 2364 | + if mindspore_dtype is not None: |
| 2365 | + restore_nn_default_dtype() |
2352 | 2366 |
|
2353 | 2367 | # Make sure to tie the weights correctly
|
2354 | 2368 | model.tie_weights()
|
|
0 commit comments