Skip to content

Commit 7768afb

Browse files
committed
Update flash_attention_patch.py
To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer. huggingface/transformers#25598
1 parent 611a5a8 commit 7768afb

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

applications/Colossal-LLaMA-2/colossal_llama2/utils/flash_attention_patch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def attention_forward(
6565
past_key_value: Optional[Tuple[torch.Tensor]] = None,
6666
output_attentions: bool = False,
6767
use_cache: bool = False,
68+
**kwargs
6869
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
6970
"""
7071
Re-define LLaMA-2 `LlamaAttention` forward method using flash-attention.

0 commit comments

Comments
 (0)