Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
bf4e645
debug
sayakpaul Dec 1, 2023
afb517a
from step
sayakpaul Dec 1, 2023
55f1842
print
sayakpaul Dec 1, 2023
215bf3b
turn sigma a list
sayakpaul Dec 1, 2023
75ae3df
make str
sayakpaul Dec 1, 2023
ff04934
init_noise_sigma
sayakpaul Dec 1, 2023
096fffb
comment
sayakpaul Dec 1, 2023
bd855d7
remove prints
sayakpaul Dec 1, 2023
88c7e16
feat: introduce fused projections
sayakpaul Dec 1, 2023
f5b091d
change to a better name
sayakpaul Dec 1, 2023
a4da76b
no grad
sayakpaul Dec 1, 2023
c5a5f85
device.
sayakpaul Dec 1, 2023
4e556a9
device
sayakpaul Dec 1, 2023
86027e5
dtype
sayakpaul Dec 1, 2023
a030797
okay
sayakpaul Dec 1, 2023
01c6038
print
sayakpaul Dec 1, 2023
c4eaec3
more print
sayakpaul Dec 1, 2023
a7da467
fix: unbind -> split
sayakpaul Dec 1, 2023
94fb74a
fix: qkv >-> k
sayakpaul Dec 1, 2023
678577b
enable disable
sayakpaul Dec 1, 2023
580a1c2
apply attention processor within the method
sayakpaul Dec 1, 2023
06bb65b
attn processors
sayakpaul Dec 1, 2023
a0b9066
_enable_fused_qkv_projections
sayakpaul Dec 1, 2023
32012ce
remove print
sayakpaul Dec 1, 2023
5175b91
add fused projection to vae
sayakpaul Dec 1, 2023
7b16888
add todos.
sayakpaul Dec 2, 2023
ba14a08
merge main and resolve conflicts
sayakpaul Dec 2, 2023
23f8404
add: documentation and cleanups.
sayakpaul Dec 3, 2023
e51bc7e
add: test for qkv projection fusion.
sayakpaul Dec 3, 2023
b64e533
relax assertions.
sayakpaul Dec 3, 2023
c7f78bf
relax further
sayakpaul Dec 3, 2023
981dc3a
fix: docs
sayakpaul Dec 3, 2023
be647c3
Merge branch 'main' into sdxl/feat
sayakpaul Dec 3, 2023
0afc2b4
fix-copies
sayakpaul Dec 3, 2023
2c02f07
correct error message.
sayakpaul Dec 3, 2023
e0848eb
Empty-Commit
sayakpaul Dec 3, 2023
4b66d10
better conditioning on disable_fused_qkv_projections
sayakpaul Dec 4, 2023
6c5712c
Merge branch 'main' into sdxl/feat
sayakpaul Dec 4, 2023
8da35af
check
sayakpaul Dec 4, 2023
4e120d8
check processor
sayakpaul Dec 4, 2023
253aaf0
bfloat16 computation.
sayakpaul Dec 4, 2023
44d4263
check latent dtype
sayakpaul Dec 4, 2023
418d33c
style
sayakpaul Dec 4, 2023
8c0c3e2
remove copy temporarily
sayakpaul Dec 4, 2023
1688fee
cast latent to bfloat16
sayakpaul Dec 4, 2023
4f882ab
fix: vae -> self.vae
sayakpaul Dec 4, 2023
2632a8b
remove print.
sayakpaul Dec 4, 2023
d944d8b
add _change_to_group_norm_32
sayakpaul Dec 4, 2023
0432297
comment out stuff that didn't work
sayakpaul Dec 4, 2023
7d8b913
Apply suggestions from code review
sayakpaul Dec 4, 2023
ff28fdd
reflect patrick's suggestions.
sayakpaul Dec 4, 2023
93b5f92
fix imports
sayakpaul Dec 4, 2023
a7a952d
Merge branch 'main' into sdxl/feat
sayakpaul Dec 4, 2023
8d17831
fix: disable call.
sayakpaul Dec 4, 2023
d17bbbd
fix more
sayakpaul Dec 4, 2023
a5fb4d7
fix device and dtype
sayakpaul Dec 4, 2023
8fadb14
fix conditions.
sayakpaul Dec 4, 2023
c6d5e86
fix more
sayakpaul Dec 4, 2023
abf9ebc
Apply suggestions from code review
sayakpaul Dec 4, 2023
d485abd
Merge branch 'main' into sdxl/feat
sayakpaul Dec 4, 2023
e65ddcd
Merge branch 'main' into sdxl/feat
sayakpaul Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/en/api/attnprocessor.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ An attention processor is a class for applying different types of attention mech
## AttnProcessor2_0
[[autodoc]] models.attention_processor.AttnProcessor2_0

## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0

## LoRAAttnProcessor
[[autodoc]] models.attention_processor.LoRAAttnProcessor

Expand Down
130 changes: 127 additions & 3 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,14 @@ def __init__(
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim

# we make use of this private variable to know whether this class is loaded
Expand Down Expand Up @@ -180,6 +182,7 @@ def __init__(
else:
linear_cls = LoRACompatibleLinear

self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)

if not self.only_cross_attention:
Expand Down Expand Up @@ -692,6 +695,32 @@ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> tor

return encoder_hidden_states

@torch.no_grad()
def fuse_projections(self, fuse=True):
is_cross_attention = self.cross_attention_dim != self.query_dim
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype

if not is_cross_attention:
# fetch weight matrices.
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

# create a new single projection layer and copy over the weights.
self.to_qkv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_qkv.weight.copy_(concatenated_weights)

else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]

self.to_kv = self.linear_cls(in_features, out_features, bias=False, device=device, dtype=dtype)
self.to_kv.weight.copy_(concatenated_weights)

self.fused_projections = fuse


class AttnProcessor:
r"""
Expand Down Expand Up @@ -1184,9 +1213,6 @@ def __call__(
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states

args = () if USE_PEFT_BACKEND else (scale,)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

Expand Down Expand Up @@ -1253,6 +1279,103 @@ def __call__(
return hidden_states


class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
It uses fused projection layers. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

<Tip warning={true}>

This API is currently 🧪 experimental in nature and can change in future.

</Tip>
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"FusedAttnProcessor2_0 requires at least PyTorch 2.0, to use it. Please upgrade PyTorch to > 2.0."
)

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
else:
if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args)

kv = attn.to_kv(encoder_hidden_states, *args)
split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class CustomDiffusionXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method.
Expand Down Expand Up @@ -2251,6 +2374,7 @@ def __call__(
AttentionProcessor = Union[
AttnProcessor,
AttnProcessor2_0,
FusedAttnProcessor2_0,
XFormersAttnProcessor,
SlicedAttnProcessor,
AttnAddedKVProcessor,
Expand Down
39 changes: 39 additions & 0 deletions src/diffusers/models/autoencoder_kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
Expand Down Expand Up @@ -448,3 +449,41 @@ def forward(
return (dec,)

return DecoderOutput(sample=dec)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>
"""
self.original_attn_processors = None

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>

"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
37 changes: 37 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
Attention,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
Expand Down Expand Up @@ -794,6 +795,42 @@ def disable_freeu(self):
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
setattr(upsample_block, k, None)

def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>
"""
self.original_attn_processors = None

for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")

self.original_attn_processors = self.attn_processors

for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)

def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>

"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)

def forward(
self,
sample: torch.FloatTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
Expand Down Expand Up @@ -681,7 +682,6 @@ def _get_add_time_ids(
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
return add_time_ids

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
Expand All @@ -692,6 +692,7 @@ def upcast_vae(self):
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
Expand Down Expand Up @@ -729,6 +730,65 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()

def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>

Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False

if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())

if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")

self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())

def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.

<Tip warning={true}>

This API is 🧪 experimental.

</Tip>

Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.

"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False

if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False

# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
Expand Down Expand Up @@ -610,6 +611,7 @@ def upcast_vae(self):
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
Expand Down
Loading