Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions mindone/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@

import mindspore as ms
from mindspore import nn, ops
from mindspore.ops import Cast

from ...safetensors.mindspore import load as safe_load
from ...safetensors.mindspore import load_file as safe_load_file
from ..utils import (
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
Expand All @@ -47,6 +47,7 @@
)

logger = logging.get_logger(__name__)
cpu_cast = Cast().set_device("CPU")

_CLASS_REMAPPING_DICT = {
"Transformer2DModel": {
Expand Down Expand Up @@ -97,7 +98,7 @@ def load_state_dict(
if disable_mmap:
return safe_load(open(checkpoint_file, "rb").read())
else:
return safe_load_file(checkpoint_file)
return ms.load_checkpoint(checkpoint_file, format="safetensors")
else:
raise NotImplementedError(
f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}"
Expand Down Expand Up @@ -140,11 +141,11 @@ def _load_state_dict_into_model(
and any(module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
and dtype == ms.float16
):
v.set_dtype(ms.float32)
state_dict[k] = ms.Parameter(cpu_cast(v.data, ms.float32), name=k)
else:
v.set_dtype(local_state[k].dtype)
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
else:
v.set_dtype(local_state[k].dtype)
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
else:
pass # unexpect key keeps origin dtype
cm = silence_mindspore_logger() if is_sharded else nullcontext()
Expand Down
49 changes: 49 additions & 0 deletions mindone/diffusers/models/modeling_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import inspect
from functools import wraps

import mindspore as ms
from mindspore import mint, nn

SKIP_CLASSES = {nn.Dropout}
# Store original __init__ for manual restore
_ORIG_INITS = {}


def patch_nn_default_dtype(dtype=ms.float32, force=False):
"""
Iterate over all Cells under nn and mint.nn,
automatically set or force the default dtype in __init__ if supported.

Args:
dtype (mindspore.dtype): target dtype to enforce
force (bool): if True, even when user passes dtype explicitly, override it
"""
for module in [ms.nn, mint.nn]:
for name in dir(module):
attr = getattr(module, name)
if inspect.isclass(attr) and issubclass(attr, nn.Cell):
if attr in SKIP_CLASSES:
continue # skip specified classes
sig = inspect.signature(attr.__init__)
if "dtype" in sig.parameters:
if attr not in _ORIG_INITS:
_ORIG_INITS[attr] = attr.__init__

_orig_init = attr.__init__

@wraps(_orig_init)
def _new_init(self, *args, _orig_init=_orig_init, **kwargs):
if force or "dtype" not in kwargs:
kwargs["dtype"] = dtype
return _orig_init(self, *args, **kwargs)

setattr(attr, "__init__", _new_init)


def restore_nn_default_dtype():
"""
Manually restore the original __init__ of all patched nn / mint.nn Cells.
"""
for cls, orig_init in _ORIG_INITS.items():
cls.__init__ = orig_init
_ORIG_INITS.clear()
8 changes: 7 additions & 1 deletion mindone/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
load_state_dict,
split_torch_state_dict_into_shards,
)
from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype


class ContextManagers:
Expand Down Expand Up @@ -819,7 +820,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
)

with no_init_parameters():
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
model = cls.from_config(config, **unused_kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()

state_dict = None
if not is_sharded:
Expand Down Expand Up @@ -874,7 +879,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

def to(self, dtype: Optional[ms.Type] = None):
for p in self.get_parameters():
p.set_dtype(dtype)
if p.dtype != dtype:
p.set_dtype(dtype)
return self

def half(self):
Expand Down
49 changes: 49 additions & 0 deletions mindone/transformers/modeling_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import inspect
from functools import wraps

import mindspore as ms
from mindspore import mint, nn

SKIP_CLASSES = {nn.Dropout}
# Store original __init__ for manual restore
_ORIG_INITS = {}


def patch_nn_default_dtype(dtype=ms.float32, force=False):
"""
Iterate over all Cells under nn and mint.nn,
automatically set or force the default dtype in __init__ if supported.

Args:
dtype (mindspore.dtype): target dtype to enforce
force (bool): if True, even when user passes dtype explicitly, override it
"""
for module in [ms.nn, mint.nn]:
for name in dir(module):
attr = getattr(module, name)
if inspect.isclass(attr) and issubclass(attr, nn.Cell):
if attr in SKIP_CLASSES:
continue # skip specified classes
sig = inspect.signature(attr.__init__)
if "dtype" in sig.parameters:
if attr not in _ORIG_INITS:
_ORIG_INITS[attr] = attr.__init__

_orig_init = attr.__init__

@wraps(_orig_init)
def _new_init(self, *args, _orig_init=_orig_init, **kwargs):
if force or "dtype" not in kwargs:
kwargs["dtype"] = dtype
return _orig_init(self, *args, **kwargs)

setattr(attr, "__init__", _new_init)


def restore_nn_default_dtype():
"""
Manually restore the original __init__ of all patched nn / mint.nn Cells.
"""
for cls, orig_init in _ORIG_INITS.items():
cls.__init__ = orig_init
_ORIG_INITS.clear()
24 changes: 19 additions & 5 deletions mindone/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
import mindspore as ms
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.nn import CrossEntropyLoss, Identity
from mindspore.nn.utils import no_init_parameters
from mindspore.ops import Cast

from .activations import get_activation
from .generation.utils import GenerationMixin
Expand All @@ -77,6 +79,7 @@
prune_linear_layer,
)
from .modeling_attn_mask_utils import dtype_to_min
from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype
from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available

if is_safetensors_available():
Expand All @@ -86,6 +89,7 @@
from mindone.safetensors.mindspore import save_file as safe_save_file

logger = logging.get_logger(__name__)
cpu_cast = Cast().set_device("CPU")

_init_weights = True

Expand Down Expand Up @@ -349,7 +353,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
local_state = {v.name: v for k, v in model_to_load.parameters_and_names()}
for k, v in state_dict.items():
if k in local_state:
v.set_dtype(local_state[k].dtype)
state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k)
else:
pass # unexpect key keeps origin dtype
cm = silence_mindspore_logger() if is_sharded else nullcontext()
Expand Down Expand Up @@ -387,7 +391,8 @@ def to(self, dtype: Optional[ms.Type] = None):
# Now we use `Parameter` and `Parameter.set_dtype()` instead.

for p in self.get_parameters():
p.set_dtype(dtype)
if p.dtype != dtype:
p.set_dtype(dtype)
return self

def float(self):
Expand Down Expand Up @@ -977,8 +982,12 @@ def _from_config(cls, config, **kwargs):
use_flash_attention_2=use_flash_attention_2,
mindspore_dtype=mindspore_dtype,
)

model = cls(config, **kwargs)
with no_init_parameters():
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
model = cls(config, **kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()

# We cannot set default mindspore dtype. So we need to cast model weights after creating.
if mindspore_dtype is not None:
Expand Down Expand Up @@ -2348,7 +2357,12 @@ def from_pretrained(
config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
)

model = cls(config, *model_args, **model_kwargs)
with no_init_parameters():
if mindspore_dtype is not None:
patch_nn_default_dtype(dtype=mindspore_dtype, force=True)
model = cls(config, *model_args, **model_kwargs)
if mindspore_dtype is not None:
restore_nn_default_dtype()

# Make sure to tie the weights correctly
model.tie_weights()
Expand Down