1616
1717import collections
1818from dataclasses import dataclass
19- from typing import Optional , Union
19+ from typing import Callable , Optional , Union
2020
2121import numpy as np
2222import torch
2828from ...activations import ACT2FN
2929from ...modeling_layers import GradientCheckpointingLayer
3030from ...modeling_outputs import BaseModelOutput
31- from ...modeling_utils import PreTrainedModel
31+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS , PreTrainedModel
3232from ...processing_utils import Unpack
3333from ...utils import (
3434 ModelOutput ,
@@ -178,6 +178,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
178178 return x
179179
180180
181+ def eager_attention_forward (
182+ module : nn .Module ,
183+ query : torch .Tensor ,
184+ key : torch .Tensor ,
185+ value : torch .Tensor ,
186+ attention_mask : Optional [torch .Tensor ],
187+ scaling : float ,
188+ dropout : float = 0.0 ,
189+ ** kwargs ,
190+ ):
191+ attn_weights = torch .matmul (query , key .transpose (2 , 3 )) * scaling
192+ if attention_mask is not None :
193+ attn_weights = attn_weights + attention_mask
194+
195+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
196+ attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
197+ attn_output = torch .matmul (attn_weights , value )
198+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
199+
200+ return attn_output , attn_weights
201+
202+
181203class SamAttention (nn .Module ):
182204 """
183205 SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
@@ -186,6 +208,7 @@ class SamAttention(nn.Module):
186208
187209 def __init__ (self , config , downsample_rate = None ):
188210 super ().__init__ ()
211+ self .config = config
189212 self .hidden_size = config .hidden_size
190213
191214 downsample_rate = config .attention_downsample_rate if downsample_rate is None else downsample_rate
@@ -194,25 +217,32 @@ def __init__(self, config, downsample_rate=None):
194217 self .num_attention_heads = config .num_attention_heads
195218 if self .internal_dim % config .num_attention_heads != 0 :
196219 raise ValueError ("num_attention_heads must divide hidden_size." )
220+ self .scaling = (self .internal_dim // config .num_attention_heads ) ** - 0.5
197221
198222 self .q_proj = nn .Linear (self .hidden_size , self .internal_dim )
199223 self .k_proj = nn .Linear (self .hidden_size , self .internal_dim )
200224 self .v_proj = nn .Linear (self .hidden_size , self .internal_dim )
201225 self .out_proj = nn .Linear (self .internal_dim , self .hidden_size )
202226
227+ self .is_causal = False
228+
203229 def _separate_heads (self , hidden_states : Tensor , num_attention_heads : int ) -> Tensor :
204230 batch , point_batch_size , n_tokens , channel = hidden_states .shape
205231 c_per_head = channel // num_attention_heads
206232 hidden_states = hidden_states .reshape (batch * point_batch_size , n_tokens , num_attention_heads , c_per_head )
207233 return hidden_states .transpose (1 , 2 )
208234
209235 def _recombine_heads (self , hidden_states : Tensor , point_batch_size : int ) -> Tensor :
210- batch , n_heads , n_tokens , c_per_head = hidden_states .shape
211- hidden_states = hidden_states .transpose (1 , 2 )
236+ batch , n_tokens , n_heads , c_per_head = hidden_states .shape
212237 return hidden_states .reshape (batch // point_batch_size , point_batch_size , n_tokens , n_heads * c_per_head )
213238
214239 def forward (
215- self , query : Tensor , key : Tensor , value : Tensor , attention_similarity : Optional [Tensor ] = None
240+ self ,
241+ query : Tensor ,
242+ key : Tensor ,
243+ value : Tensor ,
244+ attention_similarity : Optional [Tensor ] = None ,
245+ ** kwargs : Unpack [TransformersKwargs ],
216246 ) -> Tensor :
217247 # Input projections
218248 query = self .q_proj (query )
@@ -226,64 +256,26 @@ def forward(
226256 value = self ._separate_heads (value , self .num_attention_heads )
227257
228258 # SamAttention
229- _ , _ , _ , c_per_head = query .shape
230- attn = query @ key .permute (0 , 1 , 3 , 2 ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
231- attn = attn / (c_per_head ** 0.5 )
232- attn = torch .softmax (attn , dim = - 1 )
233-
234- if attention_similarity is not None :
235- attn = attn + attention_similarity
236- attn = torch .softmax (attn , dim = - 1 )
237-
238- # Get output
239- out = attn @ value
240- out = self ._recombine_heads (out , point_batch_size )
241- out = self .out_proj (out )
242-
243- return out
244-
245-
246- class SamSdpaAttention (SamAttention ):
247- """
248- SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
249- values. Using SDPA instead of the default attention.
250- """
251-
252- def __init__ (self , config , downsample_rate = None ):
253- super ().__init__ (config , downsample_rate )
254-
255- def forward (
256- self , query : Tensor , key : Tensor , value : Tensor , attention_similarity : Optional [Tensor ] = None
257- ) -> Tensor :
258- # Input projections
259- query = self .q_proj (query )
260- key = self .k_proj (key )
261- value = self .v_proj (value )
262-
263- point_batch_size = query .shape [1 ]
264- # Separate into heads
265- query = self ._separate_heads (query , self .num_attention_heads )
266- key = self ._separate_heads (key , self .num_attention_heads )
267- value = self ._separate_heads (value , self .num_attention_heads )
268-
269- # Scaled dot product attention
270- attn_mask = None
271- if attention_similarity is not None :
272- attn_mask = attention_similarity .unsqueeze (1 ).expand (- 1 , self .num_attention_heads , - 1 , - 1 )
273-
274- out = F .scaled_dot_product_attention (query , key , value , attn_mask = attn_mask )
275-
276- # Get output
277- out = self ._recombine_heads (out , point_batch_size )
278- out = self .out_proj (out )
279-
280- return out
259+ attention_interface : Callable = eager_attention_forward
260+ if self .config ._attn_implementation != "eager" :
261+ attention_interface = ALL_ATTENTION_FUNCTIONS [self .config ._attn_implementation ]
262+
263+ attn_output , attn_weights = attention_interface (
264+ self ,
265+ query ,
266+ key ,
267+ value ,
268+ attention_mask = attention_similarity ,
269+ dropout = 0.0 if not self .training else self .dropout_p ,
270+ scaling = self .scaling ,
271+ is_causal = self .is_causal ,
272+ ** kwargs ,
273+ )
281274
275+ attn_output = self ._recombine_heads (attn_output , point_batch_size )
276+ attn_output = self .out_proj (attn_output )
282277
283- SAM_ATTENTION_CLASSES = {
284- "eager" : SamAttention ,
285- "sdpa" : SamSdpaAttention ,
286- }
278+ return attn_output , attn_weights
287279
288280
289281class SamTwoWayAttentionBlock (nn .Module ):
@@ -306,21 +298,17 @@ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_
306298 self .hidden_size = config .hidden_size
307299 self .layer_norm_eps = config .layer_norm_eps
308300
309- self .self_attn = SAM_ATTENTION_CLASSES [ config . _attn_implementation ] (config , downsample_rate = 1 )
301+ self .self_attn = SamAttention (config , downsample_rate = 1 )
310302 self .layer_norm1 = nn .LayerNorm (self .hidden_size , eps = self .layer_norm_eps )
311303
312- self .cross_attn_token_to_image = SAM_ATTENTION_CLASSES [config ._attn_implementation ](
313- config , downsample_rate = attention_downsample_rate
314- )
304+ self .cross_attn_token_to_image = SamAttention (config , downsample_rate = attention_downsample_rate )
315305 self .layer_norm2 = nn .LayerNorm (self .hidden_size , eps = self .layer_norm_eps )
316306
317307 self .mlp = SamMLPBlock (config )
318308 self .layer_norm3 = nn .LayerNorm (self .hidden_size , eps = self .layer_norm_eps )
319309
320310 self .layer_norm4 = nn .LayerNorm (self .hidden_size , eps = self .layer_norm_eps )
321- self .cross_attn_image_to_token = SAM_ATTENTION_CLASSES [config ._attn_implementation ](
322- config , downsample_rate = attention_downsample_rate
323- )
311+ self .cross_attn_image_to_token = SamAttention (config , downsample_rate = attention_downsample_rate )
324312 self .skip_first_layer_pe = skip_first_layer_pe
325313
326314 def forward (
@@ -330,21 +318,22 @@ def forward(
330318 query_point_embedding : Tensor ,
331319 key_point_embedding : Tensor ,
332320 attention_similarity : Tensor ,
321+ ** kwargs : Unpack [TransformersKwargs ],
333322 ):
334323 # Self attention block
335324 if self .skip_first_layer_pe :
336- queries = self .self_attn (query = queries , key = queries , value = queries )
325+ queries , _ = self .self_attn (query = queries , key = queries , value = queries )
337326 else :
338327 query = queries + query_point_embedding
339- attn_out = self .self_attn (query = query , key = query , value = queries )
328+ attn_out , _ = self .self_attn (query = query , key = query , value = queries )
340329 queries = queries + attn_out
341330 queries = self .layer_norm1 (queries )
342331
343332 # Cross attention block, tokens attending to image embedding
344333 query = queries + query_point_embedding
345334 key = keys + key_point_embedding
346335
347- attn_out = self .cross_attn_token_to_image (
336+ attn_out , _ = self .cross_attn_token_to_image (
348337 query = query , key = key , value = keys , attention_similarity = attention_similarity
349338 )
350339 queries = queries + attn_out
@@ -360,7 +349,7 @@ def forward(
360349 query = queries + query_point_embedding
361350 key = keys + key_point_embedding
362351
363- attn_out = self .cross_attn_image_to_token (query = key , key = query , value = queries )
352+ attn_out , _ = self .cross_attn_image_to_token (query = key , key = query , value = queries )
364353 keys = keys + attn_out
365354
366355 keys = self .layer_norm4 (keys )
@@ -378,7 +367,7 @@ def __init__(self, config: SamMaskDecoderConfig):
378367 for i in range (self .num_hidden_layers ):
379368 self .layers .append (SamTwoWayAttentionBlock (config , skip_first_layer_pe = (i == 0 )))
380369
381- self .final_attn_token_to_image = SAM_ATTENTION_CLASSES [ config . _attn_implementation ] (config )
370+ self .final_attn_token_to_image = SamAttention (config )
382371 self .layer_norm_final_attn = nn .LayerNorm (config .hidden_size )
383372
384373 def forward (
@@ -388,6 +377,7 @@ def forward(
388377 image_positional_embeddings : Tensor ,
389378 attention_similarity : Tensor ,
390379 target_embedding = None ,
380+ ** kwargs : Unpack [TransformersKwargs ],
391381 ) -> Union [tuple , BaseModelOutput ]:
392382 if image_embeddings is None :
393383 raise ValueError ("You have to specify an image_embedding" )
@@ -410,12 +400,13 @@ def forward(
410400 query_point_embedding = point_embeddings ,
411401 key_point_embedding = image_positional_embeddings ,
412402 attention_similarity = attention_similarity ,
403+ ** kwargs ,
413404 )
414405 # Apply the final attenion layer from the points to the image
415406 query = queries + point_embeddings
416407 key = keys + image_positional_embeddings
417408
418- attn_out = self .final_attn_token_to_image (query = query , key = key , value = keys )
409+ attn_out , _ = self .final_attn_token_to_image (query = query , key = key , value = keys )
419410
420411 queries = queries + attn_out
421412 queries = self .layer_norm_final_attn (queries )
@@ -501,12 +492,12 @@ def forward(
501492 Whether to return multiple masks or a single mask.
502493 """
503494 batch_size , num_channels , height , width = image_embeddings .shape
504- point_batch_size = sparse_prompt_embeddings .shape [1 ]
495+ point_batch_size = sparse_prompt_embeddings .shape [1 ] if sparse_prompt_embeddings is not None else 1
505496 # Concatenate output tokens
506497 output_tokens = torch .cat ([self .iou_token .weight , self .mask_tokens .weight ], dim = 0 )
507498 output_tokens = output_tokens .repeat (batch_size , point_batch_size , 1 , 1 )
508499
509- if sparse_prompt_embeddings . sum (). item () != 0 :
500+ if sparse_prompt_embeddings is not None :
510501 tokens = torch .cat ((output_tokens , sparse_prompt_embeddings ), dim = 2 )
511502 else :
512503 tokens = output_tokens
@@ -611,7 +602,7 @@ def forward(self, masks):
611602
612603
613604class SamPromptEncoder (nn .Module ):
614- def __init__ (self , config : SamPromptEncoderConfig ):
605+ def __init__ (self , config : SamConfig ):
615606 super ().__init__ ()
616607 self .shared_embedding = SamPositionalEmbedding (config .vision_config )
617608 config = config .prompt_encoder_config
@@ -645,11 +636,7 @@ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -
645636
646637 # This is required for the ONNX export. The dtype, device need to be explicitly
647638 # specified as otherwise torch.onnx.export interprets as double
648- point_embedding = torch .where (
649- labels [..., None ] != - 10 ,
650- point_embedding ,
651- torch .tensor (0.0 , dtype = point_embedding .dtype , device = point_embedding .device ),
652- )
639+ point_embedding = torch .where (labels [..., None ] != - 10 , point_embedding , torch .zeros_like (point_embedding ))
653640
654641 point_embedding = torch .where (
655642 (labels == 0 )[:, :, :, None ],
@@ -696,9 +683,8 @@ def forward(
696683 """
697684 sparse_embeddings = None
698685 batch_size = 1
699- target_device = self .shared_embedding .positional_embedding .device
700686 if input_points is not None :
701- batch_size , point_batch_size = input_points .shape [: 2 ]
687+ batch_size = input_points .shape [0 ]
702688 if input_labels is None :
703689 raise ValueError ("If points are provided, labels must also be provided." )
704690 point_embeddings = self ._embed_points (input_points , input_labels , pad = (input_boxes is None ))
@@ -717,9 +703,6 @@ def forward(
717703 batch_size , - 1 , self .image_embedding_size [0 ], self .image_embedding_size [1 ]
718704 )
719705
720- if sparse_embeddings is None :
721- sparse_embeddings = torch .zeros ((batch_size , 1 , 1 , self .hidden_size ), device = target_device )
722-
723706 return sparse_embeddings , dense_embeddings
724707
725708
@@ -1184,10 +1167,6 @@ def get_image_embeddings(self, pixel_values, **kwargs: Unpack[TransformersKwargs
11841167 Args:
11851168 pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
11861169 Input pixel values
1187- output_attentions (`bool`, *optional*):
1188- Whether or not to return the attentions tensors of all attention layers.
1189- output_hidden_states (`bool`, *optional*):
1190- Whether or not to return the hidden states of all layers.
11911170 """
11921171 vision_output = self .vision_encoder (
11931172 pixel_values ,
0 commit comments