Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 62 additions & 78 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import collections
from dataclasses import dataclass
from typing import Optional, Union
from typing import Callable, Optional, Union

import numpy as np
import torch
Expand All @@ -28,7 +28,7 @@
from ...activations import ACT2FN
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -178,6 +178,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be mask then softmax :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course 🤦Thanks

attn_weights = torch.softmax(attn_weights, dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second softmax confuses me tbh (and based on that I have some doubts that sdpa will match well to eager 👀 )

Is it really necessary or could we just use one softmax after the mask has been applied? And there is also a dtype difference between the two (fp32 vs dtype of the weights). Maybe we could inherit the eager attention from somewhere else like llama or bart?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right it is weird, I reproduced the eager attention implementation that was there before the refactor (see below) without paying too much attention to it, but it might just be that the attention_mask is never used here so this path is never taken and didn't cause any issue. I'll check and remove it if that's the case, thanks for pointing it out!

attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()

return attn_output, attn_weights


class SamAttention(nn.Module):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
Expand All @@ -186,6 +208,7 @@ class SamAttention(nn.Module):

def __init__(self, config, downsample_rate=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size

downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
Expand All @@ -200,19 +223,25 @@ def __init__(self, config, downsample_rate=None):
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)

self.is_causal = False

def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
batch, point_batch_size, n_tokens, channel = hidden_states.shape
c_per_head = channel // num_attention_heads
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
return hidden_states.transpose(1, 2)

def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
hidden_states = hidden_states.transpose(1, 2)
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)

def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_similarity: Optional[Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Tensor:
# Input projections
query = self.q_proj(query)
Expand All @@ -226,66 +255,29 @@ def forward(
value = self._separate_heads(value, self.num_attention_heads)

# SamAttention
_, _, _, c_per_head = query.shape
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
attn = attn / (c_per_head**0.5)
attn = torch.softmax(attn, dim=-1)

if attention_similarity is not None:
attn = attn + attention_similarity
attn = torch.softmax(attn, dim=-1)

# Get output
out = attn @ value
out = self._recombine_heads(out, point_batch_size)
out = self.out_proj(out)

return out


class SamSdpaAttention(SamAttention):
"""
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
values. Using SDPA instead of the default attention.
"""

def __init__(self, config, downsample_rate=None):
super().__init__(config, downsample_rate)

def forward(
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
) -> Tensor:
# Input projections
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)

point_batch_size = query.shape[1]
# Separate into heads
query = self._separate_heads(query, self.num_attention_heads)
key = self._separate_heads(key, self.num_attention_heads)
value = self._separate_heads(value, self.num_attention_heads)

# Scaled dot product attention
attn_mask = None
if attention_similarity is not None:
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)

out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
scale = query.shape[-1] ** -0.5
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

attn_output, _ = attention_interface(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We never want the attention outputs here?

self,
query,
key,
value,
attention_mask=attention_similarity,
dropout=0.0 if not self.training else self.dropout_p,
scaling=scale,
is_causal=self.is_causal,
**kwargs,
)

# Get output
out = self._recombine_heads(out, point_batch_size)
out = self._recombine_heads(attn_output, point_batch_size)
out = self.out_proj(out)

return out


SAM_ATTENTION_CLASSES = {
"eager": SamAttention,
"sdpa": SamSdpaAttention,
}


class SamTwoWayAttentionBlock(nn.Module):
def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
"""
Expand All @@ -306,21 +298,17 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps

self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
self.self_attn = SamAttention(config, downsample_rate=1)
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.mlp = SamMLPBlock(config)
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)

self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
config, downsample_rate=attention_downsample_rate
)
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe

def forward(
Expand All @@ -330,6 +318,7 @@ def forward(
query_point_embedding: Tensor,
key_point_embedding: Tensor,
attention_similarity: Tensor,
**kwargs: Unpack[TransformersKwargs],
):
# Self attention block
if self.skip_first_layer_pe:
Expand Down Expand Up @@ -378,7 +367,7 @@ def __init__(self, config: SamMaskDecoderConfig):
for i in range(self.num_hidden_layers):
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))

self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
self.final_attn_token_to_image = SamAttention(config)
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)

def forward(
Expand All @@ -388,6 +377,7 @@ def forward(
image_positional_embeddings: Tensor,
attention_similarity: Tensor,
target_embedding=None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]:
if image_embeddings is None:
raise ValueError("You have to specify an image_embedding")
Expand All @@ -410,6 +400,7 @@ def forward(
query_point_embedding=point_embeddings,
key_point_embedding=image_positional_embeddings,
attention_similarity=attention_similarity,
**kwargs,
)
# Apply the final attenion layer from the points to the image
query = queries + point_embeddings
Expand Down Expand Up @@ -501,12 +492,12 @@ def forward(
Whether to return multiple masks or a single mask.
"""
batch_size, num_channels, height, width = image_embeddings.shape
point_batch_size = sparse_prompt_embeddings.shape[1]
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)

if sparse_prompt_embeddings.sum().item() != 0:
if sparse_prompt_embeddings is not None:
Comment on lines -509 to +500
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing cuda sync issue and does not seem to be necessary if we set sparse_prompt_embeddings to None when it's not provided in the prompt encoder

tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
else:
tokens = output_tokens
Expand Down Expand Up @@ -611,7 +602,7 @@ def forward(self, masks):


class SamPromptEncoder(nn.Module):
def __init__(self, config: SamPromptEncoderConfig):
def __init__(self, config: SamConfig):
super().__init__()
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
config = config.prompt_encoder_config
Expand Down Expand Up @@ -648,7 +639,7 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -
point_embedding = torch.where(
labels[..., None] != -10,
point_embedding,
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
torch.zeros_like(point_embedding, dtype=point_embedding.dtype, device=point_embedding.device),
)

point_embedding = torch.where(
Expand Down Expand Up @@ -717,9 +708,6 @@ def forward(
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)

if sparse_embeddings is None:
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)

return sparse_embeddings, dense_embeddings


Expand Down Expand Up @@ -1182,10 +1170,6 @@ def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Input pixel values
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
"""
vision_output = self.vision_encoder(
pixel_values,
Expand Down
Loading
Loading