|
21 | 21 | shift_tokens_right,
|
22 | 22 | )
|
23 | 23 | from transformers.utils import logging
|
24 |
| - |
| 24 | +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa |
25 | 25 | from colossalai.pipeline.stage_manager import PipelineStageManager
|
26 | 26 | from colossalai.shardformer.layer import ColoAttention
|
27 | 27 | from colossalai.shardformer.shard import ShardConfig
|
@@ -539,18 +539,12 @@ def whisper_encoder_forward(
|
539 | 539 | layer_outputs = (None, None)
|
540 | 540 | else:
|
541 | 541 | if self.gradient_checkpointing and self.training:
|
542 |
| - |
543 |
| - def create_custom_forward(module): |
544 |
| - def custom_forward(*inputs): |
545 |
| - return module(*inputs, output_attentions) |
546 |
| - |
547 |
| - return custom_forward |
548 |
| - |
549 |
| - layer_outputs = torch.utils.checkpoint.checkpoint( |
550 |
| - create_custom_forward(encoder_layer), |
| 542 | + layer_outputs = self._gradient_checkpointing_func( |
| 543 | + encoder_layer.__call__, |
551 | 544 | hidden_states,
|
552 | 545 | None,
|
553 | 546 | (head_mask[idx] if head_mask is not None else None),
|
| 547 | + output_attentions, |
554 | 548 | )
|
555 | 549 | else:
|
556 | 550 | layer_outputs = encoder_layer(
|
@@ -701,6 +695,20 @@ def whisper_decoder_forward(
|
701 | 695 |
|
702 | 696 | if inputs_embeds is None:
|
703 | 697 | inputs_embeds = self.embed_tokens(input_ids)
|
| 698 | + |
| 699 | + if self._use_flash_attention_2: |
| 700 | + # 2d mask is passed through the layers |
| 701 | + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| 702 | + elif self._use_sdpa and head_mask is None and not output_attentions: |
| 703 | + # output_attentions=True & head_mask can not be supported when using SDPA. |
| 704 | + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| 705 | + attention_mask, input_shape, inputs_embeds, past_key_values_length |
| 706 | + ) |
| 707 | + else: |
| 708 | + # 4d mask is passed through the layers |
| 709 | + attention_mask = _prepare_4d_causal_attention_mask( |
| 710 | + attention_mask, input_shape, inputs_embeds, past_key_values_length |
| 711 | + ) |
704 | 712 |
|
705 | 713 | # embed positions
|
706 | 714 | if input_ids is not None:
|
@@ -756,23 +764,17 @@ def whisper_decoder_forward(
|
756 | 764 | past_key_value = past_key_values[idx] if past_key_values is not None else None
|
757 | 765 |
|
758 | 766 | if self.gradient_checkpointing and self.training:
|
759 |
| - |
760 |
| - def create_custom_forward(module): |
761 |
| - def custom_forward(*inputs): |
762 |
| - # None for past_key_value |
763 |
| - return module(*inputs, output_attentions, use_cache) |
764 |
| - |
765 |
| - return custom_forward |
766 |
| - |
767 |
| - layer_outputs = torch.utils.checkpoint.checkpoint( |
768 |
| - create_custom_forward(decoder_layer), |
| 767 | + layer_outputs = self._gradient_checkpointing_func( |
| 768 | + decoder_layer.__call__, |
769 | 769 | hidden_states,
|
770 | 770 | attention_mask,
|
771 | 771 | encoder_hidden_states,
|
772 | 772 | None, # encoder attention mask
|
773 | 773 | head_mask[idx] if head_mask is not None else None,
|
774 | 774 | (cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
|
775 | 775 | None, # past_key_value
|
| 776 | + output_attentions, |
| 777 | + use_cache, |
776 | 778 | )
|
777 | 779 | else:
|
778 | 780 | layer_outputs = decoder_layer(
|
|
0 commit comments