@@ -817,26 +817,42 @@ def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
817
817
def get_model (args , hf_config ):
818
818
model_name = args .model
819
819
dtype = args .quantization .model_dtype
820
- max_seq_len = args .max_seq_len
821
820
sep_embed = args .sep_embed
822
821
823
822
position_embedding_base = 10000
824
823
max_position_embeddings = 2048
825
824
if "rope_theta" in hf_config :
826
825
position_embedding_base = hf_config ["rope_theta" ]
827
- if "max_position_embeddings" in hf_config :
828
- max_position_embeddings = hf_config ["max_position_embeddings" ]
829
826
830
- config = LlamaConfig (
831
- ** hf_config ,
832
- dtype = dtype ,
833
- position_embedding_base = position_embedding_base ,
834
- combine_matmul = True ,
835
- num_shards = args .num_shards ,
836
- build_model_only = args .build_model_only ,
837
- )
838
- if max_seq_len != - 1 :
839
- config .max_sequence_length = max_seq_len
827
+ # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards,
828
+ # while Llama-1 variants use `max_sequence_length`.
829
+ # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`.
830
+ # If none of them is defined, throw an error.
831
+ if "max_sequence_length" in hf_config :
832
+ config = LlamaConfig (
833
+ ** hf_config ,
834
+ dtype = dtype ,
835
+ position_embedding_base = position_embedding_base ,
836
+ combine_matmul = True ,
837
+ num_shards = args .num_shards ,
838
+ build_model_only = args .build_model_only ,
839
+ )
840
+ elif "max_position_embeddings" in hf_config :
841
+ config = LlamaConfig (
842
+ ** hf_config ,
843
+ dtype = dtype ,
844
+ max_sequence_length = hf_config ["max_position_embeddings" ],
845
+ position_embedding_base = position_embedding_base ,
846
+ combine_matmul = True ,
847
+ num_shards = args .num_shards ,
848
+ build_model_only = args .build_model_only ,
849
+ )
850
+ else :
851
+ raise Exception ("The model config should contain information about maximum sequence length." )
852
+
853
+ # If there is a user-provided maximum sequence length, override hf config.
854
+ if args .max_seq_len != - 1 :
855
+ config .max_sequence_length = args .max_seq_len
840
856
841
857
param_manager = ParamManager ()
842
858
bb = relax .BlockBuilder ()
0 commit comments