From 4ac1bb489244fc1dec6dd9dbd7027b6c33b6c994 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Oct 2023 06:29:15 +0000 Subject: [PATCH 1/3] update model to prepare for upstreaming --- mlc_llm/relax_model/llama.py | 505 ++++++++++++++++-- ...llama_batched.py => llama_batched_vllm.py} | 0 2 files changed, 446 insertions(+), 59 deletions(-) rename mlc_llm/relax_model/{llama_batched.py => llama_batched_vllm.py} (100%) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 7241a33a28..8294313324 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1,10 +1,10 @@ import math from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import numpy as np import tvm -from tvm import relax, te +from tvm import relax, te, tir from tvm.relax.op import ccl from tvm.relax.testing import nn from tvm.script import relax as R @@ -59,6 +59,7 @@ def __init__( self.position_embedding_base = position_embedding_base self.combine_matmul = combine_matmul self.sliding_window = sliding_window + if build_model_only and num_shards > 1: self.num_shards = num_shards else: @@ -247,7 +248,7 @@ def rotary_compute(*idx): return q_embed, k_embed -class LlamaAttention(nn.Module): +class LlamaAttentionBase(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): @@ -335,15 +336,12 @@ def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): def forward( self, hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: - from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, squeeze - from tvm.relax.op.nn import softmax - + ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: bsz, q_len, _ = hidden_states.struct_info.shape - assert bsz == 1, "Only support batch size 1 at this moment." query_states, key_states, value_states = self.project_qkv( hidden_states, @@ -351,7 +349,123 @@ def forward( (bsz, q_len, self.num_key_value_heads, self.head_dim), ) - kv_seq_len = all_seq_len_shape.struct_info.values[0] + from tvm.relax.op import reshape + + attn_output, past_key_values = self.attention_fwd( + query_states, + key_states, + value_states, + past_key_values, + bsz, + q_len, + layer_id=layer_id, + all_seq_len_shape=all_seq_len_shape, + attention_mask=attention_mask, + ) + + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + attn_output = self.o_proj(attn_output) + return attn_output, past_key_values + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ): + raise NotImplementedError() + + +class LlamaPagedAttention(LlamaAttentionBase): + def __init__(self, config: LlamaConfig): + super().__init__(config) + ctx_mod = relax.BlockBuilder.current().get() + self.kv_cache_transpose_append = ctx_mod.get_global_var("kv_cache_transpose_append") + self.attention_compute = ctx_mod.get_global_var("attention") + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ) -> Tuple[relax.Expr, relax.Expr]: + assert "layer_id" in kwargs and isinstance(kwargs["layer_id"], int) + layer_id = kwargs["layer_id"] + + f_kv_cache_append = relax.extern("vm.builtin.paged_attention_kv_cache_append") + past_key_values = nn.emit( + relax.call_pure_packed( + f_kv_cache_append, + past_key_values, + self.kv_cache_transpose_append, + key_states, + value_states, + relax.PrimValue(layer_id), + sinfo_args=relax.ObjectStructInfo(), + ) + ) + + f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention") + attn_output = nn.emit( + relax.call_dps_packed( + f_kv_cache_attention, + [ + past_key_values, + self.attention_compute, + query_states, + relax.PrimValue(layer_id), + True, + 1.0, + self.position_embedding_base, + ], + out_sinfo=relax.TensorStructInfo( + ((batch_size, q_len, self.num_query_heads, self.head_dim)), + query_states.struct_info.dtype, + ), + ) + ) + return attn_output, past_key_values + + +class LlamaAttention(LlamaAttentionBase): + def __init__(self, config: LlamaConfig): + super().__init__(config) + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ) -> Tuple[relax.Expr, Tuple[relax.Expr]]: + assert "attention_mask" in kwargs + assert "all_seq_len_shape" in kwargs + attention_mask = kwargs["attention_mask"] + kv_seq_len = kwargs["all_seq_len_shape"].struct_info.values[0] + + from tvm.relax.op import ( + astype, + matmul, + maximum, + permute_dims, + reshape, + squeeze, + ) + from tvm.relax.op.nn import softmax + offset = kv_seq_len - q_len query_states, key_states = apply_rotary_pos_emb( query_states, @@ -371,7 +485,7 @@ def forward( squeezed_key = nn.emit(squeeze(key_states, axis=0)) squeezed_value = nn.emit(squeeze(value_states, axis=0)) - k_cache, v_cache = past_key_value + k_cache, v_cache = past_key_values f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") k_cache = nn.emit( relax.Call( @@ -387,7 +501,7 @@ def forward( sinfo_args=[relax.ObjectStructInfo()], ) ) - past_key_value = (k_cache, v_cache) + past_key_values = (k_cache, v_cache) f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") k_cache = nn.emit( relax.Call( @@ -421,7 +535,7 @@ def forward( tvm.ir.assert_structural_equal( attention_mask.struct_info.shape.values, - (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + (batch_size, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), ) attn_weights = nn.emit( @@ -444,18 +558,14 @@ def forward( attn_output = nn.emit(matmul(attn_weights, value_states)) attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) - attn_output = nn.emit( - reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) - ) - - attn_output = self.o_proj(attn_output) - return attn_output, ((None, None) if past_key_value is None else past_key_value) + return attn_output, past_key_values class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, enable_batching: bool): + attn_class = LlamaPagedAttention if enable_batching else LlamaAttention self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config) + self.self_attn = attn_class(config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm( config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps @@ -490,8 +600,9 @@ def post_self_attn(self, hidden_states, residual): def forward( self, hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, attention_mask: Optional[relax.Expr] = None, ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: residual = hidden_states @@ -501,17 +612,17 @@ def forward( # Self Attention hidden_states, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, all_seq_len_shape=all_seq_len_shape, + layer_id=layer_id, ) - hidden_states = self.post_self_attn(hidden_states, residual) return hidden_states, present_key_value def _make_causal_mask(input_ids_shape, dtype, src_len): - from tvm.relax.op import broadcast_to, full, triu + from tvm.relax.op import broadcast_to bsz, tgt_len = input_ids_shape @@ -560,8 +671,14 @@ def forward(self, input_ids: relax.Expr): return inputs_embeds -class LlamaModel(nn.Module): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): +class LlamaModelBase(nn.Module): + def __init__( + self, + config: LlamaConfig, + vocab_size_var: tir.Var, + sep_embed: bool = False, + enable_batching: bool = False, + ): self.num_shards = config.num_shards self.padding_idx = config.pad_token_id self.embed_tokens = None @@ -570,10 +687,23 @@ def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) self.layers = ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + [LlamaDecoderLayer(config, enable_batching) for _ in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + raise NotImplementedError() + + +class LlamaModelForSingleSequence(LlamaModelBase): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + super().__init__(config, vocab_size_var, sep_embed, enable_batching=False) + def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -596,7 +726,7 @@ def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): def forward( self, inputs: relax.Expr, - all_seq_len_shape: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], past_key_values: relax.Expr, ): if self.num_shards > 1: @@ -627,8 +757,9 @@ def forward( hidden_states, key_value_cache = decoder_layer( hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_value, all_seq_len_shape=all_seq_len_shape, + layer_id=idx, ) next_decoder_cache += key_value_cache @@ -638,9 +769,51 @@ def forward( return hidden_states, next_decoder_cache +class LlamaModelForBatching(LlamaModelBase): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool): + assert sep_embed + super().__init__(config, vocab_size_var, sep_embed=True, enable_batching=True) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + assert all_seq_len_shape is None + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + hidden_states, past_key_values = decoder_layer( + hidden_states, + attention_mask=None, + past_key_values=past_key_values, + all_seq_len_shape=all_seq_len_shape, + layer_id=idx, + ) + + hidden_states = self.norm(hidden_states) + return hidden_states, past_key_values + + class LlamaForCausalLM(nn.Module): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): - self.model = LlamaModel(config, vocab_size_var, sep_embed) + def __init__( + self, + config: LlamaConfig, + vocab_size_var: tvm.tir.Var, + sep_embed: bool = False, + enable_batching: bool = False, + ): + model_class = LlamaModelForBatching if enable_batching else LlamaModelForSingleSequence + self.model = model_class(config, vocab_size_var, sep_embed) self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) ############ Rotary embedding constants ############ @@ -657,7 +830,7 @@ def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: def forward( self, inputs: relax.Expr, - all_seq_len_shape: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], past_key_values: relax.Expr, ): hidden_states, key_value_cache = self.model( @@ -667,8 +840,9 @@ def forward( ) def te_slicing(x: te.Tensor): + assert x.ndim == 3 return te.compute( - shape=(1, 1, x.shape[-1]), + shape=(x.shape[0], 1, x.shape[2]), fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], name="slice", ) @@ -699,7 +873,7 @@ def create_embed_func( ) -> None: func_name = "embed" - bsz = 1 + bsz = tvm.tir.Var("nseq", "int64") seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): model = LlamaEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) @@ -717,7 +891,7 @@ def create_embed_func( bb.update_func(gv, mod[gv].with_attr("num_input", 1)) -def create_encoding_func( +def create_prefill_func_for_single_seq( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, @@ -731,7 +905,9 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): - model = LlamaForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) + model = LlamaForCausalLM( + config, tvm.tir.Var("vocab_size", "int64"), sep_embed, enable_batching=False + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = ( @@ -763,7 +939,43 @@ def create_encoding_func( bb.update_func(gv, mod[gv].with_attr("num_input", 3)) -def create_decoding_func( +def create_prefill_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "prefill_with_embed" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.Var("vocab_size", "int64"), sep_embed=True, enable_batching=True + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder( + (bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" + ) + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, + all_seq_len_shape=None, + past_key_values=past_key_values, + ) + params = [inputs, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + + +def create_decoding_func_for_single_seq( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, @@ -803,6 +1015,37 @@ def create_decoding_func( bb.update_func(gv, mod[gv].with_attr("num_input", 3)) +def create_decoding_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode_with_embed" + + bsz = tir.Var("nseq", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.Var("vocab_size", "int64"), sep_embed=True, enable_batching=True + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder((bsz, 1, hidden_size), dtype=config.dtype, name="inputs_embeds") + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape=None, past_key_values=past_key_values + ) + params = [inputs, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + + def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: num_key_value_heads = config.get_num_key_value_heads() // config.num_shards init_shape = relax.ShapeExpr( @@ -831,7 +1074,39 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bb.emit_func_output(gv) -def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: +def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + head_dim = config.hidden_size // config.num_attention_heads + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + + page_size = tir.Var("page_size", "int64") + total_seq_len = tir.Var("total_seq_len", "int64") + reserved_nseq = tir.Var("reserved_nseq", "int64") + cache_config = relax.Var( + "cache_config", + relax.ShapeStructInfo([reserved_nseq, total_seq_len, page_size]), + ) + + with bb.function("create_kv_cache", [cache_config]): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros((), config.dtype)) + f_kv_cache_create = relax.extern("vm.builtin.paged_attention_kv_cache_create") + cache = bb.emit_output( + relax.Call( + f_kv_cache_create, + args=[ + cache_config, + relax.PrimValue(config.num_hidden_layers), + relax.PrimValue(num_key_value_heads), + relax.PrimValue(head_dim), + zeros, + ], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + bb.emit_func_output(cache) + + +def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConfig) -> None: with bb.function("softmax_with_temperature"): logits = nn.Placeholder( (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" @@ -844,6 +1119,89 @@ def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bb.emit_func_output(gv, [logits, temperature]) +def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + with bb.function("softmax_with_temperature"): + bsz = tvm.tir.Var("nseq", "int64") + logits = nn.Placeholder( + (bsz, 1, tvm.tir.Var("vocab_size", "int64")), + dtype="float32", + name="logits", + ) + temperature = nn.Placeholder((bsz,), dtype="float32", name="temperature") + with bb.dataflow(): + t_reshaped = bb.emit(relax.op.reshape(temperature, (bsz, 1, 1))) + div = bb.emit(relax.op.divide(logits, t_reshaped)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def emit_paged_kv_cache_op(bb: relax.BlockBuilder, dtype: str) -> None: + from tvm.script import tir as T + + # fmt: off + @T.prim_func + def kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_page_table_indptr: T.handle, + var_page_table_values: T.handle, + var_last_page_offset: T.handle, + var_append_length_indptr: T.handle, + var_pos2seqidx: T.handle, + layer_id: T.int32, + ): + nseq = T.int32() + ntoken = T.int32() + nhead = T.int32() + nfeat = T.int32() + nlayer = T.int32() + npage = T.int32() + page_size = T.int32() + num_pages = T.int32() + + pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), dtype) + k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), dtype) + v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), dtype) + last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32") + page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32") + page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") + append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32") + pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32") + + for global_pos, h, f in T.grid(ntoken, nhead, nfeat): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], + layer_id, + 0, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], + layer_id, + 1, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = v_data[vgpos, vh, vf] + # fmt: on + + bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append") + # Todo: integrating attention TIR func/kernel. + bb.add_func(relax.extern("attention_func"), "attention") + + def setup_params(mod, param_manager, dtype, config, args): def f_convert_pname_fwd(pname: str) -> List[str]: if not config.combine_matmul: @@ -932,36 +1290,65 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len + enable_batching = args.enable_batching sep_embed = args.sep_embed + if enable_batching and not sep_embed: + raise ValueError("`sep_embed` is required when batching is enabled.") + position_embedding_base = 10000 - max_position_embeddings = 2048 + if "rope_theta" in hf_config: position_embedding_base = hf_config["rope_theta"] - if "max_position_embeddings" in hf_config: - max_position_embeddings = hf_config["max_position_embeddings"] - config = LlamaConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception("The model config should contain information about maximum sequence length.") + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len param_manager = ParamManager() bb = relax.BlockBuilder() if sep_embed: create_embed_func(bb, param_manager, config, args.quantization) - create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) - create_softmax_func(bb, config) + + if enable_batching: + emit_paged_kv_cache_op(bb, dtype) + create_prefill_func_for_batching(bb, param_manager, config, args.quantization) + create_decoding_func_for_batching(bb, param_manager, config, args.quantization) + create_paged_kv_cache_func(bb, config) + create_softmax_func_for_batching(bb, config) + else: + create_prefill_func_for_single_seq(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func_for_single_seq(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func_for_single_seq(bb, config) + create_metadata_func( bb, model_name=model_name, diff --git a/mlc_llm/relax_model/llama_batched.py b/mlc_llm/relax_model/llama_batched_vllm.py similarity index 100% rename from mlc_llm/relax_model/llama_batched.py rename to mlc_llm/relax_model/llama_batched_vllm.py From f94ab4ddb9516e1c12633a0f536274109ca59130 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Oct 2023 06:30:01 +0000 Subject: [PATCH 2/3] more --- mlc_llm/relax_model/llama_batched_vllm.py | 71 +++++++++++++++-------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index e1ba527cf5..49b9c24b43 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -18,7 +18,7 @@ Linear, Embedding, LlamaRMSNorm, - LlamaAttention, + LlamaAttentionBase, LlamaDecoderLayer, get_param_quant_kind, setup_params, @@ -48,8 +48,8 @@ def rotary_compute(*idx): return q_embed, k_embed -class LlamaAttentionBatched(LlamaAttention): - def __init__(self, config: LlamaConfig, head_mapping): +class LlamaAttentionBatched(LlamaAttentionBase): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): super().__init__(config) self.head_mapping = head_mapping self.sliding_window = None @@ -62,9 +62,9 @@ def forward( hidden_states: relax.Expr, positions: relax.Expr, seq_lens: relax.Expr, - kv_cache: Optional[relax.Expr], + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], slot_mapping: Optional[relax.Expr], - max_seqlen: Optional[relax.Expr], + max_seqlen: Optional[relax.Expr], # Must be on CPU seqstart: Optional[relax.Expr], # For prefill block_tables: Optional[relax.Expr], # For decode indices_within_window: Optional[relax.Expr], # For prefill with sliding-window attention @@ -151,8 +151,8 @@ def forward( class LlamaDecoderLayerBatched(LlamaDecoderLayer): - def __init__(self, config: LlamaConfig, head_mapping): - super().__init__(config) + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config, False) self.self_attn = LlamaAttentionBatched(config, head_mapping) def forward( @@ -160,7 +160,7 @@ def forward( hidden_states: relax.Expr, positions: relax.Expr, seq_lens: relax.Expr, - kv_cache: Optional[relax.Expr], + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], slot_mapping: Optional[relax.Expr], max_seqlen: Optional[relax.Expr], seqstart: Optional[relax.Expr], @@ -240,7 +240,9 @@ def forward( hidden_states = inputs_embeds - # max_seqlen needs to be on CPU + # max_seqlen needs to be on CPU, so that vLLM and Flash Attention can directly get the + # integer length by max_seqlen->data[0]. Otherwise, we need to repeatedly do cudaMemcpy + # of a single int32. max_seqlen = R.to_vdevice(R.max(seq_lens), self.cpu_device) new_kvs = () @@ -568,37 +570,56 @@ def create_decoding_func( def get_model(args, hf_config): - model_name = args.model dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len sep_embed = False position_embedding_base = 10000 - max_position_embeddings = 2048 + if "rope_theta" in hf_config: position_embedding_base = hf_config["rope_theta"] - if "max_position_embeddings" in hf_config: - max_position_embeddings = hf_config["max_position_embeddings"] - - config = LlamaConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len + + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception( + "The model config should contain information about maximum sequence length." + ) + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len param_manager = ParamManager() bb = relax.BlockBuilder() + # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. cpu_dev = VDevice("llvm", 0, "global") create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) + bb.get().update_global_info("vdevice", [cpu_dev]) mod = bb.get() From d0f6fa92296004baf41661ad7b91475f1a8b4eb4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 26 Oct 2023 06:36:50 +0000 Subject: [PATCH 3/3] ok --- mlc_llm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 3f3e9573a9..6adb3f4fe1 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -22,7 +22,7 @@ gpt_neox, gptj, llama, - llama_batched, + llama_batched_vllm, minigpt, param_manager, rwkv, @@ -606,7 +606,7 @@ def build_model_from_args(args: argparse.Namespace): config = json.load(i_f) if not use_cache or args.convert_weight_only: if args.model_category in ["llama", "mistral"] and args.batched: - mod, param_manager, params, model_config = llama_batched.get_model(args, config) + mod, param_manager, params, model_config = llama_batched_vllm.get_model(args, config) elif args.model_category == "llama": mod, param_manager, params, model_config = llama.get_model(args, config) elif args.model_category == "mistral":