Skip to content

Commit a8315ce

Browse files
[UNet3DModel] Fix with attn processor (#2790)
* [UNet3DModel] Fix attn processor * make style
1 parent 0d633a4 commit a8315ce

File tree

1 file changed

+65
-7
lines changed

1 file changed

+65
-7
lines changed

src/diffusers/models/unet_3d_condition.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..utils import BaseOutput, logging
24+
from .attention_processor import AttentionProcessor
2425
from .embeddings import TimestepEmbedding, Timesteps
2526
from .modeling_utils import ModelMixin
2627
from .transformer_temporal import TransformerTemporalModel
@@ -249,6 +250,32 @@ def __init__(
249250
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
250251
)
251252

253+
@property
254+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
255+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
256+
r"""
257+
Returns:
258+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
259+
indexed by its weight name.
260+
"""
261+
# set recursively
262+
processors = {}
263+
264+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
265+
if hasattr(module, "set_processor"):
266+
processors[f"{name}.processor"] = module.processor
267+
268+
for sub_name, child in module.named_children():
269+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
270+
271+
return processors
272+
273+
for name, module in self.named_children():
274+
fn_recursive_add_processors(name, module, processors)
275+
276+
return processors
277+
278+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
252279
def set_attention_slice(self, slice_size):
253280
r"""
254281
Enable sliced attention computation.
@@ -259,34 +286,34 @@ def set_attention_slice(self, slice_size):
259286
Args:
260287
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
261288
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
262-
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
289+
`"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is
263290
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
264291
must be a multiple of `slice_size`.
265292
"""
266293
sliceable_head_dims = []
267294

268-
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
295+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
269296
if hasattr(module, "set_attention_slice"):
270297
sliceable_head_dims.append(module.sliceable_head_dim)
271298

272299
for child in module.children():
273-
fn_recursive_retrieve_slicable_dims(child)
300+
fn_recursive_retrieve_sliceable_dims(child)
274301

275302
# retrieve number of attention layers
276303
for module in self.children():
277-
fn_recursive_retrieve_slicable_dims(module)
304+
fn_recursive_retrieve_sliceable_dims(module)
278305

279-
num_slicable_layers = len(sliceable_head_dims)
306+
num_sliceable_layers = len(sliceable_head_dims)
280307

281308
if slice_size == "auto":
282309
# half the attention head size is usually a good trade-off between
283310
# speed and memory
284311
slice_size = [dim // 2 for dim in sliceable_head_dims]
285312
elif slice_size == "max":
286313
# make smallest slice possible
287-
slice_size = num_slicable_layers * [1]
314+
slice_size = num_sliceable_layers * [1]
288315

289-
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
316+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
290317

291318
if len(slice_size) != len(sliceable_head_dims):
292319
raise ValueError(
@@ -314,6 +341,37 @@ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[i
314341
for module in self.children():
315342
fn_recursive_set_attention_slice(module, reversed_slice_size)
316343

344+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
345+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
346+
r"""
347+
Parameters:
348+
`processor (`dict` of `AttentionProcessor` or `AttentionProcessor`):
349+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
350+
of **all** `Attention` layers.
351+
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.:
352+
353+
"""
354+
count = len(self.attn_processors.keys())
355+
356+
if isinstance(processor, dict) and len(processor) != count:
357+
raise ValueError(
358+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
359+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
360+
)
361+
362+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
363+
if hasattr(module, "set_processor"):
364+
if not isinstance(processor, dict):
365+
module.set_processor(processor)
366+
else:
367+
module.set_processor(processor.pop(f"{name}.processor"))
368+
369+
for sub_name, child in module.named_children():
370+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
371+
372+
for name, module in self.named_children():
373+
fn_recursive_attn_processor(name, module, processor)
374+
317375
def _set_gradient_checkpointing(self, module, value=False):
318376
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
319377
module.gradient_checkpointing = value

0 commit comments

Comments
 (0)