44
44
)
45
45
from .configuration_qwen2_moe import Qwen2MoeConfig
46
46
47
+
47
48
if is_flash_attn_2_available ():
48
49
from flash_attn import flash_attn_func , flash_attn_varlen_func
49
50
from flash_attn .bert_padding import index_first_axis , pad_input , unpad_input # noqa
63
64
64
65
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
65
66
def load_balancing_loss_func (
66
- gate_logits : torch .Tensor , num_experts : torch .Tensor = None , top_k = 2 ,
67
- attention_mask : Optional [torch .Tensor ] = None
67
+ gate_logits : torch .Tensor , num_experts : torch .Tensor = None , top_k = 2 , attention_mask : Optional [torch .Tensor ] = None
68
68
) -> float :
69
69
r"""
70
70
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
@@ -112,9 +112,9 @@ def load_balancing_loss_func(
112
112
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
113
113
expert_attention_mask = (
114
114
attention_mask [None , :, :, None , None ]
115
- .expand ((num_hidden_layers , batch_size , sequence_length , 2 , num_experts ))
116
- .reshape (- 1 , 2 , num_experts )
117
- .to (compute_device )
115
+ .expand ((num_hidden_layers , batch_size , sequence_length , 2 , num_experts ))
116
+ .reshape (- 1 , 2 , num_experts )
117
+ .to (compute_device )
118
118
)
119
119
120
120
# Compute the percentage of tokens routed to each experts
@@ -125,9 +125,9 @@ def load_balancing_loss_func(
125
125
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
126
126
router_per_expert_attention_mask = (
127
127
attention_mask [None , :, :, None ]
128
- .expand ((num_hidden_layers , batch_size , sequence_length , num_experts ))
129
- .reshape (- 1 , num_experts )
130
- .to (compute_device )
128
+ .expand ((num_hidden_layers , batch_size , sequence_length , num_experts ))
129
+ .reshape (- 1 , num_experts )
130
+ .to (compute_device )
131
131
)
132
132
133
133
# Compute the average probability of routing to these experts
@@ -211,7 +211,7 @@ def forward(self, x, seq_len=None):
211
211
def rotate_half (x ):
212
212
"""Rotates half the hidden dims of the input."""
213
213
x1 = x [..., : x .shape [- 1 ] // 2 ]
214
- x2 = x [..., x .shape [- 1 ] // 2 :]
214
+ x2 = x [..., x .shape [- 1 ] // 2 :]
215
215
return torch .cat ((- x2 , x1 ), dim = - 1 )
216
216
217
217
@@ -318,14 +318,14 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
318
318
)
319
319
320
320
def forward (
321
- self ,
322
- hidden_states : torch .Tensor ,
323
- attention_mask : Optional [torch .Tensor ] = None ,
324
- position_ids : Optional [torch .LongTensor ] = None ,
325
- past_key_value : Optional [Cache ] = None ,
326
- output_attentions : bool = False ,
327
- use_cache : bool = False ,
328
- ** kwargs ,
321
+ self ,
322
+ hidden_states : torch .Tensor ,
323
+ attention_mask : Optional [torch .Tensor ] = None ,
324
+ position_ids : Optional [torch .LongTensor ] = None ,
325
+ past_key_value : Optional [Cache ] = None ,
326
+ output_attentions : bool = False ,
327
+ use_cache : bool = False ,
328
+ ** kwargs ,
329
329
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
330
330
if "padding_mask" in kwargs :
331
331
warnings .warn (
@@ -419,14 +419,14 @@ def __init__(self, *args, **kwargs):
419
419
self ._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10 ()
420
420
421
421
def forward (
422
- self ,
423
- hidden_states : torch .Tensor ,
424
- attention_mask : Optional [torch .Tensor ] = None ,
425
- position_ids : Optional [torch .LongTensor ] = None ,
426
- past_key_value : Optional [Cache ] = None ,
427
- output_attentions : bool = False ,
428
- use_cache : bool = False ,
429
- ** kwargs ,
422
+ self ,
423
+ hidden_states : torch .Tensor ,
424
+ attention_mask : Optional [torch .Tensor ] = None ,
425
+ position_ids : Optional [torch .LongTensor ] = None ,
426
+ past_key_value : Optional [Cache ] = None ,
427
+ output_attentions : bool = False ,
428
+ use_cache : bool = False ,
429
+ ** kwargs ,
430
430
):
431
431
if "padding_mask" in kwargs :
432
432
warnings .warn (
@@ -462,10 +462,10 @@ def forward(
462
462
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
463
463
464
464
use_sliding_windows = (
465
- _flash_supports_window_size
466
- and getattr (self .config , "sliding_window" , None ) is not None
467
- and kv_seq_len > self .config .sliding_window
468
- and self .config .use_sliding_window
465
+ _flash_supports_window_size
466
+ and getattr (self .config , "sliding_window" , None ) is not None
467
+ and kv_seq_len > self .config .sliding_window
468
+ and self .config .use_sliding_window
469
469
)
470
470
471
471
if not _flash_supports_window_size :
@@ -478,9 +478,9 @@ def forward(
478
478
# Activate slicing cache only if the config has a value `sliding_windows` attribute
479
479
cache_has_contents = past_key_value .get_seq_length (self .layer_idx ) > 0
480
480
if (
481
- getattr (self .config , "sliding_window" , None ) is not None
482
- and kv_seq_len > self .config .sliding_window
483
- and cache_has_contents
481
+ getattr (self .config , "sliding_window" , None ) is not None
482
+ and kv_seq_len > self .config .sliding_window
483
+ and cache_has_contents
484
484
):
485
485
slicing_tokens = 1 - self .config .sliding_window
486
486
@@ -555,15 +555,15 @@ def forward(
555
555
return attn_output , attn_weights , past_key_value
556
556
557
557
def _flash_attention_forward (
558
- self ,
559
- query_states ,
560
- key_states ,
561
- value_states ,
562
- attention_mask ,
563
- query_length ,
564
- dropout = 0.0 ,
565
- softmax_scale = None ,
566
- use_sliding_windows = False ,
558
+ self ,
559
+ query_states ,
560
+ key_states ,
561
+ value_states ,
562
+ attention_mask ,
563
+ query_length ,
564
+ dropout = 0.0 ,
565
+ softmax_scale = None ,
566
+ use_sliding_windows = False ,
567
567
):
568
568
"""
569
569
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -666,7 +666,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
666
666
# by slicing it on the proper place
667
667
if kv_seq_len != attention_mask .shape [- 1 ]:
668
668
attention_mask_num_tokens = attention_mask .shape [- 1 ]
669
- attention_mask = attention_mask [:, attention_mask_num_tokens - kv_seq_len :]
669
+ attention_mask = attention_mask [:, attention_mask_num_tokens - kv_seq_len :]
670
670
671
671
indices_k , cu_seqlens_k , max_seqlen_in_batch_k = _get_unpad_data (attention_mask )
672
672
@@ -712,13 +712,13 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
712
712
713
713
# Adapted from Qwen2MoeAttention.forward
714
714
def forward (
715
- self ,
716
- hidden_states : torch .Tensor ,
717
- attention_mask : Optional [torch .Tensor ] = None ,
718
- position_ids : Optional [torch .LongTensor ] = None ,
719
- past_key_value : Optional [Cache ] = None ,
720
- output_attentions : bool = False ,
721
- use_cache : bool = False ,
715
+ self ,
716
+ hidden_states : torch .Tensor ,
717
+ attention_mask : Optional [torch .Tensor ] = None ,
718
+ position_ids : Optional [torch .LongTensor ] = None ,
719
+ past_key_value : Optional [Cache ] = None ,
720
+ output_attentions : bool = False ,
721
+ use_cache : bool = False ,
722
722
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
723
723
if output_attentions :
724
724
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -887,15 +887,15 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
887
887
self .post_attention_layernorm = Qwen2MoeRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
888
888
889
889
def forward (
890
- self ,
891
- hidden_states : torch .Tensor ,
892
- attention_mask : Optional [torch .Tensor ] = None ,
893
- position_ids : Optional [torch .LongTensor ] = None ,
894
- past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
895
- output_attentions : Optional [bool ] = False ,
896
- output_router_logits : Optional [bool ] = False ,
897
- use_cache : Optional [bool ] = False ,
898
- ** kwargs ,
890
+ self ,
891
+ hidden_states : torch .Tensor ,
892
+ attention_mask : Optional [torch .Tensor ] = None ,
893
+ position_ids : Optional [torch .LongTensor ] = None ,
894
+ past_key_value : Optional [Tuple [torch .Tensor ]] = None ,
895
+ output_attentions : Optional [bool ] = False ,
896
+ output_router_logits : Optional [bool ] = False ,
897
+ use_cache : Optional [bool ] = False ,
898
+ ** kwargs ,
899
899
) -> Tuple [torch .FloatTensor , Optional [Tuple [torch .FloatTensor , torch .FloatTensor ]]]:
900
900
if "padding_mask" in kwargs :
901
901
warnings .warn (
@@ -1112,17 +1112,17 @@ def set_input_embeddings(self, value):
1112
1112
1113
1113
@add_start_docstrings_to_model_forward (QWEN2MOE_INPUTS_DOCSTRING )
1114
1114
def forward (
1115
- self ,
1116
- input_ids : torch .LongTensor = None ,
1117
- attention_mask : Optional [torch .Tensor ] = None ,
1118
- position_ids : Optional [torch .LongTensor ] = None ,
1119
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1120
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
1121
- use_cache : Optional [bool ] = None ,
1122
- output_attentions : Optional [bool ] = None ,
1123
- output_hidden_states : Optional [bool ] = None ,
1124
- output_router_logits : Optional [bool ] = None ,
1125
- return_dict : Optional [bool ] = None ,
1115
+ self ,
1116
+ input_ids : torch .LongTensor = None ,
1117
+ attention_mask : Optional [torch .Tensor ] = None ,
1118
+ position_ids : Optional [torch .LongTensor ] = None ,
1119
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1120
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
1121
+ use_cache : Optional [bool ] = None ,
1122
+ output_attentions : Optional [bool ] = None ,
1123
+ output_hidden_states : Optional [bool ] = None ,
1124
+ output_router_logits : Optional [bool ] = None ,
1125
+ return_dict : Optional [bool ] = None ,
1126
1126
) -> Union [Tuple , MoeModelOutputWithPast ]:
1127
1127
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
1128
1128
output_router_logits = (
@@ -1309,18 +1309,18 @@ def get_decoder(self):
1309
1309
@add_start_docstrings_to_model_forward (QWEN2MOE_INPUTS_DOCSTRING )
1310
1310
@replace_return_docstrings (output_type = MoeCausalLMOutputWithPast , config_class = _CONFIG_FOR_DOC )
1311
1311
def forward (
1312
- self ,
1313
- input_ids : torch .LongTensor = None ,
1314
- attention_mask : Optional [torch .Tensor ] = None ,
1315
- position_ids : Optional [torch .LongTensor ] = None ,
1316
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1317
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
1318
- labels : Optional [torch .LongTensor ] = None ,
1319
- use_cache : Optional [bool ] = None ,
1320
- output_attentions : Optional [bool ] = None ,
1321
- output_hidden_states : Optional [bool ] = None ,
1322
- output_router_logits : Optional [bool ] = None ,
1323
- return_dict : Optional [bool ] = None ,
1312
+ self ,
1313
+ input_ids : torch .LongTensor = None ,
1314
+ attention_mask : Optional [torch .Tensor ] = None ,
1315
+ position_ids : Optional [torch .LongTensor ] = None ,
1316
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1317
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
1318
+ labels : Optional [torch .LongTensor ] = None ,
1319
+ use_cache : Optional [bool ] = None ,
1320
+ output_attentions : Optional [bool ] = None ,
1321
+ output_hidden_states : Optional [bool ] = None ,
1322
+ output_router_logits : Optional [bool ] = None ,
1323
+ return_dict : Optional [bool ] = None ,
1324
1324
) -> Union [Tuple , MoeCausalLMOutputWithPast ]:
1325
1325
r"""
1326
1326
Args:
@@ -1416,7 +1416,7 @@ def forward(
1416
1416
)
1417
1417
1418
1418
def prepare_inputs_for_generation (
1419
- self , input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , ** kwargs
1419
+ self , input_ids , past_key_values = None , attention_mask = None , inputs_embeds = None , ** kwargs
1420
1420
):
1421
1421
# Omit tokens covered by past_key_values
1422
1422
if past_key_values is not None :
@@ -1433,7 +1433,7 @@ def prepare_inputs_for_generation(
1433
1433
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1434
1434
# input)
1435
1435
if attention_mask is not None and attention_mask .shape [1 ] > input_ids .shape [1 ]:
1436
- input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ):]
1436
+ input_ids = input_ids [:, - (attention_mask .shape [1 ] - past_length ) :]
1437
1437
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1438
1438
# input_ids based on the past_length.
1439
1439
elif past_length < input_ids .shape [1 ]:
@@ -1442,9 +1442,9 @@ def prepare_inputs_for_generation(
1442
1442
1443
1443
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1444
1444
if (
1445
- max_cache_length is not None
1446
- and attention_mask is not None
1447
- and cache_length + input_ids .shape [1 ] > max_cache_length
1445
+ max_cache_length is not None
1446
+ and attention_mask is not None
1447
+ and cache_length + input_ids .shape [1 ] > max_cache_length
1448
1448
):
1449
1449
attention_mask = attention_mask [:, - max_cache_length :]
1450
1450
@@ -1454,7 +1454,7 @@ def prepare_inputs_for_generation(
1454
1454
position_ids = attention_mask .long ().cumsum (- 1 ) - 1
1455
1455
position_ids .masked_fill_ (attention_mask == 0 , 1 )
1456
1456
if past_key_values :
1457
- position_ids = position_ids [:, - input_ids .shape [1 ]:]
1457
+ position_ids = position_ids [:, - input_ids .shape [1 ] :]
1458
1458
1459
1459
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1460
1460
if inputs_embeds is not None and past_key_values is None :
@@ -1515,17 +1515,17 @@ def set_input_embeddings(self, value):
1515
1515
1516
1516
@add_start_docstrings_to_model_forward (QWEN2MOE_INPUTS_DOCSTRING )
1517
1517
def forward (
1518
- self ,
1519
- input_ids : torch .LongTensor = None ,
1520
- attention_mask : Optional [torch .Tensor ] = None ,
1521
- position_ids : Optional [torch .LongTensor ] = None ,
1522
- past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1523
- inputs_embeds : Optional [torch .FloatTensor ] = None ,
1524
- labels : Optional [torch .LongTensor ] = None ,
1525
- use_cache : Optional [bool ] = None ,
1526
- output_attentions : Optional [bool ] = None ,
1527
- output_hidden_states : Optional [bool ] = None ,
1528
- return_dict : Optional [bool ] = None ,
1518
+ self ,
1519
+ input_ids : torch .LongTensor = None ,
1520
+ attention_mask : Optional [torch .Tensor ] = None ,
1521
+ position_ids : Optional [torch .LongTensor ] = None ,
1522
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1523
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
1524
+ labels : Optional [torch .LongTensor ] = None ,
1525
+ use_cache : Optional [bool ] = None ,
1526
+ output_attentions : Optional [bool ] = None ,
1527
+ output_hidden_states : Optional [bool ] = None ,
1528
+ return_dict : Optional [bool ] = None ,
1529
1529
) -> Union [Tuple , SequenceClassifierOutputWithPast ]:
1530
1530
r"""
1531
1531
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
0 commit comments