Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
Expand Down Expand Up @@ -368,6 +368,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):
r"""
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import (
Expand Down Expand Up @@ -415,6 +415,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())

def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
Expand Down
9 changes: 8 additions & 1 deletion src/diffusers/models/unet_3d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor
from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .transformer_temporal import TransformerTemporalModel
Expand Down Expand Up @@ -372,6 +372,13 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ...configuration_utils import ConfigMixin, register_to_config
from ...models import ModelMixin
from ...models.attention import Attention
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor
from ...models.attention_processor import AttentionProcessor, AttnAddedKVProcessor, AttnProcessor
from ...models.dual_transformer_2d import DualTransformer2DModel
from ...models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ...models.transformer_2d import Transformer2DModel
Expand Down Expand Up @@ -505,6 +505,12 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)

def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())

def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from parameterized import parameterized

from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_lora_on_off(self):
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample

model.set_attn_processor(AttnProcessor())
model.set_default_attn_processor()

with torch.no_grad():
new_sample = model(**inputs_dict).sample
Expand Down
5 changes: 2 additions & 3 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu

Expand Down Expand Up @@ -843,7 +842,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
Expand All @@ -856,7 +855,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
Expand Down
5 changes: 2 additions & 3 deletions tests/pipelines/stable_diffusion_2/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
UNet2DConditionModel,
logging,
)
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu

Expand Down Expand Up @@ -410,7 +409,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
"stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
outputs = pipe(**inputs)
Expand All @@ -423,7 +422,7 @@ def test_stable_diffusion_pipeline_with_model_offloading(self):
"stabilityai/stable-diffusion-2-base",
torch_dtype=torch.float16,
)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.unet.set_default_attn_processor()

torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
Expand Down
17 changes: 8 additions & 9 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from requests.exceptions import HTTPError

from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import AttnProcessor
from diffusers.training_utils import EMAModel
from diffusers.utils import torch_device

Expand Down Expand Up @@ -106,16 +105,16 @@ def test_from_save_pretrained(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
if hasattr(model, "set_attn_processor"):
model.set_attn_processor(AttnProcessor())
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
model.eval()

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
new_model = self.model_class.from_pretrained(tmpdirname)
if hasattr(new_model, "set_attn_processor"):
new_model.set_attn_processor(AttnProcessor())
if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor()
new_model.to(torch_device)

with torch.no_grad():
Expand All @@ -135,16 +134,16 @@ def test_from_save_pretrained_variant(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
if hasattr(model, "set_attn_processor"):
model.set_attn_processor(AttnProcessor())
if hasattr(model, "set_default_attn_processor"):
model.set_default_attn_processor()
model.to(torch_device)
model.eval()

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
if hasattr(new_model, "set_attn_processor"):
new_model.set_attn_processor(AttnProcessor())
if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor()

# non-variant cannot be loaded
with self.assertRaises(OSError) as error_context:
Expand Down