@@ -279,38 +279,48 @@ def eager_attention_forward(
279
279
class Glm4vVisionAttention (nn .Module ):
280
280
def __init__ (self , config : Glm4vVisionConfig ) -> None :
281
281
super ().__init__ ()
282
- self .config = config
282
+ self .dim = config . hidden_size
283
283
self .num_heads = config .num_heads
284
- self .head_dim = config .hidden_size // self .num_heads
285
- self .num_key_value_groups = 1
286
- self .scale = self .head_dim ** - 0.5
287
- self .attention_dropout = config .attention_dropout
284
+ self .head_dim = self .dim // self .num_heads
285
+ self .num_key_value_groups = 1 # needed for eager attention
288
286
self .qkv = nn .Linear (config .hidden_size , config .hidden_size * 3 , bias = config .attention_bias )
289
287
self .proj = nn .Linear (config .hidden_size , config .hidden_size , bias = False )
288
+ self .scaling = self .head_dim ** - 0.5
289
+ self .config = config
290
+ self .attention_dropout = config .attention_dropout
291
+ self .is_causal = False
290
292
291
293
def forward (
292
294
self ,
293
295
hidden_states : torch .Tensor ,
294
296
cu_seqlens : torch .Tensor ,
295
297
rotary_pos_emb : Optional [torch .Tensor ] = None ,
296
298
position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
297
- ** kwargs : Unpack [FlashAttentionKwargs ],
299
+ attention_mask : Optional [torch .Tensor ] = None ,
300
+ ** kwargs ,
298
301
) -> torch .Tensor :
299
302
seq_length = hidden_states .shape [0 ]
300
303
query_states , key_states , value_states = (
301
304
self .qkv (hidden_states ).reshape (seq_length , 3 , self .num_heads , - 1 ).permute (1 , 0 , 2 , 3 ).unbind (0 )
302
305
)
303
-
304
- cos , sin = position_embeddings
306
+ if position_embeddings is None :
307
+ logger .warning_once (
308
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
309
+ "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
310
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
311
+ "removed and `position_embeddings` will be mandatory."
312
+ )
313
+ emb = torch .cat ((rotary_pos_emb , rotary_pos_emb ), dim = - 1 )
314
+ cos = emb .cos ()
315
+ sin = emb .sin ()
316
+ else :
317
+ cos , sin = position_embeddings
305
318
query_states , key_states = apply_rotary_pos_emb_vision (query_states , key_states , cos , sin )
306
319
307
320
query_states = query_states .transpose (0 , 1 ).unsqueeze (0 )
308
321
key_states = key_states .transpose (0 , 1 ).unsqueeze (0 )
309
322
value_states = value_states .transpose (0 , 1 ).unsqueeze (0 )
310
-
311
- attention_mask = torch .zeros ([1 , 1 , seq_length , seq_length ], device = query_states .device , dtype = torch .bool )
312
- for i in range (1 , len (cu_seqlens )):
313
- attention_mask [..., cu_seqlens [i - 1 ] : cu_seqlens [i ], cu_seqlens [i - 1 ] : cu_seqlens [i ]] = True
323
+ max_seqlen = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).max ().item ()
314
324
315
325
attention_interface : Callable = eager_attention_forward
316
326
if self .config ._attn_implementation != "eager" :
@@ -321,13 +331,17 @@ def forward(
321
331
query_states ,
322
332
key_states ,
323
333
value_states ,
324
- attention_mask ,
334
+ attention_mask = attention_mask ,
325
335
dropout = 0.0 if not self .training else self .attention_dropout ,
326
- scaling = self .scale ,
336
+ scaling = self .scaling ,
337
+ cu_seq_lens_q = cu_seqlens , # pass cu seq lens for FA2
338
+ cu_seq_lens_k = cu_seqlens ,
339
+ max_length_q = max_seqlen ,
340
+ max_length_k = max_seqlen ,
327
341
is_causal = False ,
328
342
** kwargs ,
329
343
)
330
- attn_output = attn_output . squeeze ( 0 )
344
+
331
345
attn_output = attn_output .reshape (seq_length , - 1 ).contiguous ()
332
346
attn_output = self .proj (attn_output )
333
347
return attn_output
@@ -347,13 +361,15 @@ def forward(
347
361
cu_seqlens : torch .Tensor ,
348
362
rotary_pos_emb : Optional [torch .Tensor ] = None ,
349
363
position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
364
+ attention_mask : Optional [torch .Tensor ] = None ,
350
365
** kwargs ,
351
366
) -> torch .Tensor :
352
367
hidden_states = hidden_states + self .attn (
353
368
self .norm1 (hidden_states ),
354
369
cu_seqlens = cu_seqlens ,
355
370
rotary_pos_emb = rotary_pos_emb ,
356
371
position_embeddings = position_embeddings ,
372
+ attention_mask = attention_mask ,
357
373
** kwargs ,
358
374
)
359
375
hidden_states = hidden_states + self .mlp (self .norm2 (hidden_states ))
@@ -451,6 +467,25 @@ def rot_pos_emb(self, grid_thw):
451
467
rotary_pos_emb = rotary_pos_emb_full [pos_ids ].flatten (1 )
452
468
return rotary_pos_emb , pos_ids
453
469
470
+ def _prepare_attention_mask (self , inputs_tensor : torch .Tensor , cu_seqlens : torch .Tensor ) -> torch .Tensor :
471
+ # Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
472
+ # NOTE: the created attention masl only approximates the ragged FA2 attention by
473
+ # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
474
+ # blocks. Though it will not be a 100% match for FA2's `varlen` path
475
+ if self .config ._attn_implementation == "flash_attention_2" :
476
+ return None
477
+
478
+ seq_length = inputs_tensor .shape [0 ]
479
+ attention_mask = torch .full (
480
+ [1 , 1 , seq_length , seq_length ],
481
+ torch .finfo (inputs_tensor .dtype ).min ,
482
+ device = inputs_tensor .device ,
483
+ dtype = inputs_tensor .dtype ,
484
+ )
485
+ for i in range (1 , len (cu_seqlens )):
486
+ attention_mask [..., cu_seqlens [i - 1 ] : cu_seqlens [i ], cu_seqlens [i - 1 ] : cu_seqlens [i ]] = 0
487
+ return attention_mask
488
+
454
489
def forward (self , hidden_states : torch .Tensor , grid_thw : torch .Tensor ) -> torch .Tensor :
455
490
"""
456
491
Args:
@@ -480,14 +515,15 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
480
515
cu_seqlens = F .pad (cu_seqlens , (1 , 0 ), value = 0 )
481
516
seqlens = (cu_seqlens [1 :] - cu_seqlens [:- 1 ]).tolist ()
482
517
hidden_states = self .embeddings (hidden_states , seqlens , grid_thw , image_type_ids [:, 0 ], image_type_ids [:, 1 ])
518
+ attention_mask = self ._prepare_attention_mask (hidden_states , cu_seqlens = cu_seqlens )
483
519
484
520
for blk in self .blocks :
485
- if self . gradient_checkpointing and self . training :
486
- hidden_states = self . _gradient_checkpointing_func (
487
- blk . __call__ , hidden_states , cu_seqlens , None , position_embeddings
488
- )
489
- else :
490
- hidden_states = blk ( hidden_states , cu_seqlens = cu_seqlens , position_embeddings = position_embeddings )
521
+ hidden_states = blk (
522
+ hidden_states ,
523
+ cu_seqlens = cu_seqlens ,
524
+ position_embeddings = position_embeddings ,
525
+ attention_mask = attention_mask ,
526
+ )
491
527
492
528
hidden_states = self .post_layernorm (hidden_states )
493
529
0 commit comments