Skip to content

Commit a2bc2e1

Browse files
[feat] allow SDXL pipeline to run with fused QKV projections (#6030)
* debug * from step * print * turn sigma a list * make str * init_noise_sigma * comment * remove prints * feat: introduce fused projections * change to a better name * no grad * device. * device * dtype * okay * print * more print * fix: unbind -> split * fix: qkv >-> k * enable disable * apply attention processor within the method * attn processors * _enable_fused_qkv_projections * remove print * add fused projection to vae * add todos. * add: documentation and cleanups. * add: test for qkv projection fusion. * relax assertions. * relax further * fix: docs * fix-copies * correct error message. * Empty-Commit * better conditioning on disable_fused_qkv_projections * check * check processor * bfloat16 computation. * check latent dtype * style * remove copy temporarily * cast latent to bfloat16 * fix: vae -> self.vae * remove print. * add _change_to_group_norm_32 * comment out stuff that didn't work * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * reflect patrick's suggestions. * fix imports * fix: disable call. * fix more * fix device and dtype * fix conditions. * fix more * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent f427345 commit a2bc2e1

File tree

9 files changed

+342
-7
lines changed

9 files changed

+342
-7
lines changed

docs/source/en/api/attnprocessor.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
2020
## AttnProcessor2_0
2121
[[autodoc]] models.attention_processor.AttnProcessor2_0
2222

23+
## FusedAttnProcessor2_0
24+
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
25+
2326
## LoRAAttnProcessor
2427
[[autodoc]] models.attention_processor.LoRAAttnProcessor
2528

src/diffusers/models/attention_processor.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,14 @@ def __init__(
113113
):
114114
super().__init__()
115115
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
116+
self.query_dim = query_dim
116117
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
117118
self.upcast_attention = upcast_attention
118119
self.upcast_softmax = upcast_softmax
119120
self.rescale_output_factor = rescale_output_factor
120121
self.residual_connection = residual_connection
121122
self.dropout = dropout
123+
self.fused_projections = False
122124
self.out_dim = out_dim if out_dim is not None else query_dim
123125

124126
# we make use of this private variable to know whether this class is loaded
@@ -180,6 +182,7 @@ def __init__(
180182
else:
181183
linear_cls = LoRACompatibleLinear
182184

185+
self.linear_cls = linear_cls
183186
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
184187

185188
if not self.only_cross_attention:
@@ -692,6 +695,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor
692695

693696
return encoder_hidden_states
694697

698+
@torch.no_grad()
699+
def fuse_projections(self, fuse=True):
700+
is_cross_attention = self.cross_attention_dim != self.query_dim
701+
device = self.to_q.weight.data.device
702+
dtype = self.to_q.weight.data.dtype
703+
704+
if not is_cross_attention:
705+
# fetch weight matrices.
706+
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
707+
in_features = concatenated_weights.shape[1]
708+
out_features = concatenated_weights.shape[0]
709+
710+
# create a new single projection layer and copy over the weights.
711+
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
712+
self.to_qkv.weight.copy_(concatenated_weights)
713+
714+
else:
715+
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
716+
in_features = concatenated_weights.shape[1]
717+
out_features = concatenated_weights.shape[0]
718+
719+
self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
720+
self.to_kv.weight.copy_(concatenated_weights)
721+
722+
self.fused_projections = fuse
723+
695724

696725
class AttnProcessor:
697726
r"""
@@ -1184,9 +1213,6 @@ def __call__(
11841213
scale: float = 1.0,
11851214
) -> torch.FloatTensor:
11861215
residual = hidden_states
1187-
1188-
args = () if USE_PEFT_BACKEND else (scale,)
1189-
11901216
if attn.spatial_norm is not None:
11911217
hidden_states = attn.spatial_norm(hidden_states, temb)
11921218

@@ -1253,6 +1279,103 @@ def __call__(
12531279
return hidden_states
12541280

12551281

1282+
class FusedAttnProcessor2_0:
1283+
r"""
1284+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1285+
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
1286+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
1287+
1288+
<Tip warning={true}>
1289+
1290+
This API is currently 🧪 experimental in nature and can change in future.
1291+
1292+
</Tip>
1293+
"""
1294+
1295+
def __init__(self):
1296+
if not hasattr(F, "scaled_dot_product_attention"):
1297+
raise ImportError(
1298+
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
1299+
)
1300+
1301+
def __call__(
1302+
self,
1303+
attn: Attention,
1304+
hidden_states: torch.FloatTensor,
1305+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
1306+
attention_mask: Optional[torch.FloatTensor] = None,
1307+
temb: Optional[torch.FloatTensor] = None,
1308+
scale: float = 1.0,
1309+
) -> torch.FloatTensor:
1310+
residual = hidden_states
1311+
if attn.spatial_norm is not None:
1312+
hidden_states = attn.spatial_norm(hidden_states, temb)
1313+
1314+
input_ndim = hidden_states.ndim
1315+
1316+
if input_ndim == 4:
1317+
batch_size, channel, height, width = hidden_states.shape
1318+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1319+
1320+
batch_size, sequence_length, _ = (
1321+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1322+
)
1323+
1324+
if attention_mask is not None:
1325+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1326+
# scaled_dot_product_attention expects attention_mask shape to be
1327+
# (batch, heads, source_length, target_length)
1328+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1329+
1330+
if attn.group_norm is not None:
1331+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1332+
1333+
args = () if USE_PEFT_BACKEND else (scale,)
1334+
if encoder_hidden_states is None:
1335+
qkv = attn.to_qkv(hidden_states, *args)
1336+
split_size = qkv.shape[-1] // 3
1337+
query, key, value = torch.split(qkv, split_size, dim=-1)
1338+
else:
1339+
if attn.norm_cross:
1340+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1341+
query = attn.to_q(hidden_states, *args)
1342+
1343+
kv = attn.to_kv(encoder_hidden_states, *args)
1344+
split_size = kv.shape[-1] // 2
1345+
key, value = torch.split(kv, split_size, dim=-1)
1346+
1347+
inner_dim = key.shape[-1]
1348+
head_dim = inner_dim // attn.heads
1349+
1350+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1351+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1352+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1353+
1354+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1355+
# TODO: add support for attn.scale when we move to Torch 2.1
1356+
hidden_states = F.scaled_dot_product_attention(
1357+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1358+
)
1359+
1360+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1361+
hidden_states = hidden_states.to(query.dtype)
1362+
1363+
# linear proj
1364+
hidden_states = attn.to_out[0](hidden_states, *args)
1365+
# dropout
1366+
hidden_states = attn.to_out[1](hidden_states)
1367+
1368+
if input_ndim == 4:
1369+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1370+
1371+
if attn.residual_connection:
1372+
hidden_states = hidden_states + residual
1373+
1374+
hidden_states = hidden_states / attn.rescale_output_factor
1375+
1376+
return hidden_states
1377+
1378+
12561379
class CustomDiffusionXFormersAttnProcessor(nn.Module):
12571380
r"""
12581381
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
@@ -2251,6 +2374,7 @@ def __call__(
22512374
AttentionProcessor = Union[
22522375
AttnProcessor,
22532376
AttnProcessor2_0,
2377+
FusedAttnProcessor2_0,
22542378
XFormersAttnProcessor,
22552379
SlicedAttnProcessor,
22562380
AttnAddedKVProcessor,

src/diffusers/models/autoencoder_kl.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .attention_processor import (
2323
ADDED_KV_ATTENTION_PROCESSORS,
2424
CROSS_ATTENTION_PROCESSORS,
25+
Attention,
2526
AttentionProcessor,
2627
AttnAddedKVProcessor,
2728
AttnProcessor,
@@ -448,3 +449,41 @@ def forward(
448449
return (dec,)
449450

450451
return DecoderOutput(sample=dec)
452+
453+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
454+
def fuse_qkv_projections(self):
455+
"""
456+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
457+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
458+
459+
<Tip warning={true}>
460+
461+
This API is 🧪 experimental.
462+
463+
</Tip>
464+
"""
465+
self.original_attn_processors = None
466+
467+
for _, attn_processor in self.attn_processors.items():
468+
if "Added" in str(attn_processor.__class__.__name__):
469+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
470+
471+
self.original_attn_processors = self.attn_processors
472+
473+
for module in self.modules():
474+
if isinstance(module, Attention):
475+
module.fuse_projections(fuse=True)
476+
477+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
478+
def unfuse_qkv_projections(self):
479+
"""Disables the fused QKV projection if enabled.
480+
481+
<Tip warning={true}>
482+
483+
This API is 🧪 experimental.
484+
485+
</Tip>
486+
487+
"""
488+
if self.original_attn_processors is not None:
489+
self.set_attn_processor(self.original_attn_processors)

src/diffusers/models/unet_2d_condition.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .attention_processor import (
2626
ADDED_KV_ATTENTION_PROCESSORS,
2727
CROSS_ATTENTION_PROCESSORS,
28+
Attention,
2829
AttentionProcessor,
2930
AttnAddedKVProcessor,
3031
AttnProcessor,
@@ -794,6 +795,42 @@ def disable_freeu(self):
794795
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
795796
setattr(upsample_block, k, None)
796797

798+
def fuse_qkv_projections(self):
799+
"""
800+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
801+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
802+
803+
<Tip warning={true}>
804+
805+
This API is 🧪 experimental.
806+
807+
</Tip>
808+
"""
809+
self.original_attn_processors = None
810+
811+
for _, attn_processor in self.attn_processors.items():
812+
if "Added" in str(attn_processor.__class__.__name__):
813+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
814+
815+
self.original_attn_processors = self.attn_processors
816+
817+
for module in self.modules():
818+
if isinstance(module, Attention):
819+
module.fuse_projections(fuse=True)
820+
821+
def unfuse_qkv_projections(self):
822+
"""Disables the fused QKV projection if enabled.
823+
824+
<Tip warning={true}>
825+
826+
This API is 🧪 experimental.
827+
828+
</Tip>
829+
830+
"""
831+
if self.original_attn_processors is not None:
832+
self.set_attn_processor(self.original_attn_processors)
833+
797834
def forward(
798835
self,
799836
sample: torch.FloatTensor,

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
3535
from ...models.attention_processor import (
3636
AttnProcessor2_0,
37+
FusedAttnProcessor2_0,
3738
LoRAAttnProcessor2_0,
3839
LoRAXFormersAttnProcessor,
3940
XFormersAttnProcessor,
@@ -681,7 +682,6 @@ def _get_add_time_ids(
681682
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
682683
return add_time_ids
683684

684-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
685685
def upcast_vae(self):
686686
dtype = self.vae.dtype
687687
self.vae.to(dtype=torch.float32)
@@ -692,6 +692,7 @@ def upcast_vae(self):
692692
XFormersAttnProcessor,
693693
LoRAXFormersAttnProcessor,
694694
LoRAAttnProcessor2_0,
695+
FusedAttnProcessor2_0,
695696
),
696697
)
697698
# if xformers or torch_2_0 is used attention block does not need
@@ -729,6 +730,65 @@ def disable_freeu(self):
729730
"""Disables the FreeU mechanism if enabled."""
730731
self.unet.disable_freeu()
731732

733+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
734+
"""
735+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
736+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
737+
738+
<Tip warning={true}>
739+
740+
This API is 🧪 experimental.
741+
742+
</Tip>
743+
744+
Args:
745+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
746+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
747+
"""
748+
self.fusing_unet = False
749+
self.fusing_vae = False
750+
751+
if unet:
752+
self.fusing_unet = True
753+
self.unet.fuse_qkv_projections()
754+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
755+
756+
if vae:
757+
if not isinstance(self.vae, AutoencoderKL):
758+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
759+
760+
self.fusing_vae = True
761+
self.vae.fuse_qkv_projections()
762+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
763+
764+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
765+
"""Disable QKV projection fusion if enabled.
766+
767+
<Tip warning={true}>
768+
769+
This API is 🧪 experimental.
770+
771+
</Tip>
772+
773+
Args:
774+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
775+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
776+
777+
"""
778+
if unet:
779+
if not self.fusing_unet:
780+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
781+
else:
782+
self.unet.unfuse_qkv_projections()
783+
self.fusing_unet = False
784+
785+
if vae:
786+
if not self.fusing_vae:
787+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
788+
else:
789+
self.vae.unfuse_qkv_projections()
790+
self.fusing_vae = False
791+
732792
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
733793
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
734794
"""

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from ...models import AutoencoderKL, UNet2DConditionModel
2525
from ...models.attention_processor import (
2626
AttnProcessor2_0,
27+
FusedAttnProcessor2_0,
2728
LoRAAttnProcessor2_0,
2829
LoRAXFormersAttnProcessor,
2930
XFormersAttnProcessor,
@@ -610,6 +611,7 @@ def upcast_vae(self):
610611
XFormersAttnProcessor,
611612
LoRAXFormersAttnProcessor,
612613
LoRAAttnProcessor2_0,
614+
FusedAttnProcessor2_0,
613615
),
614616
)
615617
# if xformers or torch_2_0 is used attention block does not need

0 commit comments

Comments
 (0)