From 72a9489e9734a6714840d836ea9a97980a10cfb5 Mon Sep 17 00:00:00 2001 From: William Berman Date: Wed, 15 Mar 2023 13:29:04 -0700 Subject: [PATCH 1/4] deprecate AttentionBlock --- .../train_unconditional.py | 1 + .../train_unconditional.py | 2 + src/diffusers/models/attention.py | 155 ++++++----------- src/diffusers/models/attention_processor.py | 103 ++++++++++- src/diffusers/models/autoencoder_kl.py | 3 + src/diffusers/models/modeling_utils.py | 72 +++++++- src/diffusers/models/unet_2d.py | 4 + src/diffusers/models/unet_2d_blocks.py | 164 +++++++++++++++--- src/diffusers/models/vae.py | 6 + src/diffusers/models/vq_model.py | 3 + src/diffusers/pipelines/pipeline_utils.py | 7 +- tests/models/test_models_unet_2d.py | 3 + tests/models/test_models_vae.py | 1 + tests/models/test_models_vq.py | 1 + tests/test_layers_utils.py | 51 ++++++ tests/test_unet_2d_blocks.py | 4 +- tests/test_unet_blocks_common.py | 7 + 17 files changed, 459 insertions(+), 128 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 1b38036d82c0..edb315429474 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -350,6 +350,7 @@ def main(args): "UpBlock2D", "UpBlock2D", ), + attention_block_type="Attention", ) # Create EMA for the model. diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3b784eda6a34..3a399d4926f8 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -397,10 +397,12 @@ def load_model_hook(models, input_dir): "UpBlock2D", "UpBlock2D", ), + attention_block_type="Attention", ) else: config = UNet2DModel.load_config(args.model_config_name_or_path) model = UNet2DModel.from_config(config) + model._convert_deprecated_attention_blocks() # Create EMA for the model. if args.use_ema: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 5c7e54e7cd32..8527f17e74c9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,27 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, Optional import torch import torch.nn.functional as F from torch import nn -from ..utils.import_utils import is_xformers_available -from .attention_processor import Attention +from ..utils import deprecate +from .attention_processor import Attention, SpatialAttnProcessor from .embeddings import CombinedTimestepLabelEmbeddings -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - class AttentionBlock(nn.Module): """ + This class is deprecated. Its forward method will throw an error. On model load, we convert all instances of + `AttentionBlock` to `diffusers.models.attention_processor.Attention`, see + `ModelMixin#_convert_deprecated_attention_blocks`. + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. @@ -46,8 +42,6 @@ class AttentionBlock(nn.Module): eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - def __init__( self, channels: int, @@ -57,6 +51,16 @@ def __init__( eps: float = 1e-5, ): super().__init__() + + deprecation_message = ( + "`AttentionBlock` has been deprecated and will be replaced with `diffusers.models.attention_processor.Attention`." + " The DiffusionPipeline loading this block in is auto converting it to `diffusers.models.attention_processor.Attention`." + " Please call `DiffusionPipeline#save_pretrained` and re-upload the pipeline to the hub." + " If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`." + ) + + deprecate("AttentionBlock", "0.18.0", deprecation_message, standard_warn=True) + self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 @@ -71,107 +75,54 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, bias=True) - self._use_memory_efficient_attention_xformers = False - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self._attention_op = attention_op + raise ValueError( + "`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`" + ) def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) + raise ValueError( + "`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`" + ) - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op - ) - hidden_states = hidden_states.to(query_proj.dtype) + def _as_attention_processor_attention(self): + if self.num_head_size is None: + # When `self.num_head_size` is None, there is a single attention head + # of all the channels + dim_head = self.channels else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) + dim_head = self.num_head_size + + # This will allocate some additional memory but as this is only done once during model load, + # it should be ok. + attn = Attention( + self.channels, + heads=self.num_heads, + dim_head=dim_head, + bias=True, + upcast_softmax=True, + norm_num_groups=self.group_norm.num_groups, + processor=SpatialAttnProcessor(), + eps=self.group_norm.eps, + rescale_output_factor=self.rescale_output_factor, + ) - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + param = next(self.parameters()) - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) + device = param.device + dtype = param.dtype - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + attn.to(device=device, dtype=dtype) - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states + attn.group_norm.load_state_dict(self.group_norm.state_dict()) + attn.to_q.load_state_dict(self.query.state_dict()) + attn.to_k.load_state_dict(self.key.state_dict()) + attn.to_v.load_state_dict(self.value.state_dict()) + attn.to_out[0].load_state_dict(self.proj_attn.state_dict()) + + return attn class BasicTransformerBlock(nn.Module): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30026cd89ff9..c9e907e74f1c 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -62,6 +62,8 @@ def __init__( out_bias: bool = True, scale_qk: bool = True, processor: Optional["AttnProcessor"] = None, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, ): super().__init__() inner_dim = dim_head * heads @@ -69,6 +71,7 @@ def __init__( self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.cross_attention_norm = cross_attention_norm + self.rescale_output_factor = rescale_output_factor self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -81,7 +84,7 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None @@ -117,6 +120,10 @@ def set_use_memory_efficient_attention_xformers( self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) ) + is_spatial_attention = hasattr(self, "processor") and isinstance( + self.processor, (SpatialAttnProcessor, XFormersSpatialAttnProcessor) + ) + if use_memory_efficient_attention_xformers: if self.added_kv_proj_dim is not None: # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP @@ -159,6 +166,8 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) + elif is_spatial_attention: + processor = XFormersSpatialAttnProcessor() else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -170,6 +179,8 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) + elif is_spatial_attention: + processor = SpatialAttnProcessor() else: processor = AttnProcessor() @@ -684,6 +695,94 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states +class SpatialAttnProcessor: + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + if attention_mask is not None: + raise ValueError(f"{self.__class__.__name__} does not support `attention_mask`") + + if encoder_hidden_states is not None: + raise ValueError(f"{self.__class__.__name__} does not support `encoder_hidden_states`") + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = attn.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = attn.to_q(hidden_states) + key_proj = attn.to_k(hidden_states) + value_proj = attn.to_v(hidden_states) + + query_proj = attn.head_to_batch_dim(query_proj) + key_proj = attn.head_to_batch_dim(key_proj) + value_proj = attn.head_to_batch_dim(value_proj) + + attention_probs = attn.get_attention_scores(query_proj, key_proj) + hidden_states = torch.bmm(attention_probs, value_proj) + + # reshape hidden_states + hidden_states = attn.batch_to_head_dim(hidden_states) + + # compute next hidden_states + hidden_states = attn.to_out[0](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / attn.rescale_output_factor + return hidden_states + + +class XFormersSpatialAttnProcessor: + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + if attention_mask is not None: + raise ValueError(f"{self.__class__.__name__} does not support `attention_mask`") + + if encoder_hidden_states is not None: + raise ValueError(f"{self.__class__.__name__} does not support `encoder_hidden_states`") + + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = attn.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = attn.to_q(hidden_states) + key_proj = attn.to_k(hidden_states) + value_proj = attn.to_v(hidden_states) + + query_proj = attn.head_to_batch_dim(query_proj) + key_proj = attn.head_to_batch_dim(key_proj) + value_proj = attn.head_to_batch_dim(value_proj) + + # Memory efficient attention + hidden_states = xformers.ops.memory_efficient_attention( + query_proj, key_proj, value_proj, attn_bias=None, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query_proj.dtype) + + # reshape hidden_states + hidden_states = attn.batch_to_head_dim(hidden_states) + + # compute next hidden_states + hidden_states = attn.to_out[0](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / attn.rescale_output_factor + return hidden_states + + AttentionProcessor = Union[ AttnProcessor, XFormersAttnProcessor, @@ -692,4 +791,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, SlicedAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, + SpatialAttnProcessor, + XFormersSpatialAttnProcessor, ] diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 9c0161065e4c..fb64089f0245 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -81,6 +81,7 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, + attention_block_type: str = "AttentionBlock", ): super().__init__() @@ -94,6 +95,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, + attention_block_type=attention_block_type, ) # pass init params to Decoder @@ -105,6 +107,7 @@ def __init__( layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, + attention_block_type=attention_block_type, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e51b40ce4509..e43d4c0bbf58 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -25,7 +25,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from packaging import version from requests import HTTPError -from torch import Tensor, device +from torch import Tensor, device, nn from .. import __version__ from ..utils import ( @@ -629,6 +629,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.register_to_config(_name_or_path=pretrained_model_name_or_path) + # The model has had its weights loaded. If the deprecated attention block weights + # were in the old format, it's now safe to convert them. + model._convert_deprecated_attention_blocks() + # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: @@ -783,6 +787,72 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + def _convert_deprecated_attention_blocks(self): + """ + `diffusers.models.attention.AttentionBlock` is deprecated and must be converted to + `diffusers.models.attention_processor.Attention`. + + See https://github.com/huggingface/diffusers/issues/1880 and https://github.com/huggingface/diffusers/pull/2697 + for more details + """ + models_with_deprecated_attention = ["UNet2DModel", "VQModel", "AutoencoderKL"] + + blocks_with_deprecated_attention = [ + "AttnDownBlock2D", + "AttnSkipDownBlock2D", + "AttnDownEncoderBlock2D", + "AttnUpBlock2D", + "AttnSkipUpBlock2D", + "AttnUpDecoderBlock2D", + "UNetMidBlock2D", + ] + + if self.__class__.__name__ not in models_with_deprecated_attention: + return + + if self.config.attention_block_type == "Attention": + # Model as already been converted + return + + self.register_to_config(attention_block_type="Attention") + + def _convert_deprecated_attention_blocks_in_module(module): + if module.__class__.__name__ not in blocks_with_deprecated_attention: + return + + attentions = nn.ModuleList() + + for attention in module.attentions: + if attention is not None: + attention = attention._as_attention_processor_attention() + + attentions.append(attention) + + module.attentions = attentions + + if self.__class__.__name__ == "UNet2DModel": + for down_block in self.down_blocks: + _convert_deprecated_attention_blocks_in_module(down_block) + + _convert_deprecated_attention_blocks_in_module(self.mid_block) + + for up_block in self.up_blocks: + _convert_deprecated_attention_blocks_in_module(up_block) + + elif self.__class__.__name__ in ["VQModel", "AutoencoderKL"]: + for down_block in self.encoder.down_blocks: + _convert_deprecated_attention_blocks_in_module(down_block) + + _convert_deprecated_attention_blocks_in_module(self.encoder.mid_block) + + _convert_deprecated_attention_blocks_in_module(self.decoder.mid_block) + + for up_block in self.decoder.up_blocks: + _convert_deprecated_attention_blocks_in_module(up_block) + + else: + assert False + def _get_model_file( pretrained_model_name_or_path, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2df6e60d88c9..ef543ef3d586 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -102,6 +102,7 @@ def __init__( add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, + attention_block_type: str = "AttentionBlock", ): super().__init__() @@ -166,6 +167,7 @@ def __init__( attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) self.down_blocks.append(down_block) @@ -180,6 +182,7 @@ def __init__( attn_num_head_channels=attention_head_dim, resnet_groups=norm_num_groups, add_attention=add_attention, + attention_block_type=attention_block_type, ) # up @@ -205,6 +208,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3070351279b8..46ae83e0d4a6 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ from torch import nn from .attention import AdaGroupNorm, AttentionBlock -from .attention_processor import Attention, AttnAddedKVProcessor +from .attention_processor import Attention, AttnAddedKVProcessor, SpatialAttnProcessor from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -42,6 +42,7 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + attention_block_type="AttentionBlock", ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -82,6 +83,7 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -144,6 +146,7 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( @@ -169,6 +172,7 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( @@ -214,6 +218,7 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + attention_block_type="AttentionBlock", ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -293,6 +298,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -318,6 +324,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( @@ -341,6 +348,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( @@ -383,6 +391,7 @@ def __init__( add_attention: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -407,15 +416,29 @@ def __init__( for _ in range(num_layers): if self.add_attention: - attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( in_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + in_channels, + heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + attentions.append(attention) else: attentions.append(None) @@ -658,6 +681,7 @@ def __init__( output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -679,15 +703,28 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1006,6 +1043,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -1027,15 +1065,29 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1079,6 +1131,7 @@ def __init__( output_scale_factor=np.sqrt(2.0), downsample_padding=1, add_downsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() self.attentions = nn.ModuleList([]) @@ -1101,14 +1154,30 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - self.attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=32, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + self.attentions.append(attention) if add_downsample: self.resnet_down = ResnetBlock2D( @@ -1632,6 +1701,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -1655,15 +1725,31 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1966,6 +2052,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -1988,15 +2075,31 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -2035,6 +2138,7 @@ def __init__( output_scale_factor=np.sqrt(2.0), upsample_padding=1, add_upsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() self.attentions = nn.ModuleList([]) @@ -2060,14 +2164,30 @@ def __init__( ) ) - self.attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=32, + processor=SpatialAttnProcessor(), + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + self.attentions.append(attention) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b4484823ac3d..baefbb1741b0 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -46,6 +46,7 @@ def __init__( norm_num_groups=32, act_fn="silu", double_z=True, + attention_block_type="AttentionBlock", ): super().__init__() self.layers_per_block = layers_per_block @@ -80,6 +81,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, + attention_block_type=attention_block_type, ) self.down_blocks.append(down_block) @@ -93,6 +95,7 @@ def __init__( attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=None, + attention_block_type=attention_block_type, ) # out @@ -149,6 +152,7 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", + attention_block_type="AttentionBlock", ): super().__init__() self.layers_per_block = layers_per_block @@ -174,6 +178,7 @@ def __init__( attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=None, + attention_block_type=attention_block_type, ) # up @@ -197,6 +202,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, + attention_block_type=attention_block_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 65f734dccb2d..c7568c1b43c7 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -82,6 +82,7 @@ def __init__( norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, + attention_block_type: str = "AttentionBlock", ): super().__init__() @@ -95,6 +96,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, + attention_block_type=attention_block_type, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels @@ -112,6 +114,7 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, + attention_block_type=attention_block_type, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f33b506827a..b30e8544e7a4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -36,7 +36,7 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -443,6 +443,11 @@ def register_modules(self, **kwargs): # set models setattr(self, name, module) + # In case the module was created only its constructor and ModelMixin.from_pretrained + # was never called. + if issubclass(module.__class__, ModelMixin): + module._convert_deprecated_attention_blocks() + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 8f831fcf7cbf..a3d62f638004 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -61,6 +61,7 @@ def prepare_init_args_and_inputs_for_common(self): "in_channels": 3, "layers_per_block": 2, "sample_size": 32, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -98,6 +99,7 @@ def prepare_init_args_and_inputs_for_common(self): "attention_head_dim": 32, "down_block_types": ("DownBlock2D", "DownBlock2D"), "up_block_types": ("UpBlock2D", "UpBlock2D"), + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -223,6 +225,7 @@ def prepare_init_args_and_inputs_for_common(self): "AttnSkipUpBlock2D", "SkipUpBlock2D", ], + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 3eb7ce861592..914eb127a694 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -58,6 +58,7 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "latent_channels": 4, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/models/test_models_vq.py b/tests/models/test_models_vq.py index 733b51d2f158..eb9b087d1f72 100644 --- a/tests/models/test_models_vq.py +++ b/tests/models/test_models_vq.py @@ -54,6 +54,7 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "latent_channels": 3, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index d0e2102b539e..e0e97ec3dc0b 100644 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from torch import nn @@ -25,6 +26,7 @@ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel from diffusers.utils import torch_device +from diffusers.utils.import_utils import is_xformers_available torch.backends.cuda.matmul.allow_tf32 = False @@ -331,6 +333,7 @@ def test_attention_block_default(self): eps=1e-6, norm_num_groups=32, ).to(torch_device) + attentionBlock = attentionBlock._as_attention_processor_attention() with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -355,6 +358,54 @@ def test_attention_block_sd(self): eps=1e-6, norm_num_groups=32, ).to(torch_device) + attentionBlock = attentionBlock._as_attention_processor_attention() + with torch.no_grad(): + attention_scores = attentionBlock(sample) + + assert attention_scores.shape == (1, 512, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_unconverted_block_forward_throws_errors(self): + attentionBlock = AttentionBlock(32) + + with pytest.raises( + ValueError, + match="`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`", + ): + attentionBlock(1) + + def test_unconverted_block_xformers_throws_errors(self): + attentionBlock = AttentionBlock(32) + + with pytest.raises( + ValueError, + match="`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`", + ): + attentionBlock.set_use_memory_efficient_attention_xformers(False) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_conversion_xformers(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 512, 64, 64).to(torch_device) + attentionBlock = AttentionBlock( + channels=512, + rescale_output_factor=1.0, + eps=1e-6, + norm_num_groups=32, + ).to(torch_device) + attentionBlock = attentionBlock._as_attention_processor_attention() + attentionBlock.set_use_memory_efficient_attention_xformers(True) with torch.no_grad(): attention_scores = attentionBlock(sample) diff --git a/tests/test_unet_2d_blocks.py b/tests/test_unet_2d_blocks.py index e560240422ac..a85b6894ea2f 100644 --- a/tests/test_unet_2d_blocks.py +++ b/tests/test_unet_2d_blocks.py @@ -139,6 +139,7 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = { "in_channels": 32, "out_channels": 32, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -156,6 +157,7 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = { "in_channels": 32, "temb_channels": 128, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -327,7 +329,7 @@ def dummy_input(self): return super().get_dummy_input(include_temb=False) def prepare_init_args_and_inputs_for_common(self): - init_dict = {"in_channels": 32, "out_channels": 32} + init_dict = {"in_channels": 32, "out_channels": 32, "attention_block_type": "Attention"} inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/test_unet_blocks_common.py b/tests/test_unet_blocks_common.py index 17b7f65d6da3..02d74f420b28 100644 --- a/tests/test_unet_blocks_common.py +++ b/tests/test_unet_blocks_common.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest from typing import Tuple @@ -83,6 +84,12 @@ def prepare_init_args_and_inputs_for_common(self): if self.block_type == "mid": init_dict.pop("out_channels") + constructor_args = inspect.signature(self.block_class.__init__) + constructor_args = constructor_args.parameters.keys() + + if "attention_block_type" in constructor_args: + init_dict["attention_block_type"] = "Attention" + inputs_dict = self.dummy_input return init_dict, inputs_dict From 280285f398e5437546697f7378ac7ee8c5c61b0f Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 20 Mar 2023 14:29:51 -0700 Subject: [PATCH 2/4] combine attention processors --- src/diffusers/models/attention.py | 4 +- src/diffusers/models/attention_processor.py | 312 ++++++++++++-------- src/diffusers/models/unet_2d_blocks.py | 16 +- 3 files changed, 206 insertions(+), 126 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8527f17e74c9..ed4d9dcb264b 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -18,7 +18,7 @@ from torch import nn from ..utils import deprecate -from .attention_processor import Attention, SpatialAttnProcessor +from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings @@ -104,9 +104,9 @@ def _as_attention_processor_attention(self): bias=True, upcast_softmax=True, norm_num_groups=self.group_norm.num_groups, - processor=SpatialAttnProcessor(), eps=self.group_norm.eps, rescale_output_factor=self.rescale_output_factor, + residual_connection=True, ) param = next(self.parameters()) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index c9e907e74f1c..232195ff0d9d 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -64,6 +64,7 @@ def __init__( processor: Optional["AttnProcessor"] = None, eps: float = 1e-5, rescale_output_factor: float = 1.0, + residual_connection: bool = False, ): super().__init__() inner_dim = dim_head * heads @@ -72,6 +73,7 @@ def __init__( self.upcast_softmax = upcast_softmax self.cross_attention_norm = cross_attention_norm self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -120,10 +122,6 @@ def set_use_memory_efficient_attention_xformers( self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) ) - is_spatial_attention = hasattr(self, "processor") and isinstance( - self.processor, (SpatialAttnProcessor, XFormersSpatialAttnProcessor) - ) - if use_memory_efficient_attention_xformers: if self.added_kv_proj_dim is not None: # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP @@ -166,8 +164,6 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) - elif is_spatial_attention: - processor = XFormersSpatialAttnProcessor() else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -179,8 +175,6 @@ def set_use_memory_efficient_attention_xformers( ) processor.load_state_dict(self.processor.state_dict()) processor.to(self.processor.to_q_lora.up.weight.device) - elif is_spatial_attention: - processor = SpatialAttnProcessor() else: processor = AttnProcessor() @@ -311,10 +305,31 @@ def __call__( encoder_hidden_states=None, attention_mask=None, ): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if attn.residual_connection: + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + if hidden_states.ndim == 4: + reshaped_input = True + + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + else: + reshaped_input = False + + if encoder_hidden_states is None: + sequence_length = hidden_states.shape[1] + else: + sequence_length = encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -338,6 +353,15 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) + if reshaped_input: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -378,9 +402,29 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if attn.residual_connection: + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + if hidden_states.ndim == 4: + reshaped_input = True + + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + else: + reshaped_input = False + + if encoder_hidden_states is None: + sequence_length = hidden_states.shape[1] + else: + sequence_length = encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -403,6 +447,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if reshaped_input: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -453,9 +506,28 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if attn.residual_connection: + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + if hidden_states.ndim == 4: + reshaped_input = True + + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + else: + reshaped_input = False + + if encoder_hidden_states is None: + sequence_length = hidden_states.shape[1] + else: + sequence_length = encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -483,6 +555,16 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if reshaped_input: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -492,9 +574,29 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if attn.residual_connection: + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + if hidden_states.ndim == 4: + reshaped_input = True + + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + else: + reshaped_input = False + + if encoder_hidden_states is None: + sequence_length = hidden_states.shape[1] + else: + sequence_length = encoder_hidden_states.shape[1] + inner_dim = hidden_states.shape[-1] if attention_mask is not None: @@ -531,6 +633,16 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if reshaped_input: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -549,9 +661,29 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if attn.residual_connection: + residual = hidden_states + + if attn.group_norm: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + if hidden_states.ndim == 4: + reshaped_input = True + + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + else: + reshaped_input = False + + if encoder_hidden_states is None: + sequence_length = hidden_states.shape[1] + else: + sequence_length = encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -575,6 +707,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if reshaped_input: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -583,9 +724,29 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + if attn.residual_connection: + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + if hidden_states.ndim == 4: + reshaped_input = True + + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + else: + reshaped_input = False + + if encoder_hidden_states is None: + sequence_length = hidden_states.shape[1] + else: + sequence_length = encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) @@ -628,6 +789,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if reshaped_input: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -695,94 +865,6 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, return hidden_states -class SpatialAttnProcessor: - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - if attention_mask is not None: - raise ValueError(f"{self.__class__.__name__} does not support `attention_mask`") - - if encoder_hidden_states is not None: - raise ValueError(f"{self.__class__.__name__} does not support `encoder_hidden_states`") - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = attn.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = attn.to_q(hidden_states) - key_proj = attn.to_k(hidden_states) - value_proj = attn.to_v(hidden_states) - - query_proj = attn.head_to_batch_dim(query_proj) - key_proj = attn.head_to_batch_dim(key_proj) - value_proj = attn.head_to_batch_dim(value_proj) - - attention_probs = attn.get_attention_scores(query_proj, key_proj) - hidden_states = torch.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = attn.batch_to_head_dim(hidden_states) - - # compute next hidden_states - hidden_states = attn.to_out[0](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / attn.rescale_output_factor - return hidden_states - - -class XFormersSpatialAttnProcessor: - def __init__(self, attention_op: Optional[Callable] = None): - self.attention_op = attention_op - - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - if attention_mask is not None: - raise ValueError(f"{self.__class__.__name__} does not support `attention_mask`") - - if encoder_hidden_states is not None: - raise ValueError(f"{self.__class__.__name__} does not support `encoder_hidden_states`") - - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = attn.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = attn.to_q(hidden_states) - key_proj = attn.to_k(hidden_states) - value_proj = attn.to_v(hidden_states) - - query_proj = attn.head_to_batch_dim(query_proj) - key_proj = attn.head_to_batch_dim(key_proj) - value_proj = attn.head_to_batch_dim(value_proj) - - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self.attention_op, scale=attn.scale - ) - hidden_states = hidden_states.to(query_proj.dtype) - - # reshape hidden_states - hidden_states = attn.batch_to_head_dim(hidden_states) - - # compute next hidden_states - hidden_states = attn.to_out[0](hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / attn.rescale_output_factor - return hidden_states - - AttentionProcessor = Union[ AttnProcessor, XFormersAttnProcessor, @@ -791,6 +873,4 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a SlicedAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, - SpatialAttnProcessor, - XFormersSpatialAttnProcessor, ] diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 46ae83e0d4a6..c0fcc8174374 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -18,7 +18,7 @@ from torch import nn from .attention import AdaGroupNorm, AttentionBlock -from .attention_processor import Attention, AttnAddedKVProcessor, SpatialAttnProcessor +from .attention_processor import Attention, AttnAddedKVProcessor from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from .transformer_2d import Transformer2DModel @@ -432,8 +432,8 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=resnet_groups, - processor=SpatialAttnProcessor(), eps=resnet_eps, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") @@ -719,8 +719,8 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=resnet_groups, - processor=SpatialAttnProcessor(), eps=resnet_eps, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") @@ -1081,9 +1081,9 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=resnet_groups, - processor=SpatialAttnProcessor(), eps=resnet_eps, rescale_output_factor=output_scale_factor, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") @@ -1170,9 +1170,9 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=32, - processor=SpatialAttnProcessor(), eps=resnet_eps, rescale_output_factor=output_scale_factor, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") @@ -1742,9 +1742,9 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=resnet_groups, - processor=SpatialAttnProcessor(), eps=resnet_eps, rescale_output_factor=output_scale_factor, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") @@ -2092,9 +2092,9 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=resnet_groups, - processor=SpatialAttnProcessor(), eps=resnet_eps, rescale_output_factor=output_scale_factor, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") @@ -2180,9 +2180,9 @@ def __init__( bias=True, upcast_softmax=True, norm_num_groups=32, - processor=SpatialAttnProcessor(), eps=resnet_eps, rescale_output_factor=output_scale_factor, + residual_connection=True, ) else: raise ValueError(f"Unknown attention_block_type: {attention_block_type}") From fdffad430195a3e28d2f2289a9bda80038b73a16 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 21 Mar 2023 10:47:44 -0700 Subject: [PATCH 3/4] fixes re: @patrickvonplaten --- src/diffusers/models/attention_processor.py | 97 +++++++-------------- 1 file changed, 31 insertions(+), 66 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 232195ff0d9d..474760e72646 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -305,28 +305,22 @@ def __call__( encoder_hidden_states=None, attention_mask=None, ): - if attn.residual_connection: - residual = hidden_states + residual = hidden_states if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states) batch_size = hidden_states.shape[0] - if hidden_states.ndim == 4: - reshaped_input = True + input_ndim = hidden_states.ndim + if input_ndim == 4: _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width) hidden_states = hidden_states.transpose(1, 2) - else: - reshaped_input = False - if encoder_hidden_states is None: - sequence_length = hidden_states.shape[1] - else: - sequence_length = encoder_hidden_states.shape[1] + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -353,7 +347,7 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) - if reshaped_input: + if input_ndim == 4: hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, channel, height, width) @@ -402,28 +396,22 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): - if attn.residual_connection: - residual = hidden_states + residual = hidden_states if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states) batch_size = hidden_states.shape[0] - if hidden_states.ndim == 4: - reshaped_input = True + input_ndim = hidden_states.ndim + if input_ndim == 4: _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width) hidden_states = hidden_states.transpose(1, 2) - else: - reshaped_input = False - if encoder_hidden_states is None: - sequence_length = hidden_states.shape[1] - else: - sequence_length = encoder_hidden_states.shape[1] + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -447,7 +435,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) - if reshaped_input: + if input_ndim == 4: hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, channel, height, width) @@ -506,28 +494,23 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - if attn.residual_connection: - residual = hidden_states + residual = hidden_states if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states) batch_size = hidden_states.shape[0] - if hidden_states.ndim == 4: - reshaped_input = True + input_ndim = hidden_states.ndim + + if input_ndim == 4: _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width) hidden_states = hidden_states.transpose(1, 2) - else: - reshaped_input = False - if encoder_hidden_states is None: - sequence_length = hidden_states.shape[1] - else: - sequence_length = encoder_hidden_states.shape[1] + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -556,7 +539,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) - if reshaped_input: + if input_ndim == 4: hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, channel, height, width) @@ -574,28 +557,22 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - if attn.residual_connection: - residual = hidden_states + residual = hidden_states if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states) batch_size = hidden_states.shape[0] - if hidden_states.ndim == 4: - reshaped_input = True + input_ndim = hidden_states.ndim + if input_ndim == 4: _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width) hidden_states = hidden_states.transpose(1, 2) - else: - reshaped_input = False - if encoder_hidden_states is None: - sequence_length = hidden_states.shape[1] - else: - sequence_length = encoder_hidden_states.shape[1] + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] inner_dim = hidden_states.shape[-1] @@ -634,7 +611,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) - if reshaped_input: + if input_ndim == 4: hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, channel, height, width) @@ -661,28 +638,22 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): - if attn.residual_connection: - residual = hidden_states + residual = hidden_states if attn.group_norm: hidden_states = attn.group_norm(hidden_states) batch_size = hidden_states.shape[0] - if hidden_states.ndim == 4: - reshaped_input = True + input_ndim = hidden_states.ndim + if input_ndim == 4: _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width) hidden_states = hidden_states.transpose(1, 2) - else: - reshaped_input = False - if encoder_hidden_states is None: - sequence_length = hidden_states.shape[1] - else: - sequence_length = encoder_hidden_states.shape[1] + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -707,7 +678,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) - if reshaped_input: + if input_ndim == 4: hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, channel, height, width) @@ -724,28 +695,22 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - if attn.residual_connection: - residual = hidden_states + residual = hidden_states if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states) batch_size = hidden_states.shape[0] - if hidden_states.ndim == 4: - reshaped_input = True + input_ndim = hidden_states.ndim + if input_ndim == 4: _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width) hidden_states = hidden_states.transpose(1, 2) - else: - reshaped_input = False - if encoder_hidden_states is None: - sequence_length = hidden_states.shape[1] - else: - sequence_length = encoder_hidden_states.shape[1] + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -789,7 +754,7 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) - if reshaped_input: + if input_ndim == 4: hidden_states = hidden_states.transpose(1, 2) hidden_states = hidden_states.reshape(batch_size, channel, height, width) From 0bd5271ba832227a3170e093a8dd801742d2714e Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 21 Mar 2023 10:50:48 -0700 Subject: [PATCH 4/4] style --- src/diffusers/models/attention_processor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 474760e72646..b2134ee47c89 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -504,7 +504,6 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a input_ndim = hidden_states.ndim if input_ndim == 4: - _, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width)