diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index aa4e2b0ea487..5a5d233fbb4e 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -579,10 +579,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " those weights or else make sure your checkpoint file is correct." ) + empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): accepts_dtype = "dtype" in set( inspect.signature(set_module_tensor_to_device).parameters.keys() ) + + if empty_state_dict[param_name].shape != param.shape: + raise ValueError( + f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + if accepts_dtype: set_module_tensor_to_device( model, param_name, param_device, value=param, dtype=torch_dtype diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e880950a7914..6f51894e8f03 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -100,6 +100,30 @@ def test_one_request_upon_cached(self): diffusers.utils.import_utils._safetensors_available = True + def test_weight_overwrite(self): + with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context: + UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="unet", + cache_dir=tmpdirname, + in_channels=9, + ) + + # make sure that error message states what keys are missing + assert "Cannot load" in str(error_context.exception) + + with tempfile.TemporaryDirectory() as tmpdirname: + model = UNet2DConditionModel.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", + subfolder="unet", + cache_dir=tmpdirname, + in_channels=9, + low_cpu_mem_usage=False, + ignore_mismatched_sizes=True, + ) + + assert model.config.in_channels == 9 + class ModelTesterMixin: def test_from_save_pretrained(self):