@@ -811,7 +811,9 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
811811 )
812812
813813 if max_cache_len > self .model .causal_mask .shape [- 1 ] or self .device != self .model .causal_mask .device :
814- causal_mask = torch .full ((max_cache_len , max_cache_len ), fill_value = 1 , device = self .device )
814+ causal_mask = torch .full (
815+ (max_cache_len , max_cache_len ), fill_value = True , device = self .device , dtype = torch .bool
816+ )
815817 self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
816818
817819 for layer in self .model .layers :
@@ -919,8 +921,11 @@ def __init__(self, config: LlamaConfig):
919921 self .norm = LlamaRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
920922 self .gradient_checkpointing = False
921923
922- # register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
923- causal_mask = torch .full ((config .max_position_embeddings , config .max_position_embeddings ), fill_value = 1 )
924+ # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
925+ # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
926+ causal_mask = torch .full (
927+ (config .max_position_embeddings , config .max_position_embeddings ), fill_value = True , dtype = torch .bool
928+ )
924929 self .register_buffer ("causal_mask" , torch .triu (causal_mask , diagonal = 1 ), persistent = False )
925930 # Initialize weights and apply final processing
926931 self .post_init ()
0 commit comments