diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index aab4d16172d9..c56fbf588bf0 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -517,6 +517,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + if torch_dtype is not None: + deprecate("torch_dtype", "0.30.0", "Using `torch_dtype` is depcrecated. Use `dtype`, instead.") + + dtype_kwarg = kwargs.pop("dtype", None) + + if torch_dtype is not None and dtype_kwarg is not None: + raise ValueError( + "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`." + ) + + dtype = torch_dtype or dtype_kwarg + allow_pickle = False if use_safetensors is None: use_safetensors = True @@ -670,7 +682,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model, state_dict, device=param_device, - dtype=torch_dtype, + dtype=dtype, model_name_or_path=pretrained_model_name_or_path, ) @@ -755,12 +767,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "error_msgs": error_msgs, } - if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + if dtype is not None and not isinstance(dtype, torch.dtype): raise ValueError( - f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + f"{dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." ) - elif torch_dtype is not None: - model = model.to(torch_dtype) + elif dtype is not None: + model = model.to(dtype) model.register_to_config(_name_or_path=pretrained_model_name_or_path) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 769fcd2e832a..d51019e2ea84 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -1078,6 +1078,18 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + if torch_dtype is not None: + deprecate("torch_dtype", "0.30.0", "Using `torch_dtype` is depcrecated. Use `dtype`, instead.") + + dtype_kwarg = kwargs.pop("dtype", None) + + if torch_dtype is not None and dtype_kwarg is not None: + raise ValueError( + "You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`." + ) + + dtype = torch_dtype or dtype_kwarg + # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): @@ -1268,7 +1280,7 @@ def load_module(name, value): pipelines=pipelines, is_pipeline_module=is_pipeline_module, pipeline_class=pipeline_class, - torch_dtype=torch_dtype, + torch_dtype=dtype, provider=provider, sess_options=sess_options, device_map=device_map, @@ -1300,7 +1312,7 @@ def load_module(name, value): "local_files_only": local_files_only, "token": token, "revision": revision, - "torch_dtype": torch_dtype, + "torch_dtype": dtype, "custom_pipeline": custom_pipeline, "custom_revision": custom_revision, "provider": provider,