Skip to content

Commit e727032

Browse files
authored
Helper function to disable custom attention processors (huggingface#2791)
* Helper function to disable custom attention processors. * Restore code deleted by mistake. * Format * Fix modeling_text_unet copy.
1 parent c81df45 commit e727032

File tree

4 files changed

+30
-4
lines changed

4 files changed

+30
-4
lines changed

models/controlnet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ..configuration_utils import ConfigMixin, register_to_config
2222
from ..utils import BaseOutput, logging
23-
from .attention_processor import AttentionProcessor
23+
from .attention_processor import AttentionProcessor, AttnProcessor
2424
from .embeddings import TimestepEmbedding, Timesteps
2525
from .modeling_utils import ModelMixin
2626
from .unet_2d_blocks import (
@@ -368,6 +368,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
368368
for name, module in self.named_children():
369369
fn_recursive_attn_processor(name, module, processor)
370370

371+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
372+
def set_default_attn_processor(self):
373+
"""
374+
Disables custom attention processors and sets the default attention implementation.
375+
"""
376+
self.set_attn_processor(AttnProcessor())
377+
371378
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
372379
def set_attention_slice(self, slice_size):
373380
r"""

models/unet_2d_condition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..configuration_utils import ConfigMixin, register_to_config
2222
from ..loaders import UNet2DConditionLoadersMixin
2323
from ..utils import BaseOutput, logging
24-
from .attention_processor import AttentionProcessor
24+
from .attention_processor import AttentionProcessor, AttnProcessor
2525
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
2626
from .modeling_utils import ModelMixin
2727
from .unet_2d_blocks import (
@@ -442,6 +442,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
442442
for name, module in self.named_children():
443443
fn_recursive_attn_processor(name, module, processor)
444444

445+
def set_default_attn_processor(self):
446+
"""
447+
Disables custom attention processors and sets the default attention implementation.
448+
"""
449+
self.set_attn_processor(AttnProcessor())
450+
445451
def set_attention_slice(self, slice_size):
446452
r"""
447453
Enable sliced attention computation.

models/unet_3d_condition.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..utils import BaseOutput, logging
24-
from .attention_processor import AttentionProcessor
24+
from .attention_processor import AttentionProcessor, AttnProcessor
2525
from .embeddings import TimestepEmbedding, Timesteps
2626
from .modeling_utils import ModelMixin
2727
from .transformer_temporal import TransformerTemporalModel
@@ -372,6 +372,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
372372
for name, module in self.named_children():
373373
fn_recursive_attn_processor(name, module, processor)
374374

375+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
376+
def set_default_attn_processor(self):
377+
"""
378+
Disables custom attention processors and sets the default attention implementation.
379+
"""
380+
self.set_attn_processor(AttnProcessor())
381+
375382
def _set_gradient_checkpointing(self, module, value=False):
376383
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
377384
module.gradient_checkpointing = value

pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ...configuration_utils import ConfigMixin, register_to_config
88
from ...models import ModelMixin
99
from ...models.attention import Attention
10-
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor
10+
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
1111
from ...models.dual_transformer_2d import DualTransformer2DModel
1212
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
1313
from ...models.transformer_2d import Transformer2DModel
@@ -533,6 +533,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
533533
for name, module in self.named_children():
534534
fn_recursive_attn_processor(name, module, processor)
535535

536+
def set_default_attn_processor(self):
537+
"""
538+
Disables custom attention processors and sets the default attention implementation.
539+
"""
540+
self.set_attn_processor(AttnProcessor())
541+
536542
def set_attention_slice(self, slice_size):
537543
r"""
538544
Enable sliced attention computation.

0 commit comments

Comments
 (0)