Skip to content

Commit 42d9501

Browse files
[Init] Make sure shape mismatches are caught early (#2847)
Improve init
1 parent 81125d8 commit 42d9501

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

src/diffusers/models/modeling_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
579579
" those weights or else make sure your checkpoint file is correct."
580580
)
581581

582+
empty_state_dict = model.state_dict()
582583
for param_name, param in state_dict.items():
583584
accepts_dtype = "dtype" in set(
584585
inspect.signature(set_module_tensor_to_device).parameters.keys()
585586
)
587+
588+
if empty_state_dict[param_name].shape != param.shape:
589+
raise ValueError(
590+
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
591+
)
592+
586593
if accepts_dtype:
587594
set_module_tensor_to_device(
588595
model, param_name, param_device, value=param, dtype=torch_dtype

tests/test_modeling_common.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,30 @@ def test_one_request_upon_cached(self):
100100

101101
diffusers.utils.import_utils._safetensors_available = True
102102

103+
def test_weight_overwrite(self):
104+
with tempfile.TemporaryDirectory() as tmpdirname, self.assertRaises(ValueError) as error_context:
105+
UNet2DConditionModel.from_pretrained(
106+
"hf-internal-testing/tiny-stable-diffusion-torch",
107+
subfolder="unet",
108+
cache_dir=tmpdirname,
109+
in_channels=9,
110+
)
111+
112+
# make sure that error message states what keys are missing
113+
assert "Cannot load" in str(error_context.exception)
114+
115+
with tempfile.TemporaryDirectory() as tmpdirname:
116+
model = UNet2DConditionModel.from_pretrained(
117+
"hf-internal-testing/tiny-stable-diffusion-torch",
118+
subfolder="unet",
119+
cache_dir=tmpdirname,
120+
in_channels=9,
121+
low_cpu_mem_usage=False,
122+
ignore_mismatched_sizes=True,
123+
)
124+
125+
assert model.config.in_channels == 9
126+
103127

104128
class ModelTesterMixin:
105129
def test_from_save_pretrained(self):

0 commit comments

Comments
 (0)