Skip to content

Commit 4f4dfe5

Browse files
fxmartyArthurZucker
authored andcommitted
Use torch.bool instead of torch.int64 for non-persistant causal mask buffer (#29241)
use torch.bool instead of torch.int64
1 parent 4b3af6d commit 4f4dfe5

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,11 @@ def __init__(self, config: GemmaConfig):
811811
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
812812
self.gradient_checkpointing = False
813813

814-
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
815-
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
814+
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
815+
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
816+
causal_mask = torch.full(
817+
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
818+
)
816819
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
817820
# Initialize weights and apply final processing
818821
self.post_init()

src/transformers/models/llama/modeling_llama.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,9 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
808808
)
809809

810810
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
811-
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
811+
causal_mask = torch.full(
812+
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
813+
)
812814
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
813815

814816
for layer in self.model.layers:
@@ -916,8 +918,11 @@ def __init__(self, config: LlamaConfig):
916918
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
917919
self.gradient_checkpointing = False
918920

919-
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
920-
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
921+
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
922+
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
923+
causal_mask = torch.full(
924+
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
925+
)
921926
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
922927
# Initialize weights and apply final processing
923928
self.post_init()

0 commit comments

Comments
 (0)