Skip to content

Commit 433d2a2

Browse files
authored
Update SAM/SAM HQ attention implementation + fix Cuda sync issues (#39386)
* update attention implementation and improve inference speed * modular sam_hq + fix integration tests on A10 * fixup * fix after review * softmax in correct place * return attn_weights in sam/sam_hq
1 parent 541bed2 commit 433d2a2

File tree

4 files changed

+151
-189
lines changed

4 files changed

+151
-189
lines changed

src/transformers/models/sam/modeling_sam.py

Lines changed: 70 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import collections
1818
from dataclasses import dataclass
19-
from typing import Optional, Union
19+
from typing import Callable, Optional, Union
2020

2121
import numpy as np
2222
import torch
@@ -28,7 +28,7 @@
2828
from ...activations import ACT2FN
2929
from ...modeling_layers import GradientCheckpointingLayer
3030
from ...modeling_outputs import BaseModelOutput
31-
from ...modeling_utils import PreTrainedModel
31+
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3232
from ...processing_utils import Unpack
3333
from ...utils import (
3434
ModelOutput,
@@ -178,6 +178,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
178178
return x
179179

180180

181+
def eager_attention_forward(
182+
module: nn.Module,
183+
query: torch.Tensor,
184+
key: torch.Tensor,
185+
value: torch.Tensor,
186+
attention_mask: Optional[torch.Tensor],
187+
scaling: float,
188+
dropout: float = 0.0,
189+
**kwargs,
190+
):
191+
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
192+
if attention_mask is not None:
193+
attn_weights = attn_weights + attention_mask
194+
195+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
196+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
197+
attn_output = torch.matmul(attn_weights, value)
198+
attn_output = attn_output.transpose(1, 2).contiguous()
199+
200+
return attn_output, attn_weights
201+
202+
181203
class SamAttention(nn.Module):
182204
"""
183205
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
@@ -186,6 +208,7 @@ class SamAttention(nn.Module):
186208

187209
def __init__(self, config, downsample_rate=None):
188210
super().__init__()
211+
self.config = config
189212
self.hidden_size = config.hidden_size
190213

191214
downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
@@ -194,25 +217,32 @@ def __init__(self, config, downsample_rate=None):
194217
self.num_attention_heads = config.num_attention_heads
195218
if self.internal_dim % config.num_attention_heads != 0:
196219
raise ValueError("num_attention_heads must divide hidden_size.")
220+
self.scaling = (self.internal_dim // config.num_attention_heads) ** -0.5
197221

198222
self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
199223
self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
200224
self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
201225
self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
202226

227+
self.is_causal = False
228+
203229
def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
204230
batch, point_batch_size, n_tokens, channel = hidden_states.shape
205231
c_per_head = channel // num_attention_heads
206232
hidden_states = hidden_states.reshape(batch * point_batch_size, n_tokens, num_attention_heads, c_per_head)
207233
return hidden_states.transpose(1, 2)
208234

209235
def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
210-
batch, n_heads, n_tokens, c_per_head = hidden_states.shape
211-
hidden_states = hidden_states.transpose(1, 2)
236+
batch, n_tokens, n_heads, c_per_head = hidden_states.shape
212237
return hidden_states.reshape(batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head)
213238

214239
def forward(
215-
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
240+
self,
241+
query: Tensor,
242+
key: Tensor,
243+
value: Tensor,
244+
attention_similarity: Optional[Tensor] = None,
245+
**kwargs: Unpack[TransformersKwargs],
216246
) -> Tensor:
217247
# Input projections
218248
query = self.q_proj(query)
@@ -226,64 +256,26 @@ def forward(
226256
value = self._separate_heads(value, self.num_attention_heads)
227257

228258
# SamAttention
229-
_, _, _, c_per_head = query.shape
230-
attn = query @ key.permute(0, 1, 3, 2) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
231-
attn = attn / (c_per_head**0.5)
232-
attn = torch.softmax(attn, dim=-1)
233-
234-
if attention_similarity is not None:
235-
attn = attn + attention_similarity
236-
attn = torch.softmax(attn, dim=-1)
237-
238-
# Get output
239-
out = attn @ value
240-
out = self._recombine_heads(out, point_batch_size)
241-
out = self.out_proj(out)
242-
243-
return out
244-
245-
246-
class SamSdpaAttention(SamAttention):
247-
"""
248-
SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
249-
values. Using SDPA instead of the default attention.
250-
"""
251-
252-
def __init__(self, config, downsample_rate=None):
253-
super().__init__(config, downsample_rate)
254-
255-
def forward(
256-
self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Optional[Tensor] = None
257-
) -> Tensor:
258-
# Input projections
259-
query = self.q_proj(query)
260-
key = self.k_proj(key)
261-
value = self.v_proj(value)
262-
263-
point_batch_size = query.shape[1]
264-
# Separate into heads
265-
query = self._separate_heads(query, self.num_attention_heads)
266-
key = self._separate_heads(key, self.num_attention_heads)
267-
value = self._separate_heads(value, self.num_attention_heads)
268-
269-
# Scaled dot product attention
270-
attn_mask = None
271-
if attention_similarity is not None:
272-
attn_mask = attention_similarity.unsqueeze(1).expand(-1, self.num_attention_heads, -1, -1)
273-
274-
out = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask)
275-
276-
# Get output
277-
out = self._recombine_heads(out, point_batch_size)
278-
out = self.out_proj(out)
279-
280-
return out
259+
attention_interface: Callable = eager_attention_forward
260+
if self.config._attn_implementation != "eager":
261+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
262+
263+
attn_output, attn_weights = attention_interface(
264+
self,
265+
query,
266+
key,
267+
value,
268+
attention_mask=attention_similarity,
269+
dropout=0.0 if not self.training else self.dropout_p,
270+
scaling=self.scaling,
271+
is_causal=self.is_causal,
272+
**kwargs,
273+
)
281274

275+
attn_output = self._recombine_heads(attn_output, point_batch_size)
276+
attn_output = self.out_proj(attn_output)
282277

283-
SAM_ATTENTION_CLASSES = {
284-
"eager": SamAttention,
285-
"sdpa": SamSdpaAttention,
286-
}
278+
return attn_output, attn_weights
287279

288280

289281
class SamTwoWayAttentionBlock(nn.Module):
@@ -306,21 +298,17 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_
306298
self.hidden_size = config.hidden_size
307299
self.layer_norm_eps = config.layer_norm_eps
308300

309-
self.self_attn = SAM_ATTENTION_CLASSES[config._attn_implementation](config, downsample_rate=1)
301+
self.self_attn = SamAttention(config, downsample_rate=1)
310302
self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
311303

312-
self.cross_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](
313-
config, downsample_rate=attention_downsample_rate
314-
)
304+
self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
315305
self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
316306

317307
self.mlp = SamMLPBlock(config)
318308
self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
319309

320310
self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
321-
self.cross_attn_image_to_token = SAM_ATTENTION_CLASSES[config._attn_implementation](
322-
config, downsample_rate=attention_downsample_rate
323-
)
311+
self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
324312
self.skip_first_layer_pe = skip_first_layer_pe
325313

326314
def forward(
@@ -330,21 +318,22 @@ def forward(
330318
query_point_embedding: Tensor,
331319
key_point_embedding: Tensor,
332320
attention_similarity: Tensor,
321+
**kwargs: Unpack[TransformersKwargs],
333322
):
334323
# Self attention block
335324
if self.skip_first_layer_pe:
336-
queries = self.self_attn(query=queries, key=queries, value=queries)
325+
queries, _ = self.self_attn(query=queries, key=queries, value=queries)
337326
else:
338327
query = queries + query_point_embedding
339-
attn_out = self.self_attn(query=query, key=query, value=queries)
328+
attn_out, _ = self.self_attn(query=query, key=query, value=queries)
340329
queries = queries + attn_out
341330
queries = self.layer_norm1(queries)
342331

343332
# Cross attention block, tokens attending to image embedding
344333
query = queries + query_point_embedding
345334
key = keys + key_point_embedding
346335

347-
attn_out = self.cross_attn_token_to_image(
336+
attn_out, _ = self.cross_attn_token_to_image(
348337
query=query, key=key, value=keys, attention_similarity=attention_similarity
349338
)
350339
queries = queries + attn_out
@@ -360,7 +349,7 @@ def forward(
360349
query = queries + query_point_embedding
361350
key = keys + key_point_embedding
362351

363-
attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
352+
attn_out, _ = self.cross_attn_image_to_token(query=key, key=query, value=queries)
364353
keys = keys + attn_out
365354

366355
keys = self.layer_norm4(keys)
@@ -378,7 +367,7 @@ def __init__(self, config: SamMaskDecoderConfig):
378367
for i in range(self.num_hidden_layers):
379368
self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
380369

381-
self.final_attn_token_to_image = SAM_ATTENTION_CLASSES[config._attn_implementation](config)
370+
self.final_attn_token_to_image = SamAttention(config)
382371
self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
383372

384373
def forward(
@@ -388,6 +377,7 @@ def forward(
388377
image_positional_embeddings: Tensor,
389378
attention_similarity: Tensor,
390379
target_embedding=None,
380+
**kwargs: Unpack[TransformersKwargs],
391381
) -> Union[tuple, BaseModelOutput]:
392382
if image_embeddings is None:
393383
raise ValueError("You have to specify an image_embedding")
@@ -410,12 +400,13 @@ def forward(
410400
query_point_embedding=point_embeddings,
411401
key_point_embedding=image_positional_embeddings,
412402
attention_similarity=attention_similarity,
403+
**kwargs,
413404
)
414405
# Apply the final attenion layer from the points to the image
415406
query = queries + point_embeddings
416407
key = keys + image_positional_embeddings
417408

418-
attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
409+
attn_out, _ = self.final_attn_token_to_image(query=query, key=key, value=keys)
419410

420411
queries = queries + attn_out
421412
queries = self.layer_norm_final_attn(queries)
@@ -501,12 +492,12 @@ def forward(
501492
Whether to return multiple masks or a single mask.
502493
"""
503494
batch_size, num_channels, height, width = image_embeddings.shape
504-
point_batch_size = sparse_prompt_embeddings.shape[1]
495+
point_batch_size = sparse_prompt_embeddings.shape[1] if sparse_prompt_embeddings is not None else 1
505496
# Concatenate output tokens
506497
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
507498
output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
508499

509-
if sparse_prompt_embeddings.sum().item() != 0:
500+
if sparse_prompt_embeddings is not None:
510501
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
511502
else:
512503
tokens = output_tokens
@@ -611,7 +602,7 @@ def forward(self, masks):
611602

612603

613604
class SamPromptEncoder(nn.Module):
614-
def __init__(self, config: SamPromptEncoderConfig):
605+
def __init__(self, config: SamConfig):
615606
super().__init__()
616607
self.shared_embedding = SamPositionalEmbedding(config.vision_config)
617608
config = config.prompt_encoder_config
@@ -645,11 +636,7 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -
645636

646637
# This is required for the ONNX export. The dtype, device need to be explicitly
647638
# specified as otherwise torch.onnx.export interprets as double
648-
point_embedding = torch.where(
649-
labels[..., None] != -10,
650-
point_embedding,
651-
torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
652-
)
639+
point_embedding = torch.where(labels[..., None] != -10, point_embedding, torch.zeros_like(point_embedding))
653640

654641
point_embedding = torch.where(
655642
(labels == 0)[:, :, :, None],
@@ -696,9 +683,8 @@ def forward(
696683
"""
697684
sparse_embeddings = None
698685
batch_size = 1
699-
target_device = self.shared_embedding.positional_embedding.device
700686
if input_points is not None:
701-
batch_size, point_batch_size = input_points.shape[:2]
687+
batch_size = input_points.shape[0]
702688
if input_labels is None:
703689
raise ValueError("If points are provided, labels must also be provided.")
704690
point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
@@ -717,9 +703,6 @@ def forward(
717703
batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
718704
)
719705

720-
if sparse_embeddings is None:
721-
sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
722-
723706
return sparse_embeddings, dense_embeddings
724707

725708

@@ -1184,10 +1167,6 @@ def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs
11841167
Args:
11851168
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
11861169
Input pixel values
1187-
output_attentions (`bool`, *optional*):
1188-
Whether or not to return the attentions tensors of all attention layers.
1189-
output_hidden_states (`bool`, *optional*):
1190-
Whether or not to return the hidden states of all layers.
11911170
"""
11921171
vision_output = self.vision_encoder(
11931172
pixel_values,

0 commit comments

Comments
 (0)