diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 1b38036d82c0..edb315429474 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -350,6 +350,7 @@ def main(args): "UpBlock2D", "UpBlock2D", ), + attention_block_type="Attention", ) # Create EMA for the model. diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3b784eda6a34..3a399d4926f8 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -397,10 +397,12 @@ def load_model_hook(models, input_dir): "UpBlock2D", "UpBlock2D", ), + attention_block_type="Attention", ) else: config = UNet2DModel.load_config(args.model_config_name_or_path) model = UNet2DModel.from_config(config) + model._convert_deprecated_attention_blocks() # Create EMA for the model. if args.use_ema: diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f271e00f8639..07baa294b1ac 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,27 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, Optional import torch import torch.nn.functional as F from torch import nn -from ..utils.import_utils import is_xformers_available +from ..utils import deprecate from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - class AttentionBlock(nn.Module): """ + This class is deprecated. Its forward method will throw an error. On model load, we convert all instances of + `AttentionBlock` to `diffusers.models.attention_processor.Attention`, see + `ModelMixin#_convert_deprecated_attention_blocks`. + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. @@ -46,8 +42,6 @@ class AttentionBlock(nn.Module): eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. """ - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - def __init__( self, channels: int, @@ -57,6 +51,16 @@ def __init__( eps: float = 1e-5, ): super().__init__() + + deprecation_message = ( + "`AttentionBlock` has been deprecated and will be replaced with `diffusers.models.attention_processor.Attention`." + " The DiffusionPipeline loading this block in is auto converting it to `diffusers.models.attention_processor.Attention`." + " Please call `DiffusionPipeline#save_pretrained` and re-upload the pipeline to the hub." + " If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`." + ) + + deprecate("AttentionBlock", "0.18.0", deprecation_message, standard_warn=True) + self.channels = channels self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 @@ -71,107 +75,54 @@ def __init__( self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, bias=True) - self._use_memory_efficient_attention_xformers = False - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self._attention_op = attention_op + raise ValueError( + "`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`" + ) def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - query_proj = self.reshape_heads_to_batch_dim(query_proj) - key_proj = self.reshape_heads_to_batch_dim(key_proj) - value_proj = self.reshape_heads_to_batch_dim(value_proj) + raise ValueError( + "`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`" + ) - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op - ) - hidden_states = hidden_states.to(query_proj.dtype) + def _as_attention_processor_attention(self): + if self.num_head_size is None: + # When `self.num_head_size` is None, there is a single attention head + # of all the channels + dim_head = self.channels else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) + dim_head = self.num_head_size + + # This will allocate some additional memory but as this is only done once during model load, + # it should be ok. + attn = Attention( + self.channels, + heads=self.num_heads, + dim_head=dim_head, + bias=True, + upcast_softmax=True, + norm_num_groups=self.group_norm.num_groups, + eps=self.group_norm.eps, + rescale_output_factor=self.rescale_output_factor, + residual_connection=True, + ) - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + param = next(self.parameters()) - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) + device = param.device + dtype = param.dtype - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + attn.to(device=device, dtype=dtype) - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states + attn.group_norm.load_state_dict(self.group_norm.state_dict()) + attn.to_q.load_state_dict(self.query.state_dict()) + attn.to_k.load_state_dict(self.key.state_dict()) + attn.to_v.load_state_dict(self.value.state_dict()) + attn.to_out[0].load_state_dict(self.proj_attn.state_dict()) + + return attn class BasicTransformerBlock(nn.Module): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30026cd89ff9..b2134ee47c89 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -62,6 +62,9 @@ def __init__( out_bias: bool = True, scale_qk: bool = True, processor: Optional["AttnProcessor"] = None, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, ): super().__init__() inner_dim = dim_head * heads @@ -69,6 +72,8 @@ def __init__( self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax self.cross_attention_norm = cross_attention_norm + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection self.scale = dim_head**-0.5 if scale_qk else 1.0 @@ -81,7 +86,7 @@ def __init__( self.added_kv_proj_dim = added_kv_proj_dim if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None @@ -300,10 +305,25 @@ def __call__( encoder_hidden_states=None, attention_mask=None, ): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -327,6 +347,15 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -367,9 +396,23 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4): self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -392,6 +435,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -442,9 +494,22 @@ def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) @@ -472,6 +537,16 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -481,9 +556,23 @@ def __init__(self): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] + inner_dim = hidden_states.shape[-1] if attention_mask is not None: @@ -520,6 +609,16 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -538,9 +637,23 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + residual = hidden_states + + if attn.group_norm: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) @@ -564,6 +677,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states @@ -572,9 +694,23 @@ def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + residual = hidden_states + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states) + + batch_size = hidden_states.shape[0] + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + _, channel, height, width = hidden_states.shape + + hidden_states = hidden_states.view(batch_size, channel, height * width) + hidden_states = hidden_states.transpose(1, 2) + + sequence_length = hidden_states.shape[1] if encoder_hidden_states is None else encoder_hidden_states.shape[1] + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = attn.to_q(hidden_states) @@ -617,6 +753,15 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(1, 2) + hidden_states = hidden_states.reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 8f65c2357cac..cd841a98f41c 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -81,6 +81,7 @@ def __init__( norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, + attention_block_type: str = "AttentionBlock", ): super().__init__() @@ -94,6 +95,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=True, + attention_block_type=attention_block_type, ) # pass init params to Decoder @@ -105,6 +107,7 @@ def __init__( layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, act_fn=act_fn, + attention_block_type=attention_block_type, ) self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e51b40ce4509..e43d4c0bbf58 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -25,7 +25,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from packaging import version from requests import HTTPError -from torch import Tensor, device +from torch import Tensor, device, nn from .. import __version__ from ..utils import ( @@ -629,6 +629,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model.register_to_config(_name_or_path=pretrained_model_name_or_path) + # The model has had its weights loaded. If the deprecated attention block weights + # were in the old format, it's now safe to convert them. + model._convert_deprecated_attention_blocks() + # Set model in evaluation mode to deactivate DropOut modules by default model.eval() if output_loading_info: @@ -783,6 +787,72 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + def _convert_deprecated_attention_blocks(self): + """ + `diffusers.models.attention.AttentionBlock` is deprecated and must be converted to + `diffusers.models.attention_processor.Attention`. + + See https://github.com/huggingface/diffusers/issues/1880 and https://github.com/huggingface/diffusers/pull/2697 + for more details + """ + models_with_deprecated_attention = ["UNet2DModel", "VQModel", "AutoencoderKL"] + + blocks_with_deprecated_attention = [ + "AttnDownBlock2D", + "AttnSkipDownBlock2D", + "AttnDownEncoderBlock2D", + "AttnUpBlock2D", + "AttnSkipUpBlock2D", + "AttnUpDecoderBlock2D", + "UNetMidBlock2D", + ] + + if self.__class__.__name__ not in models_with_deprecated_attention: + return + + if self.config.attention_block_type == "Attention": + # Model as already been converted + return + + self.register_to_config(attention_block_type="Attention") + + def _convert_deprecated_attention_blocks_in_module(module): + if module.__class__.__name__ not in blocks_with_deprecated_attention: + return + + attentions = nn.ModuleList() + + for attention in module.attentions: + if attention is not None: + attention = attention._as_attention_processor_attention() + + attentions.append(attention) + + module.attentions = attentions + + if self.__class__.__name__ == "UNet2DModel": + for down_block in self.down_blocks: + _convert_deprecated_attention_blocks_in_module(down_block) + + _convert_deprecated_attention_blocks_in_module(self.mid_block) + + for up_block in self.up_blocks: + _convert_deprecated_attention_blocks_in_module(up_block) + + elif self.__class__.__name__ in ["VQModel", "AutoencoderKL"]: + for down_block in self.encoder.down_blocks: + _convert_deprecated_attention_blocks_in_module(down_block) + + _convert_deprecated_attention_blocks_in_module(self.encoder.mid_block) + + _convert_deprecated_attention_blocks_in_module(self.decoder.mid_block) + + for up_block in self.decoder.up_blocks: + _convert_deprecated_attention_blocks_in_module(up_block) + + else: + assert False + def _get_model_file( pretrained_model_name_or_path, diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 2df6e60d88c9..ef543ef3d586 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -102,6 +102,7 @@ def __init__( add_attention: bool = True, class_embed_type: Optional[str] = None, num_class_embeds: Optional[int] = None, + attention_block_type: str = "AttentionBlock", ): super().__init__() @@ -166,6 +167,7 @@ def __init__( attn_num_head_channels=attention_head_dim, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) self.down_blocks.append(down_block) @@ -180,6 +182,7 @@ def __init__( attn_num_head_channels=attention_head_dim, resnet_groups=norm_num_groups, add_attention=add_attention, + attention_block_type=attention_block_type, ) # up @@ -205,6 +208,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 3070351279b8..c0fcc8174374 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -42,6 +42,7 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + attention_block_type="AttentionBlock", ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": @@ -82,6 +83,7 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: @@ -144,6 +146,7 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( @@ -169,6 +172,7 @@ def get_down_block( downsample_padding=downsample_padding, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( @@ -214,6 +218,7 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + attention_block_type="AttentionBlock", ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": @@ -293,6 +298,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( @@ -318,6 +324,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( @@ -341,6 +348,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + attention_block_type=attention_block_type, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( @@ -383,6 +391,7 @@ def __init__( add_attention: bool = True, attn_num_head_channels=1, output_scale_factor=1.0, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) @@ -407,15 +416,29 @@ def __init__( for _ in range(num_layers): if self.add_attention: - attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( in_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + in_channels, + heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + eps=resnet_eps, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + attentions.append(attention) else: attentions.append(None) @@ -658,6 +681,7 @@ def __init__( output_scale_factor=1.0, downsample_padding=1, add_downsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -679,15 +703,28 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + eps=resnet_eps, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1006,6 +1043,7 @@ def __init__( output_scale_factor=1.0, add_downsample=True, downsample_padding=1, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -1027,15 +1065,29 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1079,6 +1131,7 @@ def __init__( output_scale_factor=np.sqrt(2.0), downsample_padding=1, add_downsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() self.attentions = nn.ModuleList([]) @@ -1101,14 +1154,30 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - self.attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=32, + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + self.attentions.append(attention) if add_downsample: self.resnet_down = ResnetBlock2D( @@ -1632,6 +1701,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -1655,15 +1725,31 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -1966,6 +2052,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() resnets = [] @@ -1988,15 +2075,31 @@ def __init__( pre_norm=resnet_pre_norm, ) ) - attentions.append( - AttentionBlock( + + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=resnet_groups, + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + attentions.append(attention) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) @@ -2035,6 +2138,7 @@ def __init__( output_scale_factor=np.sqrt(2.0), upsample_padding=1, add_upsample=True, + attention_block_type: str = "AttentionBlock", ): super().__init__() self.attentions = nn.ModuleList([]) @@ -2060,14 +2164,30 @@ def __init__( ) ) - self.attentions.append( - AttentionBlock( + if attention_block_type == "AttentionBlock": + attention = AttentionBlock( out_channels, num_head_channels=attn_num_head_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, ) - ) + elif attention_block_type == "Attention": + attention = Attention( + out_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, + bias=True, + upcast_softmax=True, + norm_num_groups=32, + eps=resnet_eps, + rescale_output_factor=output_scale_factor, + residual_connection=True, + ) + else: + raise ValueError(f"Unknown attention_block_type: {attention_block_type}") + + self.attentions.append(attention) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b4484823ac3d..baefbb1741b0 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -46,6 +46,7 @@ def __init__( norm_num_groups=32, act_fn="silu", double_z=True, + attention_block_type="AttentionBlock", ): super().__init__() self.layers_per_block = layers_per_block @@ -80,6 +81,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, + attention_block_type=attention_block_type, ) self.down_blocks.append(down_block) @@ -93,6 +95,7 @@ def __init__( attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=None, + attention_block_type=attention_block_type, ) # out @@ -149,6 +152,7 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", + attention_block_type="AttentionBlock", ): super().__init__() self.layers_per_block = layers_per_block @@ -174,6 +178,7 @@ def __init__( attn_num_head_channels=None, resnet_groups=norm_num_groups, temb_channels=None, + attention_block_type=attention_block_type, ) # up @@ -197,6 +202,7 @@ def __init__( resnet_groups=norm_num_groups, attn_num_head_channels=None, temb_channels=None, + attention_block_type=attention_block_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 65f734dccb2d..c7568c1b43c7 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -82,6 +82,7 @@ def __init__( norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, + attention_block_type: str = "AttentionBlock", ): super().__init__() @@ -95,6 +96,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, + attention_block_type=attention_block_type, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels @@ -112,6 +114,7 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, + attention_block_type=attention_block_type, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8f33b506827a..b30e8544e7a4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -36,7 +36,7 @@ from .. import __version__ from ..configuration_utils import ConfigMixin -from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( CONFIG_NAME, @@ -443,6 +443,11 @@ def register_modules(self, **kwargs): # set models setattr(self, name, module) + # In case the module was created only its constructor and ModelMixin.from_pretrained + # was never called. + if issubclass(module.__class__, ModelMixin): + module._convert_deprecated_attention_blocks() + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index 8f831fcf7cbf..a3d62f638004 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -61,6 +61,7 @@ def prepare_init_args_and_inputs_for_common(self): "in_channels": 3, "layers_per_block": 2, "sample_size": 32, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -98,6 +99,7 @@ def prepare_init_args_and_inputs_for_common(self): "attention_head_dim": 32, "down_block_types": ("DownBlock2D", "DownBlock2D"), "up_block_types": ("UpBlock2D", "UpBlock2D"), + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -223,6 +225,7 @@ def prepare_init_args_and_inputs_for_common(self): "AttnSkipUpBlock2D", "SkipUpBlock2D", ], + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index abd4a078e692..8b18d9cf63fd 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -57,6 +57,7 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "latent_channels": 4, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/models/test_models_vq.py b/tests/models/test_models_vq.py index 66c33e07371e..dc3c452026a3 100644 --- a/tests/models/test_models_vq.py +++ b/tests/models/test_models_vq.py @@ -54,6 +54,7 @@ def prepare_init_args_and_inputs_for_common(self): "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], "latent_channels": 3, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/test_layers_utils.py b/tests/test_layers_utils.py index d0e2102b539e..e0e97ec3dc0b 100644 --- a/tests/test_layers_utils.py +++ b/tests/test_layers_utils.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from torch import nn @@ -25,6 +26,7 @@ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel from diffusers.utils import torch_device +from diffusers.utils.import_utils import is_xformers_available torch.backends.cuda.matmul.allow_tf32 = False @@ -331,6 +333,7 @@ def test_attention_block_default(self): eps=1e-6, norm_num_groups=32, ).to(torch_device) + attentionBlock = attentionBlock._as_attention_processor_attention() with torch.no_grad(): attention_scores = attentionBlock(sample) @@ -355,6 +358,54 @@ def test_attention_block_sd(self): eps=1e-6, norm_num_groups=32, ).to(torch_device) + attentionBlock = attentionBlock._as_attention_processor_attention() + with torch.no_grad(): + attention_scores = attentionBlock(sample) + + assert attention_scores.shape == (1, 512, 64, 64) + output_slice = attention_scores[0, -1, -3:, -3:] + + expected_slice = torch.tensor( + [-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device + ) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + + def test_unconverted_block_forward_throws_errors(self): + attentionBlock = AttentionBlock(32) + + with pytest.raises( + ValueError, + match="`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`", + ): + attentionBlock(1) + + def test_unconverted_block_xformers_throws_errors(self): + attentionBlock = AttentionBlock(32) + + with pytest.raises( + ValueError, + match="`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`", + ): + attentionBlock.set_use_memory_efficient_attention_xformers(False) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_conversion_xformers(self): + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + sample = torch.randn(1, 512, 64, 64).to(torch_device) + attentionBlock = AttentionBlock( + channels=512, + rescale_output_factor=1.0, + eps=1e-6, + norm_num_groups=32, + ).to(torch_device) + attentionBlock = attentionBlock._as_attention_processor_attention() + attentionBlock.set_use_memory_efficient_attention_xformers(True) with torch.no_grad(): attention_scores = attentionBlock(sample) diff --git a/tests/test_unet_2d_blocks.py b/tests/test_unet_2d_blocks.py index e560240422ac..a85b6894ea2f 100644 --- a/tests/test_unet_2d_blocks.py +++ b/tests/test_unet_2d_blocks.py @@ -139,6 +139,7 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = { "in_channels": 32, "out_channels": 32, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -156,6 +157,7 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = { "in_channels": 32, "temb_channels": 128, + "attention_block_type": "Attention", } inputs_dict = self.dummy_input return init_dict, inputs_dict @@ -327,7 +329,7 @@ def dummy_input(self): return super().get_dummy_input(include_temb=False) def prepare_init_args_and_inputs_for_common(self): - init_dict = {"in_channels": 32, "out_channels": 32} + init_dict = {"in_channels": 32, "out_channels": 32, "attention_block_type": "Attention"} inputs_dict = self.dummy_input return init_dict, inputs_dict diff --git a/tests/test_unet_blocks_common.py b/tests/test_unet_blocks_common.py index 17b7f65d6da3..02d74f420b28 100644 --- a/tests/test_unet_blocks_common.py +++ b/tests/test_unet_blocks_common.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import unittest from typing import Tuple @@ -83,6 +84,12 @@ def prepare_init_args_and_inputs_for_common(self): if self.block_type == "mid": init_dict.pop("out_channels") + constructor_args = inspect.signature(self.block_class.__init__) + constructor_args = constructor_args.parameters.keys() + + if "attention_block_type" in constructor_args: + init_dict["attention_block_type"] = "Attention" + inputs_dict = self.dummy_input return init_dict, inputs_dict