Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mlc_llm/relax_model/stablelm_3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def __init__(
self.num_shards = 1
self.kwargs = kwargs

def get_num_key_value_heads(self):
if self.num_key_value_heads is None:
return self.num_attention_heads
return self.num_key_value_heads


class LayerNorm(nn.Module):
def __init__(
Expand Down Expand Up @@ -579,7 +584,7 @@ def create_embed_func(
bsz = 1
seq_len = tvm.tir.Var("n", "int64")
with bb.function(func_name):
model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("v", "int64"))
model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids")
Expand Down Expand Up @@ -608,7 +613,7 @@ def create_encoding_func(
all_seq_len = tvm.tir.Var("m", "int64")
hidden_size = config.hidden_size
with bb.function(func_name):
model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"), sep_embed)
model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed)
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

inputs = (
Expand Down Expand Up @@ -652,7 +657,7 @@ def create_decoding_func(
all_seq_len = tvm.tir.Var("n", "int64")

with bb.function(func_name):
model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"))
model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"))
param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind)

input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids")
Expand Down Expand Up @@ -714,7 +719,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> No

def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None:
with bb.function("softmax_with_temperature"):
logits = nn.Placeholder((1, 1, tvm.tir.Var("v", "int64")), dtype="float32", name="logits")
logits = nn.Placeholder(
(1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits"
)
temperature = nn.Placeholder((), dtype="float32", name="temperature")
with bb.dataflow():
div = bb.emit(relax.op.divide(logits, temperature))
Expand Down