Skip to content

Commit bb46efa

Browse files
committed
change the way of load checkpoint
1 parent 28f3b18 commit bb46efa

File tree

5 files changed

+130
-11
lines changed

5 files changed

+130
-11
lines changed

mindone/diffusers/models/model_loading_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@
3333

3434
import mindspore as ms
3535
from mindspore import nn, ops
36+
from mindspore.ops import Cast
3637

3738
from ...safetensors.mindspore import load as safe_load
38-
from ...safetensors.mindspore import load_file as safe_load_file
3939
from ..utils import (
4040
SAFE_WEIGHTS_INDEX_NAME,
4141
SAFETENSORS_FILE_EXTENSION,
@@ -47,6 +47,7 @@
4747
)
4848

4949
logger = logging.get_logger(__name__)
50+
cpu_cast = Cast().set_device("CPU")
5051

5152
_CLASS_REMAPPING_DICT = {
5253
"Transformer2DModel": {
@@ -97,7 +98,7 @@ def load_state_dict(
9798
if disable_mmap:
9899
return safe_load(open(checkpoint_file, "rb").read())
99100
else:
100-
return safe_load_file(checkpoint_file)
101+
return ms.load_checkpoint(checkpoint_file, format="safetensors")
101102
else:
102103
raise NotImplementedError(
103104
f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}"
@@ -140,11 +141,11 @@ def _load_state_dict_into_model(
140141
and any(module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
141142
and dtype == ms.float16
142143
):
143-
v.set_dtype(ms.float32)
144+
state_dict[k] = ms.Parameter(cpu_cast(v.data, ms.float32), name=k)
144145
else:
145-
v.set_dtype(local_state[k].dtype)
146+
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
146147
else:
147-
v.set_dtype(local_state[k].dtype)
148+
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
148149
else:
149150
pass # unexpect key keeps origin dtype
150151
cm = silence_mindspore_logger() if is_sharded else nullcontext()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import inspect
2+
from functools import wraps
3+
4+
import mindspore as ms
5+
from mindspore import mint, nn
6+
7+
SKIP_CLASSES = {nn.Dropout}
8+
# Store original __init__ for manual restore
9+
_ORIG_INITS = {}
10+
11+
12+
def patch_nn_default_dtype(dtype=ms.float32, force=False):
13+
"""
14+
Iterate over all Cells under nn and mint.nn,
15+
automatically set or force the default dtype in __init__ if supported.
16+
17+
Args:
18+
dtype (mindspore.dtype): target dtype to enforce
19+
force (bool): if True, even when user passes dtype explicitly, override it
20+
"""
21+
for module in [ms.nn, mint.nn]:
22+
for name in dir(module):
23+
attr = getattr(module, name)
24+
if inspect.isclass(attr) and issubclass(attr, nn.Cell):
25+
if attr in SKIP_CLASSES:
26+
continue # skip specified classes
27+
sig = inspect.signature(attr.__init__)
28+
if "dtype" in sig.parameters:
29+
if attr not in _ORIG_INITS:
30+
_ORIG_INITS[attr] = attr.__init__
31+
32+
_orig_init = attr.__init__
33+
34+
@wraps(_orig_init)
35+
def _new_init(self, *args, _orig_init=_orig_init, **kwargs):
36+
if force or "dtype" not in kwargs:
37+
kwargs["dtype"] = dtype
38+
return _orig_init(self, *args, **kwargs)
39+
40+
setattr(attr, "__init__", _new_init)
41+
42+
43+
def restore_nn_default_dtype():
44+
"""
45+
Manually restore the original __init__ of all patched nn / mint.nn Cells.
46+
"""
47+
for cls, orig_init in _ORIG_INITS.items():
48+
cls.__init__ = orig_init
49+
_ORIG_INITS.clear()

mindone/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
load_state_dict,
5959
split_torch_state_dict_into_shards,
6060
)
61+
from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype
6162

6263

6364
class ContextManagers:
@@ -819,7 +820,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
819820
)
820821

821822
with no_init_parameters():
823+
if mindspore_dtype is not None:
824+
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
822825
model = cls.from_config(config, **unused_kwargs)
826+
if mindspore_dtype is not None:
827+
restore_nn_default_dtype()
823828

824829
state_dict = None
825830
if not is_sharded:
@@ -874,7 +879,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
874879

875880
def to(self, dtype: Optional[ms.Type] = None):
876881
for p in self.get_parameters():
877-
p.set_dtype(dtype)
882+
if p.dtype != dtype:
883+
p.set_dtype(dtype)
878884
return self
879885

880886
def half(self):
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import inspect
2+
from functools import wraps
3+
4+
import mindspore as ms
5+
from mindspore import mint, nn
6+
7+
SKIP_CLASSES = {nn.Dropout}
8+
# Store original __init__ for manual restore
9+
_ORIG_INITS = {}
10+
11+
12+
def patch_nn_default_dtype(dtype=ms.float32, force=False):
13+
"""
14+
Iterate over all Cells under nn and mint.nn,
15+
automatically set or force the default dtype in __init__ if supported.
16+
17+
Args:
18+
dtype (mindspore.dtype): target dtype to enforce
19+
force (bool): if True, even when user passes dtype explicitly, override it
20+
"""
21+
for module in [ms.nn, mint.nn]:
22+
for name in dir(module):
23+
attr = getattr(module, name)
24+
if inspect.isclass(attr) and issubclass(attr, nn.Cell):
25+
if attr in SKIP_CLASSES:
26+
continue # skip specified classes
27+
sig = inspect.signature(attr.__init__)
28+
if "dtype" in sig.parameters:
29+
if attr not in _ORIG_INITS:
30+
_ORIG_INITS[attr] = attr.__init__
31+
32+
_orig_init = attr.__init__
33+
34+
@wraps(_orig_init)
35+
def _new_init(self, *args, _orig_init=_orig_init, **kwargs):
36+
if force or "dtype" not in kwargs:
37+
kwargs["dtype"] = dtype
38+
return _orig_init(self, *args, **kwargs)
39+
40+
setattr(attr, "__init__", _new_init)
41+
42+
43+
def restore_nn_default_dtype():
44+
"""
45+
Manually restore the original __init__ of all patched nn / mint.nn Cells.
46+
"""
47+
for cls, orig_init in _ORIG_INITS.items():
48+
cls.__init__ = orig_init
49+
_ORIG_INITS.clear()

mindone/transformers/modeling_utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
import mindspore as ms
6161
from mindspore import Parameter, Tensor, mint, nn, ops
6262
from mindspore.nn import CrossEntropyLoss, Identity
63+
from mindspore.nn.utils import no_init_parameters
64+
from mindspore.ops import Cast
6365

6466
from .activations import get_activation
6567
from .generation.utils import GenerationMixin
@@ -77,6 +79,7 @@
7779
prune_linear_layer,
7880
)
7981
from .modeling_attn_mask_utils import dtype_to_min
82+
from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype
8083
from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available
8184

8285
if is_safetensors_available():
@@ -86,6 +89,7 @@
8689
from mindone.safetensors.mindspore import save_file as safe_save_file
8790

8891
logger = logging.get_logger(__name__)
92+
cpu_cast = Cast().set_device("CPU")
8993

9094
_init_weights = True
9195

@@ -349,7 +353,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
349353
local_state = {v.name: v for k, v in model_to_load.parameters_and_names()}
350354
for k, v in state_dict.items():
351355
if k in local_state:
352-
v.set_dtype(local_state[k].dtype)
356+
state_dict[k] = ms.Parameter(cpu_cast(v, local_state[k].dtype), name=k)
353357
else:
354358
pass # unexpect key keeps origin dtype
355359
cm = silence_mindspore_logger() if is_sharded else nullcontext()
@@ -387,7 +391,8 @@ def to(self, dtype: Optional[ms.Type] = None):
387391
# Now we use `Parameter` and `Parameter.set_dtype()` instead.
388392

389393
for p in self.get_parameters():
390-
p.set_dtype(dtype)
394+
if p.dtype != dtype:
395+
p.set_dtype(dtype)
391396
return self
392397

393398
def float(self):
@@ -977,8 +982,12 @@ def _from_config(cls, config, **kwargs):
977982
use_flash_attention_2=use_flash_attention_2,
978983
mindspore_dtype=mindspore_dtype,
979984
)
980-
981-
model = cls(config, **kwargs)
985+
with no_init_parameters():
986+
if mindspore_dtype is not None:
987+
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
988+
model = cls(config, **kwargs)
989+
if mindspore_dtype is not None:
990+
restore_nn_default_dtype()
982991

983992
# We cannot set default mindspore dtype. So we need to cast model weights after creating.
984993
if mindspore_dtype is not None:
@@ -2348,7 +2357,12 @@ def from_pretrained(
23482357
config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
23492358
)
23502359

2351-
model = cls(config, *model_args, **model_kwargs)
2360+
with no_init_parameters():
2361+
if mindspore_dtype is not None:
2362+
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
2363+
model = cls(config, *model_args, **model_kwargs)
2364+
if mindspore_dtype is not None:
2365+
restore_nn_default_dtype()
23522366

23532367
# Make sure to tie the weights correctly
23542368
model.tie_weights()

0 commit comments

Comments
 (0)