Skip to content

Commit 3c5f910

Browse files
zucchini-nlpCyrilvallez
authored andcommitted
[qwen2-vl] fix FA2 inference (#39121)
* fix FA2 * update is causal flag and remove mask for FA2 * update for FA2 with varlen path * how the tests were passing with different devices? * add comment and ref to the PR * move mask preparation to base pretrained model * seq len is the first dim, not second * fix copies to fix GLM4V
1 parent 5e1c914 commit 3c5f910

File tree

11 files changed

+365
-199
lines changed

11 files changed

+365
-199
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,22 @@ def _flash_attention_forward(
508508
query_states, key_states, value_states, target_dtype
509509
)
510510

511+
# We will use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
512+
# under two cases:
513+
# Case 1. If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
514+
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
515+
# Case 2. Some models pass directly pre-computed `cu_seqlens` so we don't need to infer it from position ids. It is safe to
516+
# use `flash_attn_varlen_func` knowing we already have all necessary the kwargs. NOTE: it is user's responsibility
517+
# to take care of flattenning `position_ids` if that's needed by the model. See #39121 for more information
518+
is_fa2_with_position_ids = (
519+
position_ids is not None
520+
and query_states.shape[0] == 1
521+
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
522+
)
523+
is_fa2_with_varlen_kwargs = all(
524+
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
525+
)
526+
511527
# Contains at least one padding token in the sequence
512528
if attention_mask is not None:
513529
batch_size = query_states.shape[0]
@@ -531,14 +547,7 @@ def _flash_attention_forward(
531547
)
532548
attn_output = _pad_input(attn_output_unpad, indices_q, batch_size, query_length)
533549

534-
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
535-
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
536-
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
537-
elif (
538-
position_ids is not None
539-
and query_states.shape[0] == 1
540-
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
541-
):
550+
elif is_fa2_with_varlen_kwargs or is_fa2_with_position_ids:
542551
batch_size = query_states.size(0)
543552

544553
if cu_seq_lens_q is None or cu_seq_lens_k is None:

src/transformers/models/glm4v/modeling_glm4v.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -279,38 +279,48 @@ def eager_attention_forward(
279279
class Glm4vVisionAttention(nn.Module):
280280
def __init__(self, config: Glm4vVisionConfig) -> None:
281281
super().__init__()
282-
self.config = config
282+
self.dim = config.hidden_size
283283
self.num_heads = config.num_heads
284-
self.head_dim = config.hidden_size // self.num_heads
285-
self.num_key_value_groups = 1
286-
self.scale = self.head_dim**-0.5
287-
self.attention_dropout = config.attention_dropout
284+
self.head_dim = self.dim // self.num_heads
285+
self.num_key_value_groups = 1 # needed for eager attention
288286
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
289287
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
288+
self.scaling = self.head_dim**-0.5
289+
self.config = config
290+
self.attention_dropout = config.attention_dropout
291+
self.is_causal = False
290292

291293
def forward(
292294
self,
293295
hidden_states: torch.Tensor,
294296
cu_seqlens: torch.Tensor,
295297
rotary_pos_emb: Optional[torch.Tensor] = None,
296298
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
297-
**kwargs: Unpack[FlashAttentionKwargs],
299+
attention_mask: Optional[torch.Tensor] = None,
300+
**kwargs,
298301
) -> torch.Tensor:
299302
seq_length = hidden_states.shape[0]
300303
query_states, key_states, value_states = (
301304
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
302305
)
303-
304-
cos, sin = position_embeddings
306+
if position_embeddings is None:
307+
logger.warning_once(
308+
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
309+
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
310+
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
311+
"removed and `position_embeddings` will be mandatory."
312+
)
313+
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
314+
cos = emb.cos()
315+
sin = emb.sin()
316+
else:
317+
cos, sin = position_embeddings
305318
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
306319

307320
query_states = query_states.transpose(0, 1).unsqueeze(0)
308321
key_states = key_states.transpose(0, 1).unsqueeze(0)
309322
value_states = value_states.transpose(0, 1).unsqueeze(0)
310-
311-
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
312-
for i in range(1, len(cu_seqlens)):
313-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
323+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
314324

315325
attention_interface: Callable = eager_attention_forward
316326
if self.config._attn_implementation != "eager":
@@ -321,13 +331,17 @@ def forward(
321331
query_states,
322332
key_states,
323333
value_states,
324-
attention_mask,
334+
attention_mask=attention_mask,
325335
dropout=0.0 if not self.training else self.attention_dropout,
326-
scaling=self.scale,
336+
scaling=self.scaling,
337+
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
338+
cu_seq_lens_k=cu_seqlens,
339+
max_length_q=max_seqlen,
340+
max_length_k=max_seqlen,
327341
is_causal=False,
328342
**kwargs,
329343
)
330-
attn_output = attn_output.squeeze(0)
344+
331345
attn_output = attn_output.reshape(seq_length, -1).contiguous()
332346
attn_output = self.proj(attn_output)
333347
return attn_output
@@ -347,13 +361,15 @@ def forward(
347361
cu_seqlens: torch.Tensor,
348362
rotary_pos_emb: Optional[torch.Tensor] = None,
349363
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
364+
attention_mask: Optional[torch.Tensor] = None,
350365
**kwargs,
351366
) -> torch.Tensor:
352367
hidden_states = hidden_states + self.attn(
353368
self.norm1(hidden_states),
354369
cu_seqlens=cu_seqlens,
355370
rotary_pos_emb=rotary_pos_emb,
356371
position_embeddings=position_embeddings,
372+
attention_mask=attention_mask,
357373
**kwargs,
358374
)
359375
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
@@ -451,6 +467,25 @@ def rot_pos_emb(self, grid_thw):
451467
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
452468
return rotary_pos_emb, pos_ids
453469

470+
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
471+
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
472+
# NOTE: the created attention masl only approximates the ragged FA2 attention by
473+
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
474+
# blocks. Though it will not be a 100% match for FA2's `varlen` path
475+
if self.config._attn_implementation == "flash_attention_2":
476+
return None
477+
478+
seq_length = inputs_tensor.shape[0]
479+
attention_mask = torch.full(
480+
[1, 1, seq_length, seq_length],
481+
torch.finfo(inputs_tensor.dtype).min,
482+
device=inputs_tensor.device,
483+
dtype=inputs_tensor.dtype,
484+
)
485+
for i in range(1, len(cu_seqlens)):
486+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
487+
return attention_mask
488+
454489
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
455490
"""
456491
Args:
@@ -480,14 +515,15 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
480515
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
481516
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
482517
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
518+
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
483519

484520
for blk in self.blocks:
485-
if self.gradient_checkpointing and self.training:
486-
hidden_states = self._gradient_checkpointing_func(
487-
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
488-
)
489-
else:
490-
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
521+
hidden_states = blk(
522+
hidden_states,
523+
cu_seqlens=cu_seqlens,
524+
position_embeddings=position_embeddings,
525+
attention_mask=attention_mask,
526+
)
491527

492528
hidden_states = self.post_layernorm(hidden_states)
493529

src/transformers/models/glm4v/modular_glm4v.py

Lines changed: 28 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@
5050
Qwen2_5_VLPreTrainedModel,
5151
Qwen2_5_VLRotaryEmbedding,
5252
Qwen2_5_VLTextModel,
53+
Qwen2_5_VLVisionAttention,
5354
Qwen2_5_VLVisionBlock,
54-
apply_rotary_pos_emb_vision,
5555
)
5656
from ..qwen2_5_vl.processing_qwen2_5_vl import (
5757
Qwen2_5_VLProcessor,
@@ -505,62 +505,13 @@ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torc
505505
return embeddings
506506

507507

508-
class Glm4vVisionAttention(nn.Module):
508+
class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
509509
def __init__(self, config: Glm4vVisionConfig) -> None:
510510
super().__init__()
511-
self.config = config
512-
self.num_heads = config.num_heads
513-
self.head_dim = config.hidden_size // self.num_heads
514-
self.num_key_value_groups = 1
515-
self.scale = self.head_dim**-0.5
516511
self.attention_dropout = config.attention_dropout
517512
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
518513
self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
519514

520-
def forward(
521-
self,
522-
hidden_states: torch.Tensor,
523-
cu_seqlens: torch.Tensor,
524-
rotary_pos_emb: Optional[torch.Tensor] = None,
525-
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
526-
**kwargs: Unpack[FlashAttentionKwargs],
527-
) -> torch.Tensor:
528-
seq_length = hidden_states.shape[0]
529-
query_states, key_states, value_states = (
530-
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
531-
)
532-
533-
cos, sin = position_embeddings
534-
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
535-
536-
query_states = query_states.transpose(0, 1).unsqueeze(0)
537-
key_states = key_states.transpose(0, 1).unsqueeze(0)
538-
value_states = value_states.transpose(0, 1).unsqueeze(0)
539-
540-
attention_mask = torch.zeros([1, 1, seq_length, seq_length], device=query_states.device, dtype=torch.bool)
541-
for i in range(1, len(cu_seqlens)):
542-
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
543-
544-
attention_interface: Callable = eager_attention_forward
545-
if self.config._attn_implementation != "eager":
546-
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
547-
548-
attn_output, _ = attention_interface(
549-
self,
550-
query_states,
551-
key_states,
552-
value_states,
553-
attention_mask,
554-
dropout=0.0 if not self.training else self.attention_dropout,
555-
scaling=self.scale,
556-
is_causal=False,
557-
**kwargs,
558-
)
559-
attn_output = attn_output.squeeze(0)
560-
attn_output = attn_output.reshape(seq_length, -1).contiguous()
561-
attn_output = self.proj(attn_output)
562-
return attn_output
563-
564515

565516
class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
566517
def __init__(self, config) -> None:
@@ -652,6 +603,25 @@ def rot_pos_emb(self, grid_thw):
652603
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
653604
return rotary_pos_emb, pos_ids
654605

606+
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
607+
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
608+
# NOTE: the created attention masl only approximates the ragged FA2 attention by
609+
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
610+
# blocks. Though it will not be a 100% match for FA2's `varlen` path
611+
if self.config._attn_implementation == "flash_attention_2":
612+
return None
613+
614+
seq_length = inputs_tensor.shape[0]
615+
attention_mask = torch.full(
616+
[1, 1, seq_length, seq_length],
617+
torch.finfo(inputs_tensor.dtype).min,
618+
device=inputs_tensor.device,
619+
dtype=inputs_tensor.dtype,
620+
)
621+
for i in range(1, len(cu_seqlens)):
622+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
623+
return attention_mask
624+
655625
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
656626
"""
657627
Args:
@@ -681,14 +651,15 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
681651
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
682652
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
683653
hidden_states = self.embeddings(hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1])
654+
attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens=cu_seqlens)
684655

685656
for blk in self.blocks:
686-
if self.gradient_checkpointing and self.training:
687-
hidden_states = self._gradient_checkpointing_func(
688-
blk.__call__, hidden_states, cu_seqlens, None, position_embeddings
689-
)
690-
else:
691-
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
657+
hidden_states = blk(
658+
hidden_states,
659+
cu_seqlens=cu_seqlens,
660+
position_embeddings=position_embeddings,
661+
attention_mask=attention_mask,
662+
)
692663

693664
hidden_states = self.post_layernorm(hidden_states)
694665

0 commit comments

Comments
 (0)