Skip to content

Commit b10f527

Browse files
authored
Helper function to disable custom attention processors (#2791)
* Helper function to disable custom attention processors. * Restore code deleted by mistake. * Format * Fix modeling_text_unet copy.
1 parent 7bc2fff commit b10f527

File tree

8 files changed

+44
-21
lines changed

8 files changed

+44
-21
lines changed

src/diffusers/models/controlnet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from ..configuration_utils import ConfigMixin, register_to_config
2222
from ..utils import BaseOutput, logging
23-
from .attention_processor import AttentionProcessor
23+
from .attention_processor import AttentionProcessor, AttnProcessor
2424
from .embeddings import TimestepEmbedding, Timesteps
2525
from .modeling_utils import ModelMixin
2626
from .unet_2d_blocks import (
@@ -368,6 +368,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
368368
for name, module in self.named_children():
369369
fn_recursive_attn_processor(name, module, processor)
370370

371+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
372+
def set_default_attn_processor(self):
373+
"""
374+
Disables custom attention processors and sets the default attention implementation.
375+
"""
376+
self.set_attn_processor(AttnProcessor())
377+
371378
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
372379
def set_attention_slice(self, slice_size):
373380
r"""

src/diffusers/models/unet_2d_condition.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..configuration_utils import ConfigMixin, register_to_config
2222
from ..loaders import UNet2DConditionLoadersMixin
2323
from ..utils import BaseOutput, logging
24-
from .attention_processor import AttentionProcessor
24+
from .attention_processor import AttentionProcessor, AttnProcessor
2525
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
2626
from .modeling_utils import ModelMixin
2727
from .unet_2d_blocks import (
@@ -442,6 +442,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
442442
for name, module in self.named_children():
443443
fn_recursive_attn_processor(name, module, processor)
444444

445+
def set_default_attn_processor(self):
446+
"""
447+
Disables custom attention processors and sets the default attention implementation.
448+
"""
449+
self.set_attn_processor(AttnProcessor())
450+
445451
def set_attention_slice(self, slice_size):
446452
r"""
447453
Enable sliced attention computation.

src/diffusers/models/unet_3d_condition.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
2323
from ..utils import BaseOutput, logging
24-
from .attention_processor import AttentionProcessor
24+
from .attention_processor import AttentionProcessor, AttnProcessor
2525
from .embeddings import TimestepEmbedding, Timesteps
2626
from .modeling_utils import ModelMixin
2727
from .transformer_temporal import TransformerTemporalModel
@@ -372,6 +372,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
372372
for name, module in self.named_children():
373373
fn_recursive_attn_processor(name, module, processor)
374374

375+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
376+
def set_default_attn_processor(self):
377+
"""
378+
Disables custom attention processors and sets the default attention implementation.
379+
"""
380+
self.set_attn_processor(AttnProcessor())
381+
375382
def _set_gradient_checkpointing(self, module, value=False):
376383
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
377384
module.gradient_checkpointing = value

src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ...configuration_utils import ConfigMixin, register_to_config
88
from ...models import ModelMixin
99
from ...models.attention import Attention
10-
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor
10+
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
1111
from ...models.dual_transformer_2d import DualTransformer2DModel
1212
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
1313
from ...models.transformer_2d import Transformer2DModel
@@ -533,6 +533,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
533533
for name, module in self.named_children():
534534
fn_recursive_attn_processor(name, module, processor)
535535

536+
def set_default_attn_processor(self):
537+
"""
538+
Disables custom attention processors and sets the default attention implementation.
539+
"""
540+
self.set_attn_processor(AttnProcessor())
541+
536542
def set_attention_slice(self, slice_size):
537543
r"""
538544
Enable sliced attention computation.

tests/models/test_models_unet_2d_condition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from parameterized import parameterized
2323

2424
from diffusers import UNet2DConditionModel
25-
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
25+
from diffusers.models.attention_processor import LoRAAttnProcessor
2626
from diffusers.utils import (
2727
floats_tensor,
2828
load_hf_numpy,
@@ -599,7 +599,7 @@ def test_lora_on_off(self):
599599
with torch.no_grad():
600600
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
601601

602-
model.set_attn_processor(AttnProcessor())
602+
model.set_default_attn_processor()
603603

604604
with torch.no_grad():
605605
new_sample = model(**inputs_dict).sample

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
UNet2DConditionModel,
3636
logging,
3737
)
38-
from diffusers.models.attention_processor import AttnProcessor
3938
from diffusers.utils import load_numpy, nightly, slow, torch_device
4039
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
4140

@@ -843,7 +842,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
843842
"CompVis/stable-diffusion-v1-4",
844843
torch_dtype=torch.float16,
845844
)
846-
pipe.unet.set_attn_processor(AttnProcessor())
845+
pipe.unet.set_default_attn_processor()
847846
pipe.to(torch_device)
848847
pipe.set_progress_bar_config(disable=None)
849848
outputs = pipe(**inputs)
@@ -856,7 +855,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
856855
"CompVis/stable-diffusion-v1-4",
857856
torch_dtype=torch.float16,
858857
)
859-
pipe.unet.set_attn_processor(AttnProcessor())
858+
pipe.unet.set_default_attn_processor()
860859

861860
torch.cuda.empty_cache()
862861
torch.cuda.reset_max_memory_allocated()

tests/pipelines/stable_diffusion_2/test_stable_diffusion.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
UNet2DConditionModel,
3333
logging,
3434
)
35-
from diffusers.models.attention_processor import AttnProcessor
3635
from diffusers.utils import load_numpy, nightly, slow, torch_device
3736
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu
3837

@@ -410,7 +409,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
410409
"stabilityai/stable-diffusion-2-base",
411410
torch_dtype=torch.float16,
412411
)
413-
pipe.unet.set_attn_processor(AttnProcessor())
412+
pipe.unet.set_default_attn_processor()
414413
pipe.to(torch_device)
415414
pipe.set_progress_bar_config(disable=None)
416415
outputs = pipe(**inputs)
@@ -423,7 +422,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
423422
"stabilityai/stable-diffusion-2-base",
424423
torch_dtype=torch.float16,
425424
)
426-
pipe.unet.set_attn_processor(AttnProcessor())
425+
pipe.unet.set_default_attn_processor()
427426

428427
torch.cuda.empty_cache()
429428
torch.cuda.reset_max_memory_allocated()

tests/test_modeling_common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from requests.exceptions import HTTPError
2626

2727
from diffusers.models import UNet2DConditionModel
28-
from diffusers.models.attention_processor import AttnProcessor
2928
from diffusers.training_utils import EMAModel
3029
from diffusers.utils import torch_device
3130

@@ -106,16 +105,16 @@ def test_from_save_pretrained(self):
106105
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
107106

108107
model = self.model_class(**init_dict)
109-
if hasattr(model, "set_attn_processor"):
110-
model.set_attn_processor(AttnProcessor())
108+
if hasattr(model, "set_default_attn_processor"):
109+
model.set_default_attn_processor()
111110
model.to(torch_device)
112111
model.eval()
113112

114113
with tempfile.TemporaryDirectory() as tmpdirname:
115114
model.save_pretrained(tmpdirname)
116115
new_model = self.model_class.from_pretrained(tmpdirname)
117-
if hasattr(new_model, "set_attn_processor"):
118-
new_model.set_attn_processor(AttnProcessor())
116+
if hasattr(new_model, "set_default_attn_processor"):
117+
new_model.set_default_attn_processor()
119118
new_model.to(torch_device)
120119

121120
with torch.no_grad():
@@ -135,16 +134,16 @@ def test_from_save_pretrained_variant(self):
135134
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
136135

137136
model = self.model_class(**init_dict)
138-
if hasattr(model, "set_attn_processor"):
139-
model.set_attn_processor(AttnProcessor())
137+
if hasattr(model, "set_default_attn_processor"):
138+
model.set_default_attn_processor()
140139
model.to(torch_device)
141140
model.eval()
142141

143142
with tempfile.TemporaryDirectory() as tmpdirname:
144143
model.save_pretrained(tmpdirname, variant="fp16")
145144
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
146-
if hasattr(new_model, "set_attn_processor"):
147-
new_model.set_attn_processor(AttnProcessor())
145+
if hasattr(new_model, "set_default_attn_processor"):
146+
new_model.set_default_attn_processor()
148147

149148
# non-variant cannot be loaded
150149
with self.assertRaises(OSError) as error_context:

0 commit comments

Comments
 (0)