-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Update SAM/SAM HQ attention implementation + fix Cuda sync issues #39386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
579d339
bc90c56
e7e86f9
f60895e
1cf9641
ddba952
5e38f2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
|
|
||
| import collections | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Union | ||
| from typing import Callable, Optional, Union | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
@@ -28,7 +28,7 @@ | |
| from ...activations import ACT2FN | ||
| from ...modeling_layers import GradientCheckpointingLayer | ||
| from ...modeling_outputs import BaseModelOutput | ||
| from ...modeling_utils import PreTrainedModel | ||
| from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel | ||
| from ...processing_utils import Unpack | ||
| from ...utils import ( | ||
| ModelOutput, | ||
|
|
@@ -178,6 +178,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x | ||
|
|
||
|
|
||
| def eager_attention_forward( | ||
| module: nn.Module, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| scaling: float, | ||
| dropout: float = 0.0, | ||
| **kwargs, | ||
| ): | ||
| attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling | ||
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) | ||
| if attention_mask is not None: | ||
| attn_weights = attn_weights + attention_mask | ||
| attn_weights = torch.softmax(attn_weights, dim=-1) | ||
|
||
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | ||
| attn_output = torch.matmul(attn_weights, value) | ||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
|
|
||
| return attn_output, attn_weights | ||
|
|
||
|
|
||
| class SamAttention(nn.Module): | ||
| """ | ||
| SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and | ||
|
|
@@ -186,6 +208,7 @@ class SamAttention(nn.Module): | |
|
|
||
| def __init__(self, config, downsample_rate=None): | ||
| super().__init__() | ||
| self.config = config | ||
| self.hidden_size = config.hidden_size | ||
|
|
||
| downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate | ||
|
|
@@ -200,19 +223,25 @@ def __init__(self, config, downsample_rate=None): | |
| self.v_proj = nn.Linear(self.hidden_size, self.internal_dim) | ||
| self.out_proj = nn.Linear(self.internal_dim, self.hidden_size) | ||
|
|
||
| self.is_causal = False | ||
|
|
||
| def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor: | ||
| batch, point_batch_size, n_tokens, channel = hidden_states.shape | ||
| c_per_head = channel // num_attention_heads | ||
| hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head) | ||
| return hidden_states.transpose(1, 2) | ||
|
|
||
| def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor: | ||
| batch, n_heads, n_tokens, c_per_head = hidden_states.shape | ||
| hidden_states = hidden_states.transpose(1, 2) | ||
| batch, n_tokens, n_heads, c_per_head = hidden_states.shape | ||
| return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head) | ||
|
|
||
| def forward( | ||
| self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None | ||
| self, | ||
| query: Tensor, | ||
| key: Tensor, | ||
| value: Tensor, | ||
| attention_similarity: Optional[Tensor] = None, | ||
| **kwargs: Unpack[TransformersKwargs], | ||
| ) -> Tensor: | ||
| # Input projections | ||
| query = self.q_proj(query) | ||
|
|
@@ -226,66 +255,29 @@ def forward( | |
| value = self._separate_heads(value, self.num_attention_heads) | ||
|
|
||
| # SamAttention | ||
| _, _, _, c_per_head = query.shape | ||
| attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens | ||
| attn = attn / (c_per_head**0.5) | ||
| attn = torch.softmax(attn, dim=-1) | ||
|
|
||
| if attention_similarity is not None: | ||
| attn = attn + attention_similarity | ||
| attn = torch.softmax(attn, dim=-1) | ||
|
|
||
| # Get output | ||
| out = attn @ value | ||
| out = self._recombine_heads(out, point_batch_size) | ||
| out = self.out_proj(out) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| class SamSdpaAttention(SamAttention): | ||
| """ | ||
| SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and | ||
| values. Using SDPA instead of the default attention. | ||
| """ | ||
|
|
||
| def __init__(self, config, downsample_rate=None): | ||
| super().__init__(config, downsample_rate) | ||
|
|
||
| def forward( | ||
| self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None | ||
| ) -> Tensor: | ||
| # Input projections | ||
| query = self.q_proj(query) | ||
| key = self.k_proj(key) | ||
| value = self.v_proj(value) | ||
|
|
||
| point_batch_size = query.shape[1] | ||
| # Separate into heads | ||
| query = self._separate_heads(query, self.num_attention_heads) | ||
| key = self._separate_heads(key, self.num_attention_heads) | ||
| value = self._separate_heads(value, self.num_attention_heads) | ||
|
|
||
| # Scaled dot product attention | ||
| attn_mask = None | ||
| if attention_similarity is not None: | ||
| attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1) | ||
|
|
||
| out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask) | ||
| scale = query.shape[-1] ** -0.5 | ||
yonigozlan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| attention_interface: Callable = eager_attention_forward | ||
| if self.config._attn_implementation != "eager": | ||
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | ||
|
|
||
| attn_output, _ = attention_interface( | ||
|
||
| self, | ||
| query, | ||
| key, | ||
| value, | ||
| attention_mask=attention_similarity, | ||
| dropout=0.0 if not self.training else self.dropout_p, | ||
| scaling=scale, | ||
| is_causal=self.is_causal, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| # Get output | ||
| out = self._recombine_heads(out, point_batch_size) | ||
| out = self._recombine_heads(attn_output, point_batch_size) | ||
| out = self.out_proj(out) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| SAM_ATTENTION_CLASSES = { | ||
| "eager": SamAttention, | ||
| "sdpa": SamSdpaAttention, | ||
| } | ||
|
|
||
|
|
||
| class SamTwoWayAttentionBlock(nn.Module): | ||
| def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False): | ||
| """ | ||
|
|
@@ -306,21 +298,17 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_ | |
| self.hidden_size = config.hidden_size | ||
| self.layer_norm_eps = config.layer_norm_eps | ||
|
|
||
| self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1) | ||
| self.self_attn = SamAttention(config, downsample_rate=1) | ||
| self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) | ||
|
|
||
| self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation]( | ||
| config, downsample_rate=attention_downsample_rate | ||
| ) | ||
| self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate) | ||
| self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) | ||
|
|
||
| self.mlp = SamMLPBlock(config) | ||
| self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) | ||
|
|
||
| self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) | ||
| self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation]( | ||
| config, downsample_rate=attention_downsample_rate | ||
| ) | ||
| self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate) | ||
| self.skip_first_layer_pe = skip_first_layer_pe | ||
|
|
||
| def forward( | ||
|
|
@@ -330,6 +318,7 @@ def forward( | |
| query_point_embedding: Tensor, | ||
| key_point_embedding: Tensor, | ||
| attention_similarity: Tensor, | ||
| **kwargs: Unpack[TransformersKwargs], | ||
| ): | ||
| # Self attention block | ||
| if self.skip_first_layer_pe: | ||
|
|
@@ -378,7 +367,7 @@ def __init__(self, config: SamMaskDecoderConfig): | |
| for i in range(self.num_hidden_layers): | ||
| self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0))) | ||
|
|
||
| self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config) | ||
| self.final_attn_token_to_image = SamAttention(config) | ||
| self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size) | ||
|
|
||
| def forward( | ||
|
|
@@ -388,6 +377,7 @@ def forward( | |
| image_positional_embeddings: Tensor, | ||
| attention_similarity: Tensor, | ||
| target_embedding=None, | ||
| **kwargs: Unpack[TransformersKwargs], | ||
| ) -> Union[tuple, BaseModelOutput]: | ||
| if image_embeddings is None: | ||
| raise ValueError("You have to specify an image_embedding") | ||
|
|
@@ -410,6 +400,7 @@ def forward( | |
| query_point_embedding=point_embeddings, | ||
| key_point_embedding=image_positional_embeddings, | ||
| attention_similarity=attention_similarity, | ||
| **kwargs, | ||
| ) | ||
| # Apply the final attenion layer from the points to the image | ||
| query = queries + point_embeddings | ||
|
|
@@ -501,12 +492,12 @@ def forward( | |
| Whether to return multiple masks or a single mask. | ||
| """ | ||
| batch_size, num_channels, height, width = image_embeddings.shape | ||
| point_batch_size = sparse_prompt_embeddings.shape[1] | ||
| point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1 | ||
| # Concatenate output tokens | ||
| output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) | ||
| output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1) | ||
|
|
||
| if sparse_prompt_embeddings.sum().item() != 0: | ||
| if sparse_prompt_embeddings is not None: | ||
|
Comment on lines
-509
to
+500
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was causing cuda sync issue and does not seem to be necessary if we set sparse_prompt_embeddings to None when it's not provided in the prompt encoder |
||
| tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2) | ||
| else: | ||
| tokens = output_tokens | ||
|
|
@@ -611,7 +602,7 @@ def forward(self, masks): | |
|
|
||
|
|
||
| class SamPromptEncoder(nn.Module): | ||
| def __init__(self, config: SamPromptEncoderConfig): | ||
| def __init__(self, config: SamConfig): | ||
| super().__init__() | ||
| self.shared_embedding = SamPositionalEmbedding(config.vision_config) | ||
| config = config.prompt_encoder_config | ||
|
|
@@ -648,7 +639,7 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) - | |
| point_embedding = torch.where( | ||
| labels[..., None] != -10, | ||
| point_embedding, | ||
| torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device), | ||
| torch.zeros_like(point_embedding, dtype=point_embedding.dtype, device=point_embedding.device), | ||
yonigozlan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| point_embedding = torch.where( | ||
|
|
@@ -717,9 +708,6 @@ def forward( | |
| batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1] | ||
| ) | ||
|
|
||
| if sparse_embeddings is None: | ||
| sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device) | ||
|
|
||
| return sparse_embeddings, dense_embeddings | ||
|
|
||
|
|
||
|
|
@@ -1182,10 +1170,6 @@ def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs | |
| Args: | ||
| pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): | ||
| Input pixel values | ||
| output_attentions (`bool`, *optional*): | ||
| Whether or not to return the attentions tensors of all attention layers. | ||
| output_hidden_states (`bool`, *optional*): | ||
| Whether or not to return the hidden states of all layers. | ||
| """ | ||
| vision_output = self.vision_encoder( | ||
| pixel_values, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be mask then softmax :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course 🤦Thanks