Skip to content

Commit a50a208

Browse files
author
bozheng-hit
committed
fix style
1 parent 8d74bb0 commit a50a208

File tree

1 file changed

+98
-98
lines changed

1 file changed

+98
-98
lines changed

src/transformers/models/qwen2_moe/modeling_qwen2_moe.py

Lines changed: 98 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
from .configuration_qwen2_moe import Qwen2MoeConfig
4646

47+
4748
if is_flash_attn_2_available():
4849
from flash_attn import flash_attn_func, flash_attn_varlen_func
4950
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -63,8 +64,7 @@
6364

6465
# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func
6566
def load_balancing_loss_func(
66-
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2,
67-
attention_mask: Optional[torch.Tensor] = None
67+
gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
6868
) -> float:
6969
r"""
7070
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
@@ -112,9 +112,9 @@ def load_balancing_loss_func(
112112
# Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
113113
expert_attention_mask = (
114114
attention_mask[None, :, :, None, None]
115-
.expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
116-
.reshape(-1, 2, num_experts)
117-
.to(compute_device)
115+
.expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts))
116+
.reshape(-1, 2, num_experts)
117+
.to(compute_device)
118118
)
119119

120120
# Compute the percentage of tokens routed to each experts
@@ -125,9 +125,9 @@ def load_balancing_loss_func(
125125
# Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
126126
router_per_expert_attention_mask = (
127127
attention_mask[None, :, :, None]
128-
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
129-
.reshape(-1, num_experts)
130-
.to(compute_device)
128+
.expand((num_hidden_layers, batch_size, sequence_length, num_experts))
129+
.reshape(-1, num_experts)
130+
.to(compute_device)
131131
)
132132

133133
# Compute the average probability of routing to these experts
@@ -211,7 +211,7 @@ def forward(self, x, seq_len=None):
211211
def rotate_half(x):
212212
"""Rotates half the hidden dims of the input."""
213213
x1 = x[..., : x.shape[-1] // 2]
214-
x2 = x[..., x.shape[-1] // 2:]
214+
x2 = x[..., x.shape[-1] // 2 :]
215215
return torch.cat((-x2, x1), dim=-1)
216216

217217

@@ -318,14 +318,14 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: Optional[int] = None):
318318
)
319319

320320
def forward(
321-
self,
322-
hidden_states: torch.Tensor,
323-
attention_mask: Optional[torch.Tensor] = None,
324-
position_ids: Optional[torch.LongTensor] = None,
325-
past_key_value: Optional[Cache] = None,
326-
output_attentions: bool = False,
327-
use_cache: bool = False,
328-
**kwargs,
321+
self,
322+
hidden_states: torch.Tensor,
323+
attention_mask: Optional[torch.Tensor] = None,
324+
position_ids: Optional[torch.LongTensor] = None,
325+
past_key_value: Optional[Cache] = None,
326+
output_attentions: bool = False,
327+
use_cache: bool = False,
328+
**kwargs,
329329
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
330330
if "padding_mask" in kwargs:
331331
warnings.warn(
@@ -419,14 +419,14 @@ def __init__(self, *args, **kwargs):
419419
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
420420

421421
def forward(
422-
self,
423-
hidden_states: torch.Tensor,
424-
attention_mask: Optional[torch.Tensor] = None,
425-
position_ids: Optional[torch.LongTensor] = None,
426-
past_key_value: Optional[Cache] = None,
427-
output_attentions: bool = False,
428-
use_cache: bool = False,
429-
**kwargs,
422+
self,
423+
hidden_states: torch.Tensor,
424+
attention_mask: Optional[torch.Tensor] = None,
425+
position_ids: Optional[torch.LongTensor] = None,
426+
past_key_value: Optional[Cache] = None,
427+
output_attentions: bool = False,
428+
use_cache: bool = False,
429+
**kwargs,
430430
):
431431
if "padding_mask" in kwargs:
432432
warnings.warn(
@@ -462,10 +462,10 @@ def forward(
462462
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
463463

464464
use_sliding_windows = (
465-
_flash_supports_window_size
466-
and getattr(self.config, "sliding_window", None) is not None
467-
and kv_seq_len > self.config.sliding_window
468-
and self.config.use_sliding_window
465+
_flash_supports_window_size
466+
and getattr(self.config, "sliding_window", None) is not None
467+
and kv_seq_len > self.config.sliding_window
468+
and self.config.use_sliding_window
469469
)
470470

471471
if not _flash_supports_window_size:
@@ -478,9 +478,9 @@ def forward(
478478
# Activate slicing cache only if the config has a value `sliding_windows` attribute
479479
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
480480
if (
481-
getattr(self.config, "sliding_window", None) is not None
482-
and kv_seq_len > self.config.sliding_window
483-
and cache_has_contents
481+
getattr(self.config, "sliding_window", None) is not None
482+
and kv_seq_len > self.config.sliding_window
483+
and cache_has_contents
484484
):
485485
slicing_tokens = 1 - self.config.sliding_window
486486

@@ -555,15 +555,15 @@ def forward(
555555
return attn_output, attn_weights, past_key_value
556556

557557
def _flash_attention_forward(
558-
self,
559-
query_states,
560-
key_states,
561-
value_states,
562-
attention_mask,
563-
query_length,
564-
dropout=0.0,
565-
softmax_scale=None,
566-
use_sliding_windows=False,
558+
self,
559+
query_states,
560+
key_states,
561+
value_states,
562+
attention_mask,
563+
query_length,
564+
dropout=0.0,
565+
softmax_scale=None,
566+
use_sliding_windows=False,
567567
):
568568
"""
569569
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -666,7 +666,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
666666
# by slicing it on the proper place
667667
if kv_seq_len != attention_mask.shape[-1]:
668668
attention_mask_num_tokens = attention_mask.shape[-1]
669-
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len:]
669+
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
670670

671671
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
672672

@@ -712,13 +712,13 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
712712

713713
# Adapted from Qwen2MoeAttention.forward
714714
def forward(
715-
self,
716-
hidden_states: torch.Tensor,
717-
attention_mask: Optional[torch.Tensor] = None,
718-
position_ids: Optional[torch.LongTensor] = None,
719-
past_key_value: Optional[Cache] = None,
720-
output_attentions: bool = False,
721-
use_cache: bool = False,
715+
self,
716+
hidden_states: torch.Tensor,
717+
attention_mask: Optional[torch.Tensor] = None,
718+
position_ids: Optional[torch.LongTensor] = None,
719+
past_key_value: Optional[Cache] = None,
720+
output_attentions: bool = False,
721+
use_cache: bool = False,
722722
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
723723
if output_attentions:
724724
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
@@ -887,15 +887,15 @@ def __init__(self, config: Qwen2MoeConfig, layer_idx: int):
887887
self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
888888

889889
def forward(
890-
self,
891-
hidden_states: torch.Tensor,
892-
attention_mask: Optional[torch.Tensor] = None,
893-
position_ids: Optional[torch.LongTensor] = None,
894-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
895-
output_attentions: Optional[bool] = False,
896-
output_router_logits: Optional[bool] = False,
897-
use_cache: Optional[bool] = False,
898-
**kwargs,
890+
self,
891+
hidden_states: torch.Tensor,
892+
attention_mask: Optional[torch.Tensor] = None,
893+
position_ids: Optional[torch.LongTensor] = None,
894+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
895+
output_attentions: Optional[bool] = False,
896+
output_router_logits: Optional[bool] = False,
897+
use_cache: Optional[bool] = False,
898+
**kwargs,
899899
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
900900
if "padding_mask" in kwargs:
901901
warnings.warn(
@@ -1112,17 +1112,17 @@ def set_input_embeddings(self, value):
11121112

11131113
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
11141114
def forward(
1115-
self,
1116-
input_ids: torch.LongTensor = None,
1117-
attention_mask: Optional[torch.Tensor] = None,
1118-
position_ids: Optional[torch.LongTensor] = None,
1119-
past_key_values: Optional[List[torch.FloatTensor]] = None,
1120-
inputs_embeds: Optional[torch.FloatTensor] = None,
1121-
use_cache: Optional[bool] = None,
1122-
output_attentions: Optional[bool] = None,
1123-
output_hidden_states: Optional[bool] = None,
1124-
output_router_logits: Optional[bool] = None,
1125-
return_dict: Optional[bool] = None,
1115+
self,
1116+
input_ids: torch.LongTensor = None,
1117+
attention_mask: Optional[torch.Tensor] = None,
1118+
position_ids: Optional[torch.LongTensor] = None,
1119+
past_key_values: Optional[List[torch.FloatTensor]] = None,
1120+
inputs_embeds: Optional[torch.FloatTensor] = None,
1121+
use_cache: Optional[bool] = None,
1122+
output_attentions: Optional[bool] = None,
1123+
output_hidden_states: Optional[bool] = None,
1124+
output_router_logits: Optional[bool] = None,
1125+
return_dict: Optional[bool] = None,
11261126
) -> Union[Tuple, MoeModelOutputWithPast]:
11271127
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
11281128
output_router_logits = (
@@ -1309,18 +1309,18 @@ def get_decoder(self):
13091309
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
13101310
@replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
13111311
def forward(
1312-
self,
1313-
input_ids: torch.LongTensor = None,
1314-
attention_mask: Optional[torch.Tensor] = None,
1315-
position_ids: Optional[torch.LongTensor] = None,
1316-
past_key_values: Optional[List[torch.FloatTensor]] = None,
1317-
inputs_embeds: Optional[torch.FloatTensor] = None,
1318-
labels: Optional[torch.LongTensor] = None,
1319-
use_cache: Optional[bool] = None,
1320-
output_attentions: Optional[bool] = None,
1321-
output_hidden_states: Optional[bool] = None,
1322-
output_router_logits: Optional[bool] = None,
1323-
return_dict: Optional[bool] = None,
1312+
self,
1313+
input_ids: torch.LongTensor = None,
1314+
attention_mask: Optional[torch.Tensor] = None,
1315+
position_ids: Optional[torch.LongTensor] = None,
1316+
past_key_values: Optional[List[torch.FloatTensor]] = None,
1317+
inputs_embeds: Optional[torch.FloatTensor] = None,
1318+
labels: Optional[torch.LongTensor] = None,
1319+
use_cache: Optional[bool] = None,
1320+
output_attentions: Optional[bool] = None,
1321+
output_hidden_states: Optional[bool] = None,
1322+
output_router_logits: Optional[bool] = None,
1323+
return_dict: Optional[bool] = None,
13241324
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
13251325
r"""
13261326
Args:
@@ -1416,7 +1416,7 @@ def forward(
14161416
)
14171417

14181418
def prepare_inputs_for_generation(
1419-
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1419+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
14201420
):
14211421
# Omit tokens covered by past_key_values
14221422
if past_key_values is not None:
@@ -1433,7 +1433,7 @@ def prepare_inputs_for_generation(
14331433
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
14341434
# input)
14351435
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1436-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
1436+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
14371437
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
14381438
# input_ids based on the past_length.
14391439
elif past_length < input_ids.shape[1]:
@@ -1442,9 +1442,9 @@ def prepare_inputs_for_generation(
14421442

14431443
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
14441444
if (
1445-
max_cache_length is not None
1446-
and attention_mask is not None
1447-
and cache_length + input_ids.shape[1] > max_cache_length
1445+
max_cache_length is not None
1446+
and attention_mask is not None
1447+
and cache_length + input_ids.shape[1] > max_cache_length
14481448
):
14491449
attention_mask = attention_mask[:, -max_cache_length:]
14501450

@@ -1454,7 +1454,7 @@ def prepare_inputs_for_generation(
14541454
position_ids = attention_mask.long().cumsum(-1) - 1
14551455
position_ids.masked_fill_(attention_mask == 0, 1)
14561456
if past_key_values:
1457-
position_ids = position_ids[:, -input_ids.shape[1]:]
1457+
position_ids = position_ids[:, -input_ids.shape[1] :]
14581458

14591459
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
14601460
if inputs_embeds is not None and past_key_values is None:
@@ -1515,17 +1515,17 @@ def set_input_embeddings(self, value):
15151515

15161516
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
15171517
def forward(
1518-
self,
1519-
input_ids: torch.LongTensor = None,
1520-
attention_mask: Optional[torch.Tensor] = None,
1521-
position_ids: Optional[torch.LongTensor] = None,
1522-
past_key_values: Optional[List[torch.FloatTensor]] = None,
1523-
inputs_embeds: Optional[torch.FloatTensor] = None,
1524-
labels: Optional[torch.LongTensor] = None,
1525-
use_cache: Optional[bool] = None,
1526-
output_attentions: Optional[bool] = None,
1527-
output_hidden_states: Optional[bool] = None,
1528-
return_dict: Optional[bool] = None,
1518+
self,
1519+
input_ids: torch.LongTensor = None,
1520+
attention_mask: Optional[torch.Tensor] = None,
1521+
position_ids: Optional[torch.LongTensor] = None,
1522+
past_key_values: Optional[List[torch.FloatTensor]] = None,
1523+
inputs_embeds: Optional[torch.FloatTensor] = None,
1524+
labels: Optional[torch.LongTensor] = None,
1525+
use_cache: Optional[bool] = None,
1526+
output_attentions: Optional[bool] = None,
1527+
output_hidden_states: Optional[bool] = None,
1528+
return_dict: Optional[bool] = None,
15291529
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
15301530
r"""
15311531
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):

0 commit comments

Comments
 (0)