@@ -86,13 +86,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
86
86
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
87
87
If `None`, it will skip the normalization and activation layers in post-processing
88
88
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
89
- cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
89
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
90
+ The dimension of the cross attention features.
90
91
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
91
92
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
92
93
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
93
94
class_embed_type (`str`, *optional*, defaults to None):
94
95
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
95
- `"timestep"`, `"identity"`, or `"projection "`.
96
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection "`.
96
97
num_class_embeds (`int`, *optional*, defaults to None):
97
98
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
98
99
class conditioning with `class_embed_type` equal to `None`.
@@ -106,6 +107,8 @@ class conditioning with `class_embed_type` equal to `None`.
106
107
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
107
108
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
108
109
using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
110
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
111
+ embeddings with the class embeddings.
109
112
"""
110
113
111
114
_supports_gradient_checkpointing = True
@@ -135,7 +138,7 @@ def __init__(
135
138
act_fn : str = "silu" ,
136
139
norm_num_groups : Optional [int ] = 32 ,
137
140
norm_eps : float = 1e-5 ,
138
- cross_attention_dim : int = 1280 ,
141
+ cross_attention_dim : Union [ int , Tuple [ int ]] = 1280 ,
139
142
attention_head_dim : Union [int , Tuple [int ]] = 8 ,
140
143
dual_cross_attention : bool = False ,
141
144
use_linear_projection : bool = False ,
@@ -149,6 +152,7 @@ def __init__(
149
152
conv_in_kernel : int = 3 ,
150
153
conv_out_kernel : int = 3 ,
151
154
projection_class_embeddings_input_dim : Optional [int ] = None ,
155
+ class_embeddings_concat : bool = False ,
152
156
):
153
157
super ().__init__ ()
154
158
@@ -175,6 +179,11 @@ def __init__(
175
179
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: { attention_head_dim } . `down_block_types`: { down_block_types } ."
176
180
)
177
181
182
+ if isinstance (cross_attention_dim , list ) and len (cross_attention_dim ) != len (down_block_types ):
183
+ raise ValueError (
184
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: { cross_attention_dim } . `down_block_types`: { down_block_types } ."
185
+ )
186
+
178
187
# input
179
188
conv_in_padding = (conv_in_kernel - 1 ) // 2
180
189
self .conv_in = nn .Conv2d (
@@ -228,6 +237,12 @@ def __init__(
228
237
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
229
238
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
230
239
self .class_embedding = TimestepEmbedding (projection_class_embeddings_input_dim , time_embed_dim )
240
+ elif class_embed_type == "simple_projection" :
241
+ if projection_class_embeddings_input_dim is None :
242
+ raise ValueError (
243
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
244
+ )
245
+ self .class_embedding = nn .Linear (projection_class_embeddings_input_dim , time_embed_dim )
231
246
else :
232
247
self .class_embedding = None
233
248
@@ -240,6 +255,17 @@ def __init__(
240
255
if isinstance (attention_head_dim , int ):
241
256
attention_head_dim = (attention_head_dim ,) * len (down_block_types )
242
257
258
+ if isinstance (cross_attention_dim , int ):
259
+ cross_attention_dim = (cross_attention_dim ,) * len (down_block_types )
260
+
261
+ if class_embeddings_concat :
262
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
263
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
264
+ # regular time embeddings
265
+ blocks_time_embed_dim = time_embed_dim * 2
266
+ else :
267
+ blocks_time_embed_dim = time_embed_dim
268
+
243
269
# down
244
270
output_channel = block_out_channels [0 ]
245
271
for i , down_block_type in enumerate (down_block_types ):
@@ -252,12 +278,12 @@ def __init__(
252
278
num_layers = layers_per_block ,
253
279
in_channels = input_channel ,
254
280
out_channels = output_channel ,
255
- temb_channels = time_embed_dim ,
281
+ temb_channels = blocks_time_embed_dim ,
256
282
add_downsample = not is_final_block ,
257
283
resnet_eps = norm_eps ,
258
284
resnet_act_fn = act_fn ,
259
285
resnet_groups = norm_num_groups ,
260
- cross_attention_dim = cross_attention_dim ,
286
+ cross_attention_dim = cross_attention_dim [ i ] ,
261
287
attn_num_head_channels = attention_head_dim [i ],
262
288
downsample_padding = downsample_padding ,
263
289
dual_cross_attention = dual_cross_attention ,
@@ -272,12 +298,12 @@ def __init__(
272
298
if mid_block_type == "UNetMidBlock2DCrossAttn" :
273
299
self .mid_block = UNetMidBlock2DCrossAttn (
274
300
in_channels = block_out_channels [- 1 ],
275
- temb_channels = time_embed_dim ,
301
+ temb_channels = blocks_time_embed_dim ,
276
302
resnet_eps = norm_eps ,
277
303
resnet_act_fn = act_fn ,
278
304
output_scale_factor = mid_block_scale_factor ,
279
305
resnet_time_scale_shift = resnet_time_scale_shift ,
280
- cross_attention_dim = cross_attention_dim ,
306
+ cross_attention_dim = cross_attention_dim [ - 1 ] ,
281
307
attn_num_head_channels = attention_head_dim [- 1 ],
282
308
resnet_groups = norm_num_groups ,
283
309
dual_cross_attention = dual_cross_attention ,
@@ -287,11 +313,11 @@ def __init__(
287
313
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn" :
288
314
self .mid_block = UNetMidBlock2DSimpleCrossAttn (
289
315
in_channels = block_out_channels [- 1 ],
290
- temb_channels = time_embed_dim ,
316
+ temb_channels = blocks_time_embed_dim ,
291
317
resnet_eps = norm_eps ,
292
318
resnet_act_fn = act_fn ,
293
319
output_scale_factor = mid_block_scale_factor ,
294
- cross_attention_dim = cross_attention_dim ,
320
+ cross_attention_dim = cross_attention_dim [ - 1 ] ,
295
321
attn_num_head_channels = attention_head_dim [- 1 ],
296
322
resnet_groups = norm_num_groups ,
297
323
resnet_time_scale_shift = resnet_time_scale_shift ,
@@ -307,6 +333,7 @@ def __init__(
307
333
# up
308
334
reversed_block_out_channels = list (reversed (block_out_channels ))
309
335
reversed_attention_head_dim = list (reversed (attention_head_dim ))
336
+ reversed_cross_attention_dim = list (reversed (cross_attention_dim ))
310
337
only_cross_attention = list (reversed (only_cross_attention ))
311
338
312
339
output_channel = reversed_block_out_channels [0 ]
@@ -330,12 +357,12 @@ def __init__(
330
357
in_channels = input_channel ,
331
358
out_channels = output_channel ,
332
359
prev_output_channel = prev_output_channel ,
333
- temb_channels = time_embed_dim ,
360
+ temb_channels = blocks_time_embed_dim ,
334
361
add_upsample = add_upsample ,
335
362
resnet_eps = norm_eps ,
336
363
resnet_act_fn = act_fn ,
337
364
resnet_groups = norm_num_groups ,
338
- cross_attention_dim = cross_attention_dim ,
365
+ cross_attention_dim = reversed_cross_attention_dim [ i ] ,
339
366
attn_num_head_channels = reversed_attention_head_dim [i ],
340
367
dual_cross_attention = dual_cross_attention ,
341
368
use_linear_projection = use_linear_projection ,
@@ -571,7 +598,11 @@ def forward(
571
598
class_labels = self .time_proj (class_labels )
572
599
573
600
class_emb = self .class_embedding (class_labels ).to (dtype = self .dtype )
574
- emb = emb + class_emb
601
+
602
+ if self .config .class_embeddings_concat :
603
+ emb = torch .cat ([emb , class_emb ], dim = - 1 )
604
+ else :
605
+ emb = emb + class_emb
575
606
576
607
# 2. pre-process
577
608
sample = self .conv_in (sample )
0 commit comments