diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index 4bb1beedeb..89c15a7955 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -66,6 +66,11 @@ def __init__( self.num_shards = 1 self.kwargs = kwargs + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + return self.num_key_value_heads + class LayerNorm(nn.Module): def __init__( @@ -579,7 +584,7 @@ def create_embed_func( bsz = 1 seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): - model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("v", "int64")) + model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") @@ -608,7 +613,7 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"), sep_embed) + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = ( @@ -652,7 +657,7 @@ def create_decoding_func( all_seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64")) + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") @@ -714,7 +719,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> No def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, tvm.tir.Var("v", "int64")), dtype="float32", name="logits") + logits = nn.Placeholder( + (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + ) temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature))