2222from ..configuration_utils import ConfigMixin , register_to_config
2323from ..utils import BaseOutput , is_torch_version , logging
2424from ..utils .torch_utils import apply_freeu
25- from .attention_processor import Attention , AttentionProcessor
25+ from .attention_processor import (
26+ ADDED_KV_ATTENTION_PROCESSORS ,
27+ CROSS_ATTENTION_PROCESSORS ,
28+ Attention ,
29+ AttentionProcessor ,
30+ AttnAddedKVProcessor ,
31+ AttnProcessor ,
32+ )
2633from .controlnet import ControlNetConditioningEmbedding
2734from .embeddings import TimestepEmbedding , Timesteps
2835from .modeling_utils import ModelMixin
@@ -869,7 +876,7 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors:
869876
870877 return processors
871878
872- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
879+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
873880 def set_attn_processor (self , processor : Union [AttentionProcessor , Dict [str , AttentionProcessor ]]):
874881 r"""
875882 Sets the attention processor to use to compute attention.
@@ -904,7 +911,23 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
904911 for name , module in self .named_children ():
905912 fn_recursive_attn_processor (name , module , processor )
906913
907- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
914+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
915+ def set_default_attn_processor (self ):
916+ """
917+ Disables custom attention processors and sets the default attention implementation.
918+ """
919+ if all (proc .__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self .attn_processors .values ()):
920+ processor = AttnAddedKVProcessor ()
921+ elif all (proc .__class__ in CROSS_ATTENTION_PROCESSORS for proc in self .attn_processors .values ()):
922+ processor = AttnProcessor ()
923+ else :
924+ raise ValueError (
925+ f"Cannot call `set_default_attn_processor` when attention processors are of type { next (iter (self .attn_processors .values ()))} "
926+ )
927+
928+ self .set_attn_processor (processor )
929+
930+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
908931 def enable_freeu (self , s1 : float , s2 : float , b1 : float , b2 : float ):
909932 r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
910933
@@ -929,7 +952,7 @@ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
929952 setattr (upsample_block , "b1" , b1 )
930953 setattr (upsample_block , "b2" , b2 )
931954
932- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
955+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
933956 def disable_freeu (self ):
934957 """Disables the FreeU mechanism."""
935958 freeu_keys = {"s1" , "s2" , "b1" , "b2" }
@@ -938,7 +961,7 @@ def disable_freeu(self):
938961 if hasattr (upsample_block , k ) or getattr (upsample_block , k , None ) is not None :
939962 setattr (upsample_block , k , None )
940963
941- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
964+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
942965 def fuse_qkv_projections (self ):
943966 """
944967 Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -962,7 +985,7 @@ def fuse_qkv_projections(self):
962985 if isinstance (module , Attention ):
963986 module .fuse_projections (fuse = True )
964987
965- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
988+ # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
966989 def unfuse_qkv_projections (self ):
967990 """Disables the fused QKV projection if enabled.
968991
0 commit comments