diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index 7a4812e0961e..0b11c1f5bc5d 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -17,6 +17,9 @@ An attention processor is a class for applying different types of attention mech ## CustomDiffusionAttnProcessor [[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor +## CustomDiffusionAttnProcessor2_0 +[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0 + ## AttnAddedKVProcessor [[autodoc]] models.attention_processor.AttnAddedKVProcessor @@ -39,4 +42,4 @@ An attention processor is a class for applying different types of attention mech [[autodoc]] models.attention_processor.SlicedAttnProcessor ## SlicedAttnAddedKVProcessor -[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor \ No newline at end of file +[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 0d5cf695bd4f..60d8d6723dcf 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -51,7 +51,11 @@ UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor +from diffusers.models.attention_processor import ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, +) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -870,7 +874,9 @@ def main(args): unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - attention_class = CustomDiffusionAttnProcessor + attention_class = ( + CustomDiffusionAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else CustomDiffusionAttnProcessor + ) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 45c866c1aa16..e2589f3abcf6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -559,6 +559,7 @@ def save_attn_procs( """ from .models.attention_processor import ( CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, ) @@ -578,7 +579,10 @@ def save_function(weights, filename): os.makedirs(save_directory, exist_ok=True) is_custom_diffusion = any( - isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + isinstance( + x, + (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor), + ) for (_, x) in self.attn_processors.items() ) if is_custom_diffusion: @@ -586,7 +590,14 @@ def save_function(weights, filename): { y: x for (y, x) in self.attn_processors.items() - if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)) + if isinstance( + x, + ( + CustomDiffusionAttnProcessor, + CustomDiffusionAttnProcessor2_0, + CustomDiffusionXFormersAttnProcessor, + ), + ) } ) state_dict = model_to_save.state_dict() diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 34593a4e77ec..fba5bddb5def 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -173,7 +173,8 @@ def set_use_memory_efficient_attention_xformers( LORA_ATTENTION_PROCESSORS, ) is_custom_diffusion = hasattr(self, "processor") and isinstance( - self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) + self.processor, + (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), ) is_added_kv_processor = hasattr(self, "processor") and isinstance( self.processor, @@ -261,7 +262,12 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) elif is_custom_diffusion: - processor = CustomDiffusionAttnProcessor( + attn_processor_class = ( + CustomDiffusionAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else CustomDiffusionAttnProcessor + ) + processor = attn_processor_class( train_kv=self.processor.train_kv, train_q_out=self.processor.train_q_out, hidden_size=self.processor.hidden_size, @@ -1156,6 +1162,111 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states +class CustomDiffusionAttnProcessor2_0(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled + dot-product attention. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + + def __init__( + self, + train_kv=True, + train_q_out=True, + hidden_size=None, + cross_attention_dim=None, + out_bias=True, + dropout=0.0, + ): + super().__init__() + self.train_kv = train_kv + self.train_q_out = train_q_out + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + + # `_custom_diffusion` id for easy serialization and loading. + if self.train_kv: + self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + if self.train_q_out: + self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False) + self.to_out_custom_diffusion = nn.ModuleList([]) + self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias)) + self.to_out_custom_diffusion.append(nn.Dropout(dropout)) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if self.train_q_out: + query = self.to_q_custom_diffusion(hidden_states) + else: + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + crossattn = False + encoder_hidden_states = hidden_states + else: + crossattn = True + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + if self.train_kv: + key = self.to_k_custom_diffusion(encoder_hidden_states) + value = self.to_v_custom_diffusion(encoder_hidden_states) + else: + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if crossattn: + detach = torch.ones_like(key) + detach[:, :1, :] = detach[:, :1, :] * 0.0 + key = detach * key + (1 - detach) * key.detach() + value = detach * value + (1 - detach) * value.detach() + + inner_dim = hidden_states.shape[-1] + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if self.train_q_out: + # linear proj + hidden_states = self.to_out_custom_diffusion[0](hidden_states) + # dropout + hidden_states = self.to_out_custom_diffusion[1](hidden_states) + else: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states + + class SlicedAttnProcessor: r""" Processor for implementing sliced attention. @@ -1639,6 +1750,7 @@ def __call__(self, attn: Attention, hidden_states, *args, **kwargs): XFormersAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, + CustomDiffusionAttnProcessor2_0, # depraceted LoRAAttnProcessor, LoRAAttnProcessor2_0,