diff --git a/mindone/diffusers/models/model_loading_utils.py b/mindone/diffusers/models/model_loading_utils.py index 34bc6059db..0bcdc6a983 100644 --- a/mindone/diffusers/models/model_loading_utils.py +++ b/mindone/diffusers/models/model_loading_utils.py @@ -33,9 +33,9 @@ import mindspore as ms from mindspore import nn, ops +from mindspore.ops import Cast from ...safetensors.mindspore import load as safe_load -from ...safetensors.mindspore import load_file as safe_load_file from ..utils import ( SAFE_WEIGHTS_INDEX_NAME, SAFETENSORS_FILE_EXTENSION, @@ -47,6 +47,7 @@ ) logger = logging.get_logger(__name__) +cpu_cast = Cast().set_device("CPU") _CLASS_REMAPPING_DICT = { "Transformer2DModel": { @@ -97,7 +98,7 @@ def load_state_dict( if disable_mmap: return safe_load(open(checkpoint_file, "rb").read()) else: - return safe_load_file(checkpoint_file) + return ms.load_checkpoint(checkpoint_file, format="safetensors") else: raise NotImplementedError( f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}" @@ -140,11 +141,11 @@ def _load_state_dict_into_model( and any(module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules) and dtype == ms.float16 ): - v.set_dtype(ms.float32) + state_dict[k] = ms.Parameter(cpu_cast(v.data, ms.float32), name=k) else: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: pass # unexpect key keeps origin dtype cm = silence_mindspore_logger() if is_sharded else nullcontext() diff --git a/mindone/diffusers/models/modeling_patch.py b/mindone/diffusers/models/modeling_patch.py new file mode 100644 index 0000000000..cd9b97c2d4 --- /dev/null +++ b/mindone/diffusers/models/modeling_patch.py @@ -0,0 +1,49 @@ +import inspect +from functools import wraps + +import mindspore as ms +from mindspore import mint, nn + +SKIP_CLASSES = {nn.Dropout} +# Store original __init__ for manual restore +_ORIG_INITS = {} + + +def patch_nn_default_dtype(dtype=ms.float32, force=False): + """ + Iterate over all Cells under nn and mint.nn, + automatically set or force the default dtype in __init__ if supported. + + Args: + dtype (mindspore.dtype): target dtype to enforce + force (bool): if True, even when user passes dtype explicitly, override it + """ + for module in [ms.nn, mint.nn]: + for name in dir(module): + attr = getattr(module, name) + if inspect.isclass(attr) and issubclass(attr, nn.Cell): + if attr in SKIP_CLASSES: + continue # skip specified classes + sig = inspect.signature(attr.__init__) + if "dtype" in sig.parameters: + if attr not in _ORIG_INITS: + _ORIG_INITS[attr] = attr.__init__ + + _orig_init = attr.__init__ + + @wraps(_orig_init) + def _new_init(self, *args, _orig_init=_orig_init, **kwargs): + if force or "dtype" not in kwargs: + kwargs["dtype"] = dtype + return _orig_init(self, *args, **kwargs) + + setattr(attr, "__init__", _new_init) + + +def restore_nn_default_dtype(): + """ + Manually restore the original __init__ of all patched nn / mint.nn Cells. + """ + for cls, orig_init in _ORIG_INITS.items(): + cls.__init__ = orig_init + _ORIG_INITS.clear() diff --git a/mindone/diffusers/models/modeling_utils.py b/mindone/diffusers/models/modeling_utils.py index a60071ceb6..20f48d9246 100644 --- a/mindone/diffusers/models/modeling_utils.py +++ b/mindone/diffusers/models/modeling_utils.py @@ -58,6 +58,7 @@ load_state_dict, split_torch_state_dict_into_shards, ) +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype class ContextManagers: @@ -819,7 +820,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) model = cls.from_config(config, **unused_kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() state_dict = None if not is_sharded: @@ -874,7 +879,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P def to(self, dtype: Optional[ms.Type] = None): for p in self.get_parameters(): - p.set_dtype(dtype) + if p.dtype != dtype: + p.set_dtype(dtype) return self def half(self): diff --git a/mindone/transformers/modeling_patch.py b/mindone/transformers/modeling_patch.py new file mode 100644 index 0000000000..cd9b97c2d4 --- /dev/null +++ b/mindone/transformers/modeling_patch.py @@ -0,0 +1,49 @@ +import inspect +from functools import wraps + +import mindspore as ms +from mindspore import mint, nn + +SKIP_CLASSES = {nn.Dropout} +# Store original __init__ for manual restore +_ORIG_INITS = {} + + +def patch_nn_default_dtype(dtype=ms.float32, force=False): + """ + Iterate over all Cells under nn and mint.nn, + automatically set or force the default dtype in __init__ if supported. + + Args: + dtype (mindspore.dtype): target dtype to enforce + force (bool): if True, even when user passes dtype explicitly, override it + """ + for module in [ms.nn, mint.nn]: + for name in dir(module): + attr = getattr(module, name) + if inspect.isclass(attr) and issubclass(attr, nn.Cell): + if attr in SKIP_CLASSES: + continue # skip specified classes + sig = inspect.signature(attr.__init__) + if "dtype" in sig.parameters: + if attr not in _ORIG_INITS: + _ORIG_INITS[attr] = attr.__init__ + + _orig_init = attr.__init__ + + @wraps(_orig_init) + def _new_init(self, *args, _orig_init=_orig_init, **kwargs): + if force or "dtype" not in kwargs: + kwargs["dtype"] = dtype + return _orig_init(self, *args, **kwargs) + + setattr(attr, "__init__", _new_init) + + +def restore_nn_default_dtype(): + """ + Manually restore the original __init__ of all patched nn / mint.nn Cells. + """ + for cls, orig_init in _ORIG_INITS.items(): + cls.__init__ = orig_init + _ORIG_INITS.clear() diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index c62042197c..bf348938eb 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -60,6 +60,8 @@ import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.nn import CrossEntropyLoss, Identity +from mindspore.nn.utils import no_init_parameters +from mindspore.ops import Cast from .activations import get_activation from .generation.utils import GenerationMixin @@ -77,6 +79,7 @@ prune_linear_layer, ) from .modeling_attn_mask_utils import dtype_to_min +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype from .utils.import_utils import is_flash_attn_2_available, is_sdpa_available if is_safetensors_available(): @@ -86,6 +89,7 @@ from mindone.safetensors.mindspore import save_file as safe_save_file logger = logging.get_logger(__name__) +cpu_cast = Cast().set_device("CPU") _init_weights = True @@ -349,7 +353,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar local_state = {v.name: v for k, v in model_to_load.parameters_and_names()} for k, v in state_dict.items(): if k in local_state: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: pass # unexpect key keeps origin dtype cm = silence_mindspore_logger() if is_sharded else nullcontext() @@ -387,7 +391,8 @@ def to(self, dtype: Optional[ms.Type] = None): # Now we use `Parameter` and `Parameter.set_dtype()` instead. for p in self.get_parameters(): - p.set_dtype(dtype) + if p.dtype != dtype: + p.set_dtype(dtype) return self def float(self): @@ -977,8 +982,12 @@ def _from_config(cls, config, **kwargs): use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype, ) - - model = cls(config, **kwargs) + with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) + model = cls(config, **kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() # We cannot set default mindspore dtype. So we need to cast model weights after creating. if mindspore_dtype is not None: @@ -2348,7 +2357,12 @@ def from_pretrained( config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype ) - model = cls(config, *model_args, **model_kwargs) + with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) + model = cls(config, *model_args, **model_kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() # Make sure to tie the weights correctly model.tie_weights()