Skip to content

Commit e55687e

Browse files
unet check length inputs (#2327)
* unet check length input * prep test file for changes * correct all tests * clean up --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 9e8ee2a commit e55687e

File tree

7 files changed

+66
-9
lines changed

7 files changed

+66
-9
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
mid_block_scale_factor: float = 1,
9595
downsample_padding: int = 1,
9696
act_fn: str = "silu",
97-
attention_head_dim: int = 8,
97+
attention_head_dim: Optional[int] = 8,
9898
norm_num_groups: int = 32,
9999
norm_eps: float = 1e-5,
100100
resnet_time_scale_shift: str = "default",
@@ -107,6 +107,17 @@ def __init__(
107107
self.sample_size = sample_size
108108
time_embed_dim = block_out_channels[0] * 4
109109

110+
# Check inputs
111+
if len(down_block_types) != len(up_block_types):
112+
raise ValueError(
113+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
114+
)
115+
116+
if len(block_out_channels) != len(down_block_types):
117+
raise ValueError(
118+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
119+
)
120+
110121
# input
111122
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
112123

src/diffusers/models/unet_2d_condition.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ def __init__(
150150

151151
self.sample_size = sample_size
152152

153+
# Check inputs
154+
if len(down_block_types) != len(up_block_types):
155+
raise ValueError(
156+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
157+
)
158+
159+
if len(block_out_channels) != len(down_block_types):
160+
raise ValueError(
161+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
162+
)
163+
164+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
165+
raise ValueError(
166+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
167+
)
168+
169+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
170+
raise ValueError(
171+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
172+
)
173+
153174
# input
154175
conv_in_padding = (conv_in_kernel - 1) // 2
155176
self.conv_in = nn.Conv2d(

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,31 @@ def __init__(
236236

237237
self.sample_size = sample_size
238238

239+
# Check inputs
240+
if len(down_block_types) != len(up_block_types):
241+
raise ValueError(
242+
"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`:"
243+
f" {down_block_types}. `up_block_types`: {up_block_types}."
244+
)
245+
246+
if len(block_out_channels) != len(down_block_types):
247+
raise ValueError(
248+
"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`:"
249+
f" {block_out_channels}. `down_block_types`: {down_block_types}."
250+
)
251+
252+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
253+
raise ValueError(
254+
"Must provide the same number of `only_cross_attention` as `down_block_types`."
255+
f" `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
256+
)
257+
258+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
259+
raise ValueError(
260+
"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`:"
261+
f" {attention_head_dim}. `down_block_types`: {down_block_types}."
262+
)
263+
239264
# input
240265
conv_in_padding = (conv_in_kernel - 1) // 2
241266
self.conv_in = LinearMultiDim(

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_dummy_components(self):
5656
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
5757
cross_attention_dim=32,
5858
# SD2-specific config below
59-
attention_head_dim=(2, 4, 8, 8),
59+
attention_head_dim=(2, 4),
6060
use_linear_projection=True,
6161
)
6262
scheduler = DDIMScheduler(

tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_dummy_components(self):
6565
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
6666
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
6767
cross_attention_dim=32,
68-
attention_head_dim=(2, 4, 8, 8),
68+
attention_head_dim=(2, 4),
6969
use_linear_projection=True,
7070
)
7171
scheduler = PNDMScheduler(skip_prk_steps=True)
@@ -284,7 +284,7 @@ def test_stable_diffusion_depth2img_default_case(self):
284284
if torch_device == "mps":
285285
expected_slice = np.array([0.6071, 0.5035, 0.4378, 0.5776, 0.5753, 0.4316, 0.4513, 0.5263, 0.4546])
286286
else:
287-
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
287+
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
288288

289289
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
290290

@@ -305,7 +305,7 @@ def test_stable_diffusion_depth2img_negative_prompt(self):
305305
if torch_device == "mps":
306306
expected_slice = np.array([0.5825, 0.5135, 0.4095, 0.5452, 0.6059, 0.4211, 0.3994, 0.5177, 0.4335])
307307
else:
308-
expected_slice = np.array([0.6332, 0.5167, 0.3911, 0.4446, 0.5971, 0.4619, 0.3821, 0.5323, 0.4621])
308+
expected_slice = np.array([0.6296, 0.5125, 0.3890, 0.4456, 0.5955, 0.4621, 0.3810, 0.5310, 0.4626])
309309

310310
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
311311

@@ -327,7 +327,7 @@ def test_stable_diffusion_depth2img_multiple_init_images(self):
327327
if torch_device == "mps":
328328
expected_slice = np.array([0.6501, 0.5150, 0.4939, 0.6688, 0.5437, 0.5758, 0.5115, 0.4406, 0.4551])
329329
else:
330-
expected_slice = np.array([0.6248, 0.5206, 0.6007, 0.6749, 0.5022, 0.6442, 0.5352, 0.4140, 0.4681])
330+
expected_slice = np.array([0.6267, 0.5232, 0.6001, 0.6738, 0.5029, 0.6429, 0.5364, 0.4159, 0.4674])
331331

332332
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
333333

@@ -382,7 +382,7 @@ def test_stable_diffusion_depth2img_pil(self):
382382
if torch_device == "mps":
383383
expected_slice = np.array([0.53232, 0.47015, 0.40868, 0.45651, 0.4891, 0.4668, 0.4287, 0.48822, 0.47439])
384384
else:
385-
expected_slice = np.array([0.6374, 0.5039, 0.4199, 0.4819, 0.5563, 0.4617, 0.4028, 0.5381, 0.4711])
385+
expected_slice = np.array([0.6312, 0.4984, 0.4154, 0.4788, 0.5535, 0.4599, 0.4017, 0.5359, 0.4716])
386386

387387
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
388388

tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_dummy_components(self):
4747
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
4848
cross_attention_dim=32,
4949
# SD2-specific config below
50-
attention_head_dim=(2, 4, 8, 8),
50+
attention_head_dim=(2, 4),
5151
use_linear_projection=True,
5252
)
5353
scheduler = PNDMScheduler(skip_prk_steps=True)

tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def dummy_cond_unet(self):
5656
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
5757
cross_attention_dim=32,
5858
# SD2-specific config below
59-
attention_head_dim=(2, 4, 8, 8),
59+
attention_head_dim=(2, 4),
6060
use_linear_projection=True,
6161
)
6262
return model

0 commit comments

Comments
 (0)