Skip to content

Commit 4d1296c

Browse files
committed
unet check length input
1 parent 1e7f965 commit 4d1296c

File tree

6 files changed

+41
-4
lines changed

6 files changed

+41
-4
lines changed

src/diffusers/models/unet_2d.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,22 @@ def __init__(
107107
self.sample_size = sample_size
108108
time_embed_dim = block_out_channels[0] * 4
109109

110+
# Check inputs
111+
if len(down_block_types) != len(up_block_types):
112+
raise ValueError(
113+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
114+
)
115+
116+
if len(block_out_channels) != len(down_block_types):
117+
raise ValueError(
118+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
119+
)
120+
121+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
122+
raise ValueError(
123+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
124+
)
125+
110126
# input
111127
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
112128

src/diffusers/models/unet_2d_condition.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,27 @@ def __init__(
150150

151151
self.sample_size = sample_size
152152

153+
# Check inputs
154+
if len(down_block_types) != len(up_block_types):
155+
raise ValueError(
156+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
157+
)
158+
159+
if len(block_out_channels) != len(down_block_types):
160+
raise ValueError(
161+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
162+
)
163+
164+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
165+
raise ValueError(
166+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
167+
)
168+
169+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
170+
raise ValueError(
171+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
172+
)
173+
153174
# input
154175
conv_in_padding = (conv_in_kernel - 1) // 2
155176
self.conv_in = nn.Conv2d(

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_dummy_components(self):
5656
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
5757
cross_attention_dim=32,
5858
# SD2-specific config below
59-
attention_head_dim=(2, 4, 8, 8),
59+
attention_head_dim=(2, 4),
6060
use_linear_projection=True,
6161
)
6262
scheduler = DDIMScheduler(

tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def get_dummy_components(self):
6565
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
6666
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
6767
cross_attention_dim=32,
68-
attention_head_dim=(2, 4, 8, 8),
68+
attention_head_dim=(2, 4),
6969
use_linear_projection=True,
7070
)
7171
scheduler = PNDMScheduler(skip_prk_steps=True)

tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_dummy_components(self):
4747
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
4848
cross_attention_dim=32,
4949
# SD2-specific config below
50-
attention_head_dim=(2, 4, 8, 8),
50+
attention_head_dim=(2, 4),
5151
use_linear_projection=True,
5252
)
5353
scheduler = PNDMScheduler(skip_prk_steps=True)

tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def dummy_cond_unet(self):
5656
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
5757
cross_attention_dim=32,
5858
# SD2-specific config below
59-
attention_head_dim=(2, 4, 8, 8),
59+
attention_head_dim=(2, 4),
6060
use_linear_projection=True,
6161
)
6262
return model

0 commit comments

Comments
 (0)