Skip to content

Commit 9fd1a3e

Browse files
committed
[shardformer] update whisper model (hpcaitech#5529)
1 parent ad7d81d commit 9fd1a3e

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

colossalai/shardformer/modeling/whisper.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
shift_tokens_right,
2222
)
2323
from transformers.utils import logging
24-
24+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
2525
from colossalai.pipeline.stage_manager import PipelineStageManager
2626
from colossalai.shardformer.layer import ColoAttention
2727
from colossalai.shardformer.shard import ShardConfig
@@ -539,18 +539,12 @@ def whisper_encoder_forward(
539539
layer_outputs = (None, None)
540540
else:
541541
if self.gradient_checkpointing and self.training:
542-
543-
def create_custom_forward(module):
544-
def custom_forward(*inputs):
545-
return module(*inputs, output_attentions)
546-
547-
return custom_forward
548-
549-
layer_outputs = torch.utils.checkpoint.checkpoint(
550-
create_custom_forward(encoder_layer),
542+
layer_outputs = self._gradient_checkpointing_func(
543+
encoder_layer.__call__,
551544
hidden_states,
552545
None,
553546
(head_mask[idx] if head_mask is not None else None),
547+
output_attentions,
554548
)
555549
else:
556550
layer_outputs = encoder_layer(
@@ -701,6 +695,20 @@ def whisper_decoder_forward(
701695

702696
if inputs_embeds is None:
703697
inputs_embeds = self.embed_tokens(input_ids)
698+
699+
if self._use_flash_attention_2:
700+
# 2d mask is passed through the layers
701+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
702+
elif self._use_sdpa and head_mask is None and not output_attentions:
703+
# output_attentions=True & head_mask can not be supported when using SDPA.
704+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
705+
attention_mask, input_shape, inputs_embeds, past_key_values_length
706+
)
707+
else:
708+
# 4d mask is passed through the layers
709+
attention_mask = _prepare_4d_causal_attention_mask(
710+
attention_mask, input_shape, inputs_embeds, past_key_values_length
711+
)
704712

705713
# embed positions
706714
if input_ids is not None:
@@ -756,23 +764,17 @@ def whisper_decoder_forward(
756764
past_key_value = past_key_values[idx] if past_key_values is not None else None
757765

758766
if self.gradient_checkpointing and self.training:
759-
760-
def create_custom_forward(module):
761-
def custom_forward(*inputs):
762-
# None for past_key_value
763-
return module(*inputs, output_attentions, use_cache)
764-
765-
return custom_forward
766-
767-
layer_outputs = torch.utils.checkpoint.checkpoint(
768-
create_custom_forward(decoder_layer),
767+
layer_outputs = self._gradient_checkpointing_func(
768+
decoder_layer.__call__,
769769
hidden_states,
770770
attention_mask,
771771
encoder_hidden_states,
772772
None, # encoder attention mask
773773
head_mask[idx] if head_mask is not None else None,
774774
(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
775775
None, # past_key_value
776+
output_attentions,
777+
use_cache,
776778
)
777779
else:
778780
layer_outputs = decoder_layer(

colossalai/shardformer/policies/whisper.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,6 @@
2929
class WhisperPolicy(Policy):
3030
def __init__(self) -> None:
3131
super().__init__()
32-
import transformers
33-
from packaging.version import Version
34-
35-
# TODO: remove this version check when transformers>=4.36.0
36-
assert Version(transformers.__version__) <= Version(
37-
"4.33.0"
38-
), "The Whisper model should run on a transformers version not greater than 4.33.0."
3932

4033
def config_sanity_check(self):
4134
pass

0 commit comments

Comments
 (0)