Skip to content

Commit 2d94c78

Browse files
authored
[Core] feat: enable fused attention projections for other SD and SDXL pipelines (#6179)
* feat: enable fused attention projections for other SD and SDXL pipelines * add: test for SD fused projections.
1 parent a81334e commit 2d94c78

File tree

8 files changed

+461
-0
lines changed

8 files changed

+461
-0
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...image_processor import PipelineImageInput, VaeImageProcessor
2424
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2525
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26+
from ...models.attention_processor import FusedAttnProcessor2_0
2627
from ...models.lora import adjust_lora_scale_text_encoder
2728
from ...schedulers import KarrasDiffusionSchedulers
2829
from ...utils import (
@@ -655,6 +656,65 @@ def disable_freeu(self):
655656
"""Disables the FreeU mechanism if enabled."""
656657
self.unet.disable_freeu()
657658

659+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
660+
"""
661+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
662+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
663+
664+
<Tip warning={true}>
665+
666+
This API is 🧪 experimental.
667+
668+
</Tip>
669+
670+
Args:
671+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
672+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
673+
"""
674+
self.fusing_unet = False
675+
self.fusing_vae = False
676+
677+
if unet:
678+
self.fusing_unet = True
679+
self.unet.fuse_qkv_projections()
680+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
681+
682+
if vae:
683+
if not isinstance(self.vae, AutoencoderKL):
684+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
685+
686+
self.fusing_vae = True
687+
self.vae.fuse_qkv_projections()
688+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
689+
690+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
691+
"""Disable QKV projection fusion if enabled.
692+
693+
<Tip warning={true}>
694+
695+
This API is 🧪 experimental.
696+
697+
</Tip>
698+
699+
Args:
700+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
701+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
702+
703+
"""
704+
if unet:
705+
if not self.fusing_unet:
706+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
707+
else:
708+
self.unet.unfuse_qkv_projections()
709+
self.fusing_unet = False
710+
711+
if vae:
712+
if not self.fusing_vae:
713+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
714+
else:
715+
self.vae.unfuse_qkv_projections()
716+
self.fusing_vae = False
717+
658718
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
659719
"""
660720
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...image_processor import PipelineImageInput, VaeImageProcessor
2626
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2727
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
28+
from ...models.attention_processor import FusedAttnProcessor2_0
2829
from ...models.lora import adjust_lora_scale_text_encoder
2930
from ...schedulers import KarrasDiffusionSchedulers
3031
from ...utils import (
@@ -715,6 +716,65 @@ def disable_freeu(self):
715716
"""Disables the FreeU mechanism if enabled."""
716717
self.unet.disable_freeu()
717718

719+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
720+
"""
721+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
722+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
723+
724+
<Tip warning={true}>
725+
726+
This API is 🧪 experimental.
727+
728+
</Tip>
729+
730+
Args:
731+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
732+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
733+
"""
734+
self.fusing_unet = False
735+
self.fusing_vae = False
736+
737+
if unet:
738+
self.fusing_unet = True
739+
self.unet.fuse_qkv_projections()
740+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
741+
742+
if vae:
743+
if not isinstance(self.vae, AutoencoderKL):
744+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
745+
746+
self.fusing_vae = True
747+
self.vae.fuse_qkv_projections()
748+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
749+
750+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
751+
"""Disable QKV projection fusion if enabled.
752+
753+
<Tip warning={true}>
754+
755+
This API is 🧪 experimental.
756+
757+
</Tip>
758+
759+
Args:
760+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
761+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
762+
763+
"""
764+
if unet:
765+
if not self.fusing_unet:
766+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
767+
else:
768+
self.unet.unfuse_qkv_projections()
769+
self.fusing_unet = False
770+
771+
if vae:
772+
if not self.fusing_vae:
773+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
774+
else:
775+
self.vae.unfuse_qkv_projections()
776+
self.fusing_vae = False
777+
718778
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
719779
"""
720780
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...image_processor import PipelineImageInput, VaeImageProcessor
2424
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2525
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
26+
from ...models.attention_processor import FusedAttnProcessor2_0
2627
from ...models.lora import adjust_lora_scale_text_encoder
2728
from ...schedulers import KarrasDiffusionSchedulers
2829
from ...utils import (
@@ -650,6 +651,67 @@ def disable_freeu(self):
650651
"""Disables the FreeU mechanism if enabled."""
651652
self.unet.disable_freeu()
652653

654+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
655+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
656+
"""
657+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
658+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
659+
660+
<Tip warning={true}>
661+
662+
This API is 🧪 experimental.
663+
664+
</Tip>
665+
666+
Args:
667+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
668+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
669+
"""
670+
self.fusing_unet = False
671+
self.fusing_vae = False
672+
673+
if unet:
674+
self.fusing_unet = True
675+
self.unet.fuse_qkv_projections()
676+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
677+
678+
if vae:
679+
if not isinstance(self.vae, AutoencoderKL):
680+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
681+
682+
self.fusing_vae = True
683+
self.vae.fuse_qkv_projections()
684+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
685+
686+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
687+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
688+
"""Disable QKV projection fusion if enabled.
689+
690+
<Tip warning={true}>
691+
692+
This API is 🧪 experimental.
693+
694+
</Tip>
695+
696+
Args:
697+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
698+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
699+
700+
"""
701+
if unet:
702+
if not self.fusing_unet:
703+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
704+
else:
705+
self.unet.unfuse_qkv_projections()
706+
self.fusing_unet = False
707+
708+
if vae:
709+
if not self.fusing_vae:
710+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
711+
else:
712+
self.vae.unfuse_qkv_projections()
713+
self.fusing_vae = False
714+
653715
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
654716
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
655717
"""

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...image_processor import PipelineImageInput, VaeImageProcessor
2626
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2727
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
28+
from ...models.attention_processor import FusedAttnProcessor2_0
2829
from ...models.lora import adjust_lora_scale_text_encoder
2930
from ...schedulers import KarrasDiffusionSchedulers
3031
from ...utils import (
@@ -718,6 +719,67 @@ def disable_freeu(self):
718719
"""Disables the FreeU mechanism if enabled."""
719720
self.unet.disable_freeu()
720721

722+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
723+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
724+
"""
725+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
726+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
727+
728+
<Tip warning={true}>
729+
730+
This API is 🧪 experimental.
731+
732+
</Tip>
733+
734+
Args:
735+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
736+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
737+
"""
738+
self.fusing_unet = False
739+
self.fusing_vae = False
740+
741+
if unet:
742+
self.fusing_unet = True
743+
self.unet.fuse_qkv_projections()
744+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
745+
746+
if vae:
747+
if not isinstance(self.vae, AutoencoderKL):
748+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
749+
750+
self.fusing_vae = True
751+
self.vae.fuse_qkv_projections()
752+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
753+
754+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
755+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
756+
"""Disable QKV projection fusion if enabled.
757+
758+
<Tip warning={true}>
759+
760+
This API is 🧪 experimental.
761+
762+
</Tip>
763+
764+
Args:
765+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
766+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
767+
768+
"""
769+
if unet:
770+
if not self.fusing_unet:
771+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
772+
else:
773+
self.unet.unfuse_qkv_projections()
774+
self.fusing_unet = False
775+
776+
if vae:
777+
if not self.fusing_vae:
778+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
779+
else:
780+
self.vae.unfuse_qkv_projections()
781+
self.fusing_vae = False
782+
721783
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
722784
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
723785
"""

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ...image_processor import PipelineImageInput, VaeImageProcessor
2626
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
2727
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
28+
from ...models.attention_processor import FusedAttnProcessor2_0
2829
from ...models.lora import adjust_lora_scale_text_encoder
2930
from ...schedulers import KarrasDiffusionSchedulers
3031
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -844,6 +845,67 @@ def disable_freeu(self):
844845
"""Disables the FreeU mechanism if enabled."""
845846
self.unet.disable_freeu()
846847

848+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.fuse_qkv_projections
849+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
850+
"""
851+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
852+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
853+
854+
<Tip warning={true}>
855+
856+
This API is 🧪 experimental.
857+
858+
</Tip>
859+
860+
Args:
861+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
862+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
863+
"""
864+
self.fusing_unet = False
865+
self.fusing_vae = False
866+
867+
if unet:
868+
self.fusing_unet = True
869+
self.unet.fuse_qkv_projections()
870+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
871+
872+
if vae:
873+
if not isinstance(self.vae, AutoencoderKL):
874+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
875+
876+
self.fusing_vae = True
877+
self.vae.fuse_qkv_projections()
878+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
879+
880+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
881+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
882+
"""Disable QKV projection fusion if enabled.
883+
884+
<Tip warning={true}>
885+
886+
This API is 🧪 experimental.
887+
888+
</Tip>
889+
890+
Args:
891+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
892+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
893+
894+
"""
895+
if unet:
896+
if not self.fusing_unet:
897+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
898+
else:
899+
self.unet.unfuse_qkv_projections()
900+
self.fusing_unet = False
901+
902+
if vae:
903+
if not self.fusing_vae:
904+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
905+
else:
906+
self.vae.unfuse_qkv_projections()
907+
self.fusing_vae = False
908+
847909
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
848910
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
849911
"""

0 commit comments

Comments
 (0)