From 5dc1bb3c9ff1cba7a2f717f7d3d8a54c44c34fcd Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Fri, 13 Oct 2023 16:28:53 +0100 Subject: [PATCH 1/2] [stablelm 3b] Rename dynamic vocab size from "v" to "vocab_size" --- mlc_llm/relax_model/stablelm_3b.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index 4bb1beedeb..40a09b39e3 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -579,7 +579,7 @@ def create_embed_func( bsz = 1 seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): - model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("v", "int64")) + model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") @@ -608,7 +608,7 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"), sep_embed) + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = ( @@ -652,7 +652,7 @@ def create_decoding_func( all_seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64")) + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") @@ -714,7 +714,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> No def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, tvm.tir.Var("v", "int64")), dtype="float32", name="logits") + logits = nn.Placeholder( + (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + ) temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature)) From b856b811c5f84e099b5eaef6f38b021814c6b3c9 Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Fri, 13 Oct 2023 16:39:43 +0100 Subject: [PATCH 2/2] Add get_num_key_value_heads method to StableLM3bConfig --- mlc_llm/relax_model/stablelm_3b.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index 40a09b39e3..89c15a7955 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -66,6 +66,11 @@ def __init__( self.num_shards = 1 self.kwargs = kwargs + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + return self.num_key_value_heads + class LayerNorm(nn.Module): def __init__(