Skip to content

Deprecate attention block #2697

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def main(args):
"UpBlock2D",
"UpBlock2D",
),
attention_block_type="Attention",
)

# Create EMA for the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,12 @@ def load_model_hook(models, input_dir):
"UpBlock2D",
"UpBlock2D",
),
attention_block_type="Attention",
)
else:
config = UNet2DModel.load_config(args.model_config_name_or_path)
model = UNet2DModel.from_config(config)
model._convert_deprecated_attention_blocks()

# Create EMA for the model.
if args.use_ema:
Expand Down
153 changes: 52 additions & 101 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Callable, Optional

import torch
import torch.nn.functional as F
from torch import nn

from ..utils.import_utils import is_xformers_available
from ..utils import deprecate
from .attention_processor import Attention
from .embeddings import CombinedTimestepLabelEmbeddings


if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


class AttentionBlock(nn.Module):
"""
This class is deprecated. Its forward method will throw an error. On model load, we convert all instances of
`AttentionBlock` to `diffusers.models.attention_processor.Attention`, see
`ModelMixin#_convert_deprecated_attention_blocks`.

An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Expand All @@ -46,8 +42,6 @@ class AttentionBlock(nn.Module):
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""

# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore

def __init__(
self,
channels: int,
Expand All @@ -57,6 +51,16 @@ def __init__(
eps: float = 1e-5,
):
super().__init__()

deprecation_message = (
"`AttentionBlock` has been deprecated and will be replaced with `diffusers.models.attention_processor.Attention`."
" The DiffusionPipeline loading this block in is auto converting it to `diffusers.models.attention_processor.Attention`."
" Please call `DiffusionPipeline#save_pretrained` and re-upload the pipeline to the hub."
" If you are only loading a model instead of a whole pipeline, the same instructions apply with `ModelMixin#save_pretrained`."
)

deprecate("AttentionBlock", "0.18.0", deprecation_message, standard_warn=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten is this the right deprecation version? iirc we talked about two minor versions

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'd maybe even bumb it up a bit more to "0.20.0" maybe


self.channels = channels

self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
Expand All @@ -71,107 +75,54 @@ def __init__(
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, bias=True)

self._use_memory_efficient_attention_xformers = False
self._attention_op = None

def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor

def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor

def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
" only available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op
raise ValueError(
"`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`"
)

def forward(self, hidden_states):
residual = hidden_states
batch, channel, height, width = hidden_states.shape

# norm
hidden_states = self.group_norm(hidden_states)

hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)

# proj to q, k, v
query_proj = self.query(hidden_states)
key_proj = self.key(hidden_states)
value_proj = self.value(hidden_states)

scale = 1 / math.sqrt(self.channels / self.num_heads)

query_proj = self.reshape_heads_to_batch_dim(query_proj)
key_proj = self.reshape_heads_to_batch_dim(key_proj)
value_proj = self.reshape_heads_to_batch_dim(value_proj)
raise ValueError(
"`AttentionBlock` should have been converted after load to `diffusers.models.attention_processor.Attention`"
)

if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
)
hidden_states = hidden_states.to(query_proj.dtype)
def _as_attention_processor_attention(self):
if self.num_head_size is None:
# When `self.num_head_size` is None, there is a single attention head
# of all the channels
dim_head = self.channels
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
dim_head = self.num_head_size

# This will allocate some additional memory but as this is only done once during model load,
# it should be ok.
attn = Attention(
self.channels,
heads=self.num_heads,
dim_head=dim_head,
bias=True,
upcast_softmax=True,
norm_num_groups=self.group_norm.num_groups,
eps=self.group_norm.eps,
rescale_output_factor=self.rescale_output_factor,
residual_connection=True,
)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
param = next(self.parameters())

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
device = param.device
dtype = param.dtype

hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
attn.to(device=device, dtype=dtype)

# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
return hidden_states
attn.group_norm.load_state_dict(self.group_norm.state_dict())
attn.to_q.load_state_dict(self.query.state_dict())
attn.to_k.load_state_dict(self.key.state_dict())
attn.to_v.load_state_dict(self.value.state_dict())
attn.to_out[0].load_state_dict(self.proj_attn.state_dict())

return attn


class BasicTransformerBlock(nn.Module):
Expand Down
Loading