Skip to content

Commit ca8c11b

Browse files
sunggg“Sunghyun
andauthored
[BugFix] Set the right max_sequence_length for both Llama-1 and Llama-2 families (mlc-ai#1032)
* fix * reflect feedback --------- Co-authored-by: “Sunghyun <[email protected]>
1 parent bfaa5b9 commit ca8c11b

File tree

1 file changed

+29
-13
lines changed

1 file changed

+29
-13
lines changed

mlc_llm/relax_model/llama.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -817,26 +817,42 @@ def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
817817
def get_model(args, hf_config):
818818
model_name = args.model
819819
dtype = args.quantization.model_dtype
820-
max_seq_len = args.max_seq_len
821820
sep_embed = args.sep_embed
822821

823822
position_embedding_base = 10000
824823
max_position_embeddings = 2048
825824
if "rope_theta" in hf_config:
826825
position_embedding_base = hf_config["rope_theta"]
827-
if "max_position_embeddings" in hf_config:
828-
max_position_embeddings = hf_config["max_position_embeddings"]
829826

830-
config = LlamaConfig(
831-
**hf_config,
832-
dtype=dtype,
833-
position_embedding_base=position_embedding_base,
834-
combine_matmul=True,
835-
num_shards=args.num_shards,
836-
build_model_only=args.build_model_only,
837-
)
838-
if max_seq_len != -1:
839-
config.max_sequence_length = max_seq_len
827+
# Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards,
828+
# while Llama-1 variants use `max_sequence_length`.
829+
# Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`.
830+
# If none of them is defined, throw an error.
831+
if "max_sequence_length" in hf_config:
832+
config = LlamaConfig(
833+
**hf_config,
834+
dtype=dtype,
835+
position_embedding_base=position_embedding_base,
836+
combine_matmul=True,
837+
num_shards=args.num_shards,
838+
build_model_only=args.build_model_only,
839+
)
840+
elif "max_position_embeddings" in hf_config:
841+
config = LlamaConfig(
842+
**hf_config,
843+
dtype=dtype,
844+
max_sequence_length=hf_config["max_position_embeddings"],
845+
position_embedding_base=position_embedding_base,
846+
combine_matmul=True,
847+
num_shards=args.num_shards,
848+
build_model_only=args.build_model_only,
849+
)
850+
else:
851+
raise Exception("The model config should contain information about maximum sequence length.")
852+
853+
# If there is a user-provided maximum sequence length, override hf config.
854+
if args.max_seq_len != -1:
855+
config.max_sequence_length = args.max_seq_len
840856

841857
param_manager = ParamManager()
842858
bb = relax.BlockBuilder()

0 commit comments

Comments
 (0)