Skip to content

Commit 24d59c7

Browse files
authored
Use torch.bool instead of torch.int64 for non-persistant causal mask buffer (#29241)
use torch.bool instead of torch.int64
1 parent 7c4995f commit 24d59c7

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/transformers/models/gemma/modeling_gemma.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,11 @@ def __init__(self, config: GemmaConfig):
810810
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
811811
self.gradient_checkpointing = False
812812

813-
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
814-
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
813+
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
814+
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
815+
causal_mask = torch.full(
816+
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
817+
)
815818
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
816819
# Initialize weights and apply final processing
817820
self.post_init()

src/transformers/models/llama/modeling_llama.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,9 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
811811
)
812812

813813
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
814-
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
814+
causal_mask = torch.full(
815+
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
816+
)
815817
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
816818

817819
for layer in self.model.layers:
@@ -919,8 +921,11 @@ def __init__(self, config: LlamaConfig):
919921
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
920922
self.gradient_checkpointing = False
921923

922-
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
923-
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
924+
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
925+
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
926+
causal_mask = torch.full(
927+
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
928+
)
924929
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
925930
# Initialize weights and apply final processing
926931
self.post_init()

0 commit comments

Comments
 (0)