21
21
22
22
from ..configuration_utils import ConfigMixin , register_to_config
23
23
from ..utils import BaseOutput , logging
24
+ from .attention_processor import AttentionProcessor
24
25
from .embeddings import TimestepEmbedding , Timesteps
25
26
from .modeling_utils import ModelMixin
26
27
from .transformer_temporal import TransformerTemporalModel
@@ -249,6 +250,32 @@ def __init__(
249
250
block_out_channels [0 ], out_channels , kernel_size = conv_out_kernel , padding = conv_out_padding
250
251
)
251
252
253
+ @property
254
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
255
+ def attn_processors (self ) -> Dict [str , AttentionProcessor ]:
256
+ r"""
257
+ Returns:
258
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
259
+ indexed by its weight name.
260
+ """
261
+ # set recursively
262
+ processors = {}
263
+
264
+ def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
265
+ if hasattr (module , "set_processor" ):
266
+ processors [f"{ name } .processor" ] = module .processor
267
+
268
+ for sub_name , child in module .named_children ():
269
+ fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
270
+
271
+ return processors
272
+
273
+ for name , module in self .named_children ():
274
+ fn_recursive_add_processors (name , module , processors )
275
+
276
+ return processors
277
+
278
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
252
279
def set_attention_slice (self , slice_size ):
253
280
r"""
254
281
Enable sliced attention computation.
@@ -259,34 +286,34 @@ def set_attention_slice(self, slice_size):
259
286
Args:
260
287
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
261
288
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
262
- `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
289
+ `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
263
290
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
264
291
must be a multiple of `slice_size`.
265
292
"""
266
293
sliceable_head_dims = []
267
294
268
- def fn_recursive_retrieve_slicable_dims (module : torch .nn .Module ):
295
+ def fn_recursive_retrieve_sliceable_dims (module : torch .nn .Module ):
269
296
if hasattr (module , "set_attention_slice" ):
270
297
sliceable_head_dims .append (module .sliceable_head_dim )
271
298
272
299
for child in module .children ():
273
- fn_recursive_retrieve_slicable_dims (child )
300
+ fn_recursive_retrieve_sliceable_dims (child )
274
301
275
302
# retrieve number of attention layers
276
303
for module in self .children ():
277
- fn_recursive_retrieve_slicable_dims (module )
304
+ fn_recursive_retrieve_sliceable_dims (module )
278
305
279
- num_slicable_layers = len (sliceable_head_dims )
306
+ num_sliceable_layers = len (sliceable_head_dims )
280
307
281
308
if slice_size == "auto" :
282
309
# half the attention head size is usually a good trade-off between
283
310
# speed and memory
284
311
slice_size = [dim // 2 for dim in sliceable_head_dims ]
285
312
elif slice_size == "max" :
286
313
# make smallest slice possible
287
- slice_size = num_slicable_layers * [1 ]
314
+ slice_size = num_sliceable_layers * [1 ]
288
315
289
- slice_size = num_slicable_layers * [slice_size ] if not isinstance (slice_size , list ) else slice_size
316
+ slice_size = num_sliceable_layers * [slice_size ] if not isinstance (slice_size , list ) else slice_size
290
317
291
318
if len (slice_size ) != len (sliceable_head_dims ):
292
319
raise ValueError (
@@ -314,6 +341,37 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
314
341
for module in self .children ():
315
342
fn_recursive_set_attention_slice (module , reversed_slice_size )
316
343
344
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345
+ def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
346
+ r"""
347
+ Parameters:
348
+ `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
349
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
350
+ of **all** `Attention` layers.
351
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
352
+
353
+ """
354
+ count = len (self .attn_processors .keys ())
355
+
356
+ if isinstance (processor , dict ) and len (processor ) != count :
357
+ raise ValueError (
358
+ f"A dict of processors was passed, but the number of processors { len (processor )} does not match the"
359
+ f" number of attention layers: { count } . Please make sure to pass { count } processor classes."
360
+ )
361
+
362
+ def fn_recursive_attn_processor (name : str , module : torch .nn .Module , processor ):
363
+ if hasattr (module , "set_processor" ):
364
+ if not isinstance (processor , dict ):
365
+ module .set_processor (processor )
366
+ else :
367
+ module .set_processor (processor .pop (f"{ name } .processor" ))
368
+
369
+ for sub_name , child in module .named_children ():
370
+ fn_recursive_attn_processor (f"{ name } .{ sub_name } " , child , processor )
371
+
372
+ for name , module in self .named_children ():
373
+ fn_recursive_attn_processor (name , module , processor )
374
+
317
375
def _set_gradient_checkpointing (self , module , value = False ):
318
376
if isinstance (module , (CrossAttnDownBlock3D , DownBlock3D , CrossAttnUpBlock3D , UpBlock3D )):
319
377
module .gradient_checkpointing = value
0 commit comments