@@ -808,7 +808,9 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
808808 )
809809
810810 if max_cache_len > self .model .causal_mask .shape [- 1 ] or self .device != self .model .causal_mask .device :
811- causal_mask = torch .full ((max_cache_len , max_cache_len ), fill_value = 1 , device = self .device )
811+ causal_mask = torch .full (
812+ (max_cache_len , max_cache_len ), fill_value = True , device = self .device , dtype = torch .bool
813+ )
812814 self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
813815
814816 for layer in self .model .layers :
@@ -916,8 +918,11 @@ def __init__(self, config: LlamaConfig):
916918 self .norm = LlamaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
917919 self .gradient_checkpointing = False
918920
919- # register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
920- causal_mask = torch .full ((config .max_position_embeddings , config .max_position_embeddings ), fill_value = 1 )
921+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
922+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
923+ causal_mask = torch .full (
924+ (config .max_position_embeddings , config .max_position_embeddings ), fill_value = True , dtype = torch .bool
925+ )
921926 self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
922927 # Initialize weights and apply final processing
923928 self .post_init ()
0 commit comments