From 898db764b0d0fce260acd9414b2cb7b63668691b Mon Sep 17 00:00:00 2001 From: David Pissarra <61968959+davidpissarra@users.noreply.github.com> Date: Sun, 8 Oct 2023 03:36:19 +0100 Subject: [PATCH 001/116] [API] Add GenerationConfig (#1024) --- cpp/llm_chat.cc | 209 +++++++++++++++++----- python/mlc_chat/__init__.py | 1 + python/mlc_chat/chat_module.py | 223 +++++++++++++++++------- python/mlc_chat/interface/openai_api.py | 44 +++-- python/mlc_chat/rest.py | 124 ++++++------- tests/python/test_update_config.py | 89 ---------- 6 files changed, 414 insertions(+), 276 deletions(-) delete mode 100644 tests/python/test_update_config.py diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index f3fdcc3a36..c83f7add9d 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -599,7 +599,20 @@ class LLMChat { * \brief Get input tokens based on history * \param place_in_prompt The place of the input message in the prompt. */ - std::vector GetInputTokens(PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + std::vector GetInputTokens(PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + int64_t gen_mean_gen_len; + if (generation_config.count("mean_gen_len")) { + CHECK(generation_config["mean_gen_len"].is()); + gen_mean_gen_len = generation_config["mean_gen_len"].get(); + } else { + gen_mean_gen_len = this->mean_gen_len_; + } + + // work on input tokens std::vector tokens; std::vector prompts; @@ -619,7 +632,7 @@ class LLMChat { std::string all_prompt = GetConcatPrompt(prompts, 0, 0); std::vector encoded = this->tokenizer_->Encode(all_prompt); tokens.insert(tokens.end(), encoded.begin(), encoded.end()); - if (this->total_seq_len_ + tokens.size() + this->mean_gen_len_ < this->max_window_size_) { + if (this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) { return tokens; } // need shift window and re-encode @@ -656,11 +669,11 @@ class LLMChat { if (tokens.size() >= this->max_window_size_) { LOG(WARNING) << "The prompt tokens are more than `max_window_size`, the input will be truncated."; - ICHECK_GT(this->max_window_size_, this->mean_gen_len_); + ICHECK_GT(this->max_window_size_, gen_mean_gen_len); std::vector truncated_tokens( - tokens.end() - (this->max_window_size_ - this->mean_gen_len_), tokens.end()); + tokens.end() - (this->max_window_size_ - gen_mean_gen_len), tokens.end()); return truncated_tokens; - } else if (tokens.size() + this->mean_gen_len_ >= this->max_window_size_) { + } else if (tokens.size() + gen_mean_gen_len >= this->max_window_size_) { LOG(WARNING) << "The prompt tokens are too long and the generated text may be incomplete, due to " "limited `max_window_size`. "; @@ -696,7 +709,8 @@ class LLMChat { } std::vector PrepareBeforeEmbedding(std::string inp, bool append_conversation = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + picojson::object generation_config = picojson::object()) { if (conversation_.separator_style == SeparatorStyle::kLM || conversation_.separator_style == SeparatorStyle::kCodeCompletion) { this->ResetChat(); @@ -713,7 +727,7 @@ class LLMChat { conversation_.AppendReplyHeader(conversation_.roles[1]); } - return this->GetInputTokens(place_in_prompt); + return this->GetInputTokens(place_in_prompt, generation_config); } /*! @@ -724,9 +738,18 @@ class LLMChat { * \return the embedding of the tokenized prompt. */ ObjectRef EmbedStep(std::string inp, bool append_conversation = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + String generation_config_str = "") { + // process generation settings + picojson::object generation_config = picojson::object(); + if(!generation_config_str.empty()) { + picojson::value generation_config_json; + picojson::parse(generation_config_json, generation_config_str); + generation_config = generation_config_json.get(); + } + std::vector prompt_tokens = - PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt); + PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config); int64_t token_len = static_cast(prompt_tokens.size()); if (token_len == 0) { return NDArray::Empty({}, DataType::Float(32), device_); @@ -755,7 +778,7 @@ class LLMChat { * \param embedding The embedding to prefill with. * \param decode_next_token Whether to decode next token. */ - void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true) { + void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true, String generation_config_str = "") { if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; throw; @@ -774,7 +797,15 @@ class LLMChat { return; } - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); + // process generation settings + picojson::object generation_config = picojson::object(); + if(!generation_config_str.empty()) { + picojson::value generation_config_json; + picojson::parse(generation_config_json, generation_config_str); + generation_config = generation_config_json.get(); + } + + int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); @@ -791,20 +822,29 @@ class LLMChat { * \param place_in_prompt The place of the input message in the prompt. */ void PrefillStep(std::string inp, bool append_conversation = true, bool decode_next_token = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll) { + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + String generation_config_str = "") { if (ft_.embed_func_.defined() && ft_.prefill_with_embed_func_.defined()) { // Temporarily placed inside `PrefillStep` for compatibility in transition. // Will be separated out in the future. if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; } - NDArray embedding = Downcast(EmbedStep(inp, append_conversation, place_in_prompt)); - PrefillWithEmbedStep(embedding, decode_next_token); + NDArray embedding = Downcast(EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str)); + PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str); return; } + // process generation settings + picojson::object generation_config = picojson::object(); + if(!generation_config_str.empty()) { + picojson::value generation_config_json; + picojson::parse(generation_config_json, generation_config_str); + generation_config = generation_config_json.get(); + } + std::vector prompt_tokens = - this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt); + this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config); int64_t token_len = static_cast(prompt_tokens.size()); if (token_len == 0) return; if (ft_.use_disco) { @@ -824,7 +864,7 @@ class LLMChat { return; } - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); + int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); @@ -833,7 +873,15 @@ class LLMChat { this->ProcessNextToken(next_token); } - void DecodeStep() { + void DecodeStep(String generation_config_str = "") { + // process generation settings + picojson::object generation_config = picojson::object(); + if(!generation_config_str.empty()) { + picojson::value generation_config_json; + picojson::parse(generation_config_json, generation_config_str); + generation_config = generation_config_json.get(); + } + ICHECK(!output_ids_.empty()); int32_t last_token = output_ids_.back(); tvm::runtime::NDArray input_data = GetInputTokenNDArray({last_token}); @@ -843,7 +891,7 @@ class LLMChat { NDArray logits_on_device = this->ForwardTokens({last_token}, total_seq_len_ + 1); total_seq_len_ += 1; - int32_t next_token = this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_); + int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); auto tend = std::chrono::high_resolution_clock::now(); @@ -921,7 +969,7 @@ class LLMChat { { auto tstart = std::chrono::high_resolution_clock::now(); logits_on_device = this->ForwardTokens(tokens, tokens.size()); - tokens.push_back(this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_)); + tokens.push_back(this->SampleTokenFromLogits(logits_on_device)); auto tend = std::chrono::high_resolution_clock::now(); this->prefill_total_time = static_cast((tend - tstart).count()) / 1e9; @@ -933,7 +981,7 @@ class LLMChat { auto tstart = std::chrono::high_resolution_clock::now(); for (int64_t len = 1; len < generate_len; ++len) { logits_on_device = this->ForwardTokens({tokens.back()}, tokens.size()); - tokens.push_back(this->SampleTokenFromLogits(logits_on_device, temperature_, top_p_)); + tokens.push_back(this->SampleTokenFromLogits(logits_on_device)); } auto tend = std::chrono::high_resolution_clock::now(); @@ -960,26 +1008,61 @@ class LLMChat { /*! * \brief Sample output token from logits on device */ - int32_t SampleTokenFromLogits(NDArray logits_on_device, float temperature, float top_p) { - if (repetition_penalty_ == 1.0f) { - if (temperature_ < 1e-6f) { + int32_t SampleTokenFromLogits(NDArray logits_on_device, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + double gen_temperature; + NDArray gen_temperature_arr; + double gen_repetition_penalty; + double gen_top_p; + if (generation_config.count("temperature")) { + CHECK(generation_config["temperature"].is()); + gen_temperature = generation_config["temperature"].get(); + + gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_); + float temperature_cast = static_cast(gen_temperature); + gen_temperature_arr.CopyFromBytes(&temperature_cast, sizeof(float)); + } else { + gen_temperature = this->temperature_; + gen_temperature_arr = this->temperature_arr_; + } + if (generation_config.count("repetition_penalty")) { + CHECK(generation_config["repetition_penalty"].is()); + gen_repetition_penalty = generation_config["repetition_penalty"].get(); + } else { + gen_repetition_penalty = this->repetition_penalty_; + } + if (generation_config.count("top_p")) { + CHECK(generation_config["top_p"].is()); + gen_top_p = generation_config["top_p"].get(); + } else { + gen_top_p = this->top_p_; + } + + // update logits + if (gen_repetition_penalty == 1.0f) { + if (gen_temperature < 1e-6f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); } else { - this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, temperature_)); + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, gen_temperature_arr)); } } else { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); - this->ApplyRepetitionPenaltyOnCPU(); - if (temperature_ >= 1e-6f) { - this->ApplySoftmaxWithTemperatureOnCPU(); + this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty); + if (gen_temperature >= 1e-6f) { + this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); } } + + // perform sampling auto tstart = std::chrono::high_resolution_clock::now(); int next_token; - if (temperature_ < 1e-6f) { - next_token = this->SampleFromLogitsOnCPU(); + if (gen_temperature < 1e-6f) { + next_token = this->SampleFromLogitsOnCPU(gen_temperature, gen_top_p); } else { - next_token = this->SampleFromProbOnCPU(); + next_token = this->SampleFromProbOnCPU(gen_top_p); } auto tend = std::chrono::high_resolution_clock::now(); this->sample_total_time += static_cast((tend - tstart).count()) / 1e9; @@ -990,7 +1073,19 @@ class LLMChat { * \brief Add a generated token and check for stop condition. * \param next_token The next token. */ - void ProcessNextToken(int32_t next_token) { + void ProcessNextToken(int32_t next_token, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + int64_t gen_max_gen_len; + if (generation_config.count("max_gen_len")) { + CHECK(generation_config["max_gen_len"].is()); + gen_max_gen_len = generation_config["max_gen_len"].get(); + } else { + gen_max_gen_len = this->max_gen_len_; + } + ICHECK(!stop_triggered_) << "Cannot call process when it is stopped"; stop_triggered_ = @@ -1024,7 +1119,7 @@ class LLMChat { } } - if (static_cast(output_ids_.size()) >= max_gen_len_) { + if (static_cast(output_ids_.size()) >= gen_max_gen_len) { stop_triggered_ = true; } else if (total_seq_len_ >= max_window_size_) { stop_triggered_ = true; @@ -1077,32 +1172,32 @@ class LLMChat { return Downcast(ret[0]); } - NDArray Softmax(NDArray input, float temperature) { + NDArray Softmax(NDArray input, NDArray temperature_arr) { NDArray ret; - ret = ft_.softmax_func_(input, temperature_arr_); + ret = ft_.softmax_func_(input, temperature_arr); return ret; } - void ApplyRepetitionPenaltyOnCPU() { + void ApplyRepetitionPenaltyOnCPU(float repetition_penalty) { CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; float* logits_raw_data = static_cast(logits_on_cpu_->data); for (const int32_t& token_id : this->appeared_token_ids_) { if (logits_raw_data[token_id] <= 0) { - logits_raw_data[token_id] *= this->repetition_penalty_; + logits_raw_data[token_id] *= repetition_penalty; } else { // logits > 0 - logits_raw_data[token_id] /= this->repetition_penalty_; + logits_raw_data[token_id] /= repetition_penalty; } } } - void ApplySoftmaxWithTemperatureOnCPU() { + void ApplySoftmaxWithTemperatureOnCPU(float temperature) { CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; int vocab_size = logits_on_cpu_->shape[logits_on_cpu_->ndim - 1]; float* logits_raw_data = static_cast(logits_on_cpu_->data); float m = std::numeric_limits::min(); - float inv_temp = 1.0f / this->temperature_; + float inv_temp = 1.0f / temperature; double d = 0.0f; for (int i = 0; i < vocab_size; ++i) { float x = logits_raw_data[i] * inv_temp; @@ -1137,18 +1232,18 @@ class LLMChat { // Utils static double GetRandomNumber() { return RandomGenerator::GetInstance().GetRandomNumber(); } - int32_t SampleFromLogitsOnCPU() { + int32_t SampleFromLogitsOnCPU(float temperature, float top_p) { ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined"; ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; - return fsample_topp_from_logits_(logits_on_cpu_, temperature_, top_p_, GetRandomNumber()); + return fsample_topp_from_logits_(logits_on_cpu_, temperature, top_p, GetRandomNumber()); } - int32_t SampleFromProbOnCPU() { + int32_t SampleFromProbOnCPU(float top_p) { ICHECK(logits_on_cpu_.defined()) << "logits_on_cpu_ is not defined"; ICHECK_EQ(logits_on_cpu_->ndim, 3) << "logits_on_cpu_ should be 3D"; ICHECK_EQ(logits_on_cpu_->shape[0], 1) << "logits_on_cpu_ should be 1 batch"; - return fsample_topp_from_prob_(logits_on_cpu_, top_p_, GetRandomNumber()); + return fsample_topp_from_prob_(logits_on_cpu_, top_p, GetRandomNumber()); } //---------------------------- @@ -1279,7 +1374,7 @@ class LLMChatModule : public ModuleNode { }); } else if (name == "prefill") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK(1 <= args.size() && args.size() <= 3); + ICHECK(1 <= args.size() && args.size() <= 4); if (args.size() == 1) { // args: inp (with decode_next_token = true, place_in_prompt = kAll) GetChat()->PrefillStep(args[0]); @@ -1290,11 +1385,15 @@ class LLMChatModule : public ModuleNode { // args: inp, decode_next_token, place_in_prompt PlaceInPrompt place_in_prompt = static_cast(static_cast(args[2])); GetChat()->PrefillStep(args[0], true, args[1], place_in_prompt); + } else if (args.size() == 4) { + // args: inp, decode_next_token, place_in_prompt, generation_config_str + PlaceInPrompt place_in_prompt = static_cast(static_cast(args[2])); + GetChat()->PrefillStep(args[0], true, args[1], place_in_prompt, args[3]); } }); } else if (name == "embed") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK(1 <= args.size() && args.size() <= 2); + ICHECK(1 <= args.size() && args.size() <= 3); if (args.size() == 1) { // args: inp (with place_in_prompt = kAll) *rv = GetChat()->EmbedStep(args[0]); @@ -1302,22 +1401,36 @@ class LLMChatModule : public ModuleNode { // args: inp, place_in_prompt PlaceInPrompt place_in_prompt = static_cast(static_cast(args[1])); *rv = GetChat()->EmbedStep(args[0], true, place_in_prompt); + } else if (args.size() == 3) { + // args: inp, place_in_prompt, generation_config_str + PlaceInPrompt place_in_prompt = static_cast(static_cast(args[1])); + *rv = GetChat()->EmbedStep(args[0], true, place_in_prompt, args[2]); } }); } else if (name == "prefill_with_embed") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - ICHECK(1 <= args.size() && args.size() <= 2); + ICHECK(1 <= args.size() && args.size() <= 3); if (args.size() == 1) { // args: embedding (with decode_next_token = true) GetChat()->PrefillWithEmbedStep(args[0]); } else if (args.size() == 2) { // args: embedding, decode_next_token GetChat()->PrefillWithEmbedStep(args[0], args[1]); + } else if (args.size() == 3) { + // args: embedding, decode_next_token, generation_config_str + GetChat()->PrefillWithEmbedStep(args[0], args[1], args[2]); } }); } else if (name == "decode") { - return PackedFunc( - [this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { GetChat()->DecodeStep(); }); + return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { + ICHECK(0 <= args.size() && args.size() <= 1); + if (args.size() == 0) { + GetChat()->DecodeStep(); + } else if (args.size() == 1) { + // args: generation_config_str + GetChat()->DecodeStep(args[0]); + } + }); } else if (name == "reset_chat") { return PackedFunc([this, sptr_to_self](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 0); diff --git a/python/mlc_chat/__init__.py b/python/mlc_chat/__init__.py index 5d55de875f..eb2bdeebc1 100644 --- a/python/mlc_chat/__init__.py +++ b/python/mlc_chat/__init__.py @@ -6,3 +6,4 @@ from .chat_module import ChatModule from .chat_module import ConvConfig from .chat_module import ChatConfig +from .chat_module import GenerationConfig diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 9e35224801..97b7cb7670 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -175,6 +175,66 @@ def _from_json(chat_config_cls, json_obj: dict): ) +@dataclass +class GenerationConfig: + r"""A dataclass that represents user-defined generation configuration. + + An instance of ``GenerationConfig`` can be passed in to the generate function + of a :class:`mlc_chat.ChatModule` instance to override the default generation + setting in ``mlc-chat-config.json`` and ``ChatConfig`` under the model folder. + + Once the generation ends, ``GenerationConfig`` is discarded, since the values + will only override the ``ChatConfig`` generation settings during one generation, + unless it is recurrently passed to generate function. This allows changing generation + settings over time, without overriding ``ChatConfig`` permanently. + + Since the configuraiton is partial, everything will be ``Optional``. + + Parameters + ---------- + temperature : Optional[float] + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. + repetition_penalty : Optional[float] + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). + top_p : Optional[float] + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. + mean_gen_len : Optional[int] + max_gen_len : Optional[int] + """ + + temperature: Optional[float] = None + repetition_penalty: Optional[float] = None + top_p: Optional[float] = None + mean_gen_len: Optional[int] = None + max_gen_len: Optional[int] = None + + @classmethod + def _from_chat_config(generation_config_cls, chat_config_obj: ChatConfig): + return generation_config_cls( + **{ + f.name: getattr(chat_config_obj, f.name) + for f in fields(chat_config_obj) + if f.name in inspect.signature(generation_config_cls).parameters + } + ) + + class PlaceInPrompt(Enum): """The place of an input message in a prompt.""" @@ -297,6 +357,34 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi return final_chat_config +def _get_generation_config( + user_chat_config: ChatConfig, user_generation_config: Optional[GenerationConfig] +) -> GenerationConfig: + """Read in the config file in model path, then potentially override with user input. + + Parameters + ---------- + user_chat_config : ChatConfig + ``ChatConfig`` that contain the generation settings to be overriden. + user_generation_config : Optional[GenerationConfig] + User's input, a partial ``GenerationConfig`` to override the ``ChatConfig``. + + Returns + ------ + final_generation_config : GenerationConfig + ``GenerationConfig`` corresponding to ``user_chat_config``, overriden by ``user_generation_config``. + """ + final_generation_config = GenerationConfig._from_chat_config(user_chat_config) + if user_generation_config is not None: + # We override using user's chat config + for field in fields(user_generation_config): + field_name = field.name + field_value = getattr(user_generation_config, field_name) + if field_value is not None: + setattr(final_generation_config, field_name, field_value) + return final_generation_config + + def _get_lib_module_path( model: str, model_path: str, @@ -441,6 +529,25 @@ def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_tem return json.dumps(chat_dict) +def _convert_generation_config_to_json_str(generation_config: Optional[GenerationConfig]) -> str: + """Convert user's input GenerationConfig to a json string. + + Parameters + ---------- + generation_config : Optional[GenerationConfig] + User's input. A partial GenerationConfig for overriding ChatConfig generation settings. + + Returns + ------ + json_str : str + A JSON string that corresponds to user's ``generation_config`` input. + Returns "" if ``generation_config`` unspecified. + """ + if generation_config is None: + return "" + return json.dumps(asdict(generation_config)) + + def _detect_local_device(device_id: int = 0): """Automatically detect the local device if user does not specify. @@ -608,14 +715,13 @@ def __init__( self.chat_config, self.chat_config.conv_template ) self._reload(self.lib_path, self.model_path, user_chat_config_json_str) - - # 7. Save default config values. - self.default_chat_config = asdict(self.chat_config) - if "conv_config" in self.default_chat_config: - self.default_chat_config.pop("conv_config") - self.default_conv_config = json.loads(self._get_config_json())["conv_config"] - - def generate(self, prompt: str, progress_callback=None) -> str: + + def generate( + self, + prompt: str, + generation_config: Optional[GenerationConfig] = None, + progress_callback=None, + ) -> str: r"""A high-level method that returns the full response from the chat module given a user prompt. User can optionally specify which callback method to use upon receiving the response. By default, no callback will be applied. @@ -624,6 +730,8 @@ def generate(self, prompt: str, progress_callback=None) -> str: ---------- prompt : str The user input prompt, i.e. a question to ask the chat module. + generation_config: Optional[GenerationConfig] + The generation config object to override the ChatConfig generation settings. progress_callback: object The optional callback method used upon receiving a newly generated message from the chat module. See `mlc_chat/callback.py` for a full list of available callback classes. Currently, only @@ -643,24 +751,26 @@ def generate(self, prompt: str, progress_callback=None) -> str: # the chat module streaming to stdout piece by piece, and in the end we receive the # full response as a single string `output`. - from mlc_chat import ChatModule, callback + from mlc_chat import ChatModule, GenerationConfig, callback cm = ChatModule(xxx) prompt = "what's the color of banana?" - output = cm.generate(prompt, callback.StreamToStdout(callback_interval=2)) + output = cm.generate( + prompt, GenerationConfig(temperature=0.8), callback.StreamToStdout(callback_interval=2) + ) print(output) """ - self._prefill(prompt) + self._prefill(prompt, generation_config=generation_config) if not progress_callback: while not self._stopped(): - self._decode() + self._decode(generation_config=generation_config) new_msg = self._get_message() return new_msg # apply callback with a rate of callback_interval i, new_msg = 0, "" while not self._stopped(): - self._decode() + self._decode(generation_config=generation_config) if i % progress_callback.callback_interval == 0 or self._stopped(): new_msg = self._get_message() progress_callback(new_msg) @@ -696,45 +806,6 @@ def reset_chat(self, chat_config: Optional[ChatConfig] = None): # Second argument is `partial_update = True` self._load_json_override_func(user_chat_config_json_str, True) - def update_chat_config(self, new_chat_config: ChatConfig): - r"""Update the chat config, or use the currently used default values if - values are None. - - Parameters - ---------- - chat_config : ChatConfig - A ``ChatConfig`` instance partially filled. The chat module will - override the default values with it. - - Note - ---- - This is inteneded for use in the completions api to allow users to specify - config values and use defaults if they are not passed to the request. - """ - - new_chat_config_dict = asdict(new_chat_config) - - # Override chat config values if they are present. Use default values if not. - config_updates_dict = {} - for k, default_value in self.default_chat_config.items(): - new_value = new_chat_config_dict.get(k) - config_updates_dict[k] = new_value if new_value else default_value - - # Add conv_config values if there are ones. - new_conv_config_dict = new_chat_config_dict.get("conv_config") - if new_conv_config_dict: - conv_config_updates_dict = {} - for k, default_value in self.default_conv_config.items(): - new_value = new_conv_config_dict.get(k) - conv_config_updates_dict[k] = new_value if new_value else default_value - config_updates_dict["conv_config"] = conv_config_updates_dict - - # Current logic does not allow partial ChatConfig without specifying the - # conv_template. Hence we use the conv_template after considering potential overrides. - user_chat_config_json_str = json.dumps(config_updates_dict) - # Second argument is `partial_update = True` - self._load_json_override_func(user_chat_config_json_str, True) - def embed_text(self, input: str): r"""Given a text input, returns its embedding in the LLM. @@ -847,6 +918,7 @@ def _prefill( input: str, decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, + generation_config: Optional[GenerationConfig] = None, ): r"""Run prefill stage for a given input and optionally decode the first output token. User can decide where to place the input in the prompt. @@ -859,10 +931,20 @@ def _prefill( Whether to decode the next token after prefilling. place_in_prompt: PlaceInPrompt The place of the input message in the prompt. See `class PlaceInPrompt` for details. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. """ - self._prefill_func(input, decode_next_token, place_in_prompt.value) + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + self._prefill_func(input, decode_next_token, place_in_prompt.value, generation_config_str) - def _embed(self, input: str, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All): + def _embed( + self, + input: str, + place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, + generation_config: Optional[GenerationConfig] = None, + ): r"""A more fine-grained embedding API. Given a text input, get the embedding of the tokenized prompt. User can decide where to place the input in the prompt. This functionality usually aids the subsequent call to :func:`_prefill_with_embed`. @@ -873,15 +955,25 @@ def _embed(self, input: str, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All) The user input string. place_in_prompt: PlaceInPrompt The place of the input message in the prompt. See `class PlaceInPrompt` for details. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. Returns ------- embedding : tvm.runtime.NDArray The embedding of the text. """ - return self._embed_func(input, place_in_prompt.value) + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) - def _prefill_with_embed(self, embedding: tvm.runtime.NDArray, decode_next_token: bool = True): + return self._embed_func(input, place_in_prompt.value, generation_config_str) + + def _prefill_with_embed( + self, + embedding: tvm.runtime.NDArray, + decode_next_token: bool = True, + generation_config: Optional[GenerationConfig] = None, + ): r"""Given an embedding, run the prefill stage and optionally decode the first output token. Parameters @@ -890,14 +982,27 @@ def _prefill_with_embed(self, embedding: tvm.runtime.NDArray, decode_next_token: The embedding of user input. decode_next_token : bool Whether to decode the next token after prefilling. + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. """ - self._prefill_with_embed_func(embedding, decode_next_token) + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + self._prefill_with_embed_func(embedding, decode_next_token, generation_config_str) - def _decode(self): + def _decode(self, generation_config: Optional[GenerationConfig] = None): r"""Decode the next token, the decoding result is stored in a buffer and can be retrieved by :func:`get_message`. + + Parameters + ---------- + generation_config: Optional[GenerationConfig] + The generation config to override the ChatConfig generation settings. """ - self._decode_func() + generation_config = _get_generation_config(self.chat_config, generation_config) + generation_config_str = _convert_generation_config_to_json_str(generation_config) + + self._decode_func(generation_config_str) def _stopped(self) -> bool: r"""Check if the stop condition is met for the current round. diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index 11a72d8ba6..a707608ab1 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -19,14 +19,20 @@ class ChatCompletionRequest(BaseModel): stream: bool | None = False temperature: float = None top_p: float = None + # TODO: replace by presence_penalty and frequency_penalty repetition_penalty: float = None mean_gen_len: int = None + # TODO: replace by max_tokens max_gen_len: int = None - # TODO: Implement support for the following fields + # TODO: Implement support for the OpenAI API parameters + # function [] + # function_call # n: Optional[int] = 1 # stop: Optional[Union[str, List[str]]] = None + # max_tokens: Optional[int] # presence_penalty: Optional[float] = 0.0 # frequency_penalty: Optional[float] = 0.0 + # logit_bias # user: Optional[str] = None class UsageInfo(BaseModel): @@ -65,22 +71,25 @@ class ChatCompletionStreamResponse(BaseModel): class CompletionRequest(BaseModel): model: str prompt: str | list[str] + stream: bool | None = False temperature: float = None repetition_penalty: float = None top_p: float = None mean_gen_len: int = None + # TODO: replace by max_tokens max_gen_len: int = None - system_prompt: str = None - chat_roles: List[str] = None - messages: List[List[str]] = None - offset: str = None - separator_style: int = None - seps: List[str] = None - role_msg_sep: str = None - role_empty_sep: str = None - stop_str: str = None - stop_tokens: List[int] = None - add_bos: bool = None + # TODO: Implement support for the OpenAI API parameters + # suffix + # max_tokens: Optional[int] + # n: Optional[int] = 1 + # logprobs + # echo + # stop: Optional[Union[str, List[str]]] = None + # presence_penalty: Optional[float] = 0.0 + # frequency_penalty: Optional[float] = 0.0 + # best_of + # logit_bias + # user: Optional[str] = None class CompletionResponseChoice(BaseModel): index: int @@ -95,6 +104,17 @@ class CompletionResponse(BaseModel): choices: list[CompletionResponseChoice] usage: UsageInfo +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + finish_reason: Optional[Literal["stop", "length"]] = None + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + choices: List[CompletionResponseStreamChoice] + class EmbeddingsRequest(BaseModel): model: Optional[str] = None input: Union[str, List[Any]] diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index c341b3be8c..f5038f5211 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -1,12 +1,8 @@ import argparse import asyncio -import json -import os -import subprocess -import sys from contextlib import asynccontextmanager -from mlc_chat.chat_module import ChatConfig, ConvConfig +from mlc_chat.chat_module import GenerationConfig import uvicorn from fastapi import FastAPI @@ -21,6 +17,7 @@ import numpy as np + @dataclass class RestAPIArgs: """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API server.""" @@ -46,17 +43,7 @@ class RestAPIArgs: The full path to the model library file to use (e.g. a ``.so`` file). """ ) - } - ) - config_overrides_path: str = field( - default=None, - metadata={ - "help": ( - """ - The full path to the model config file to use for overriding the default (e.g. a ``.json`` file). - """ - ) - } + }, ) device: str = field( default="auto", @@ -70,7 +57,7 @@ class RestAPIArgs: is provided, it will be set to 0 by default. """ ) - } + }, ) host: str = field( default="127.0.0.1", @@ -80,7 +67,7 @@ class RestAPIArgs: The host at which the server should be started, defaults to ``127.0.0.1``. """ ) - } + }, ) port: int = field( default=8000, @@ -90,7 +77,7 @@ class RestAPIArgs: The port on which the server should be started, defaults to ``8000``. """ ) - } + }, ) random_seed: int = field( default=None, @@ -101,7 +88,7 @@ class RestAPIArgs: no seed is set. """ ) - } + }, ) @@ -126,18 +113,12 @@ def convert_args_to_argparser() -> argparse.ArgumentParser: @asynccontextmanager async def lifespan(app: FastAPI): - chat_config_overrides = None - if ARGS.config_overrides_path and os.path.isfile(ARGS.config_overrides_path): - with open(ARGS.config_overrides_path, mode="rt", encoding="utf-8") as f: - json_object = json.load(f) - chat_config_overrides = ChatConfig._from_json(json_object) if ARGS.random_seed is not None: set_global_random_seed(ARGS.random_seed) chat_mod = ChatModule( model=ARGS.model, device=ARGS.device, lib_path=ARGS.lib_path, - chat_config=chat_config_overrides ) session["chat_mod"] = chat_mod @@ -160,13 +141,17 @@ async def lifespan(app: FastAPI): allow_headers=["*"], ) -class AsyncChatCompletionStream: + +class AsyncCompletionStream: + def __init__(self, generation_config: GenerationConfig): + self.generation_config = generation_config + def __aiter__(self): return self async def get_next_msg(self): if not session["chat_mod"]._stopped(): - session["chat_mod"]._decode() + session["chat_mod"]._decode(generation_config=self.generation_config) msg = session["chat_mod"]._get_message() return msg else: @@ -187,29 +172,30 @@ async def request_completion(request: ChatCompletionRequest): Creates model response for the given chat conversation. """ - chat_config = ChatConfig( + generation_config = GenerationConfig( temperature=request.temperature, repetition_penalty=request.repetition_penalty, top_p=request.top_p, mean_gen_len=request.mean_gen_len, max_gen_len=request.max_gen_len, ) - session["chat_mod"].update_chat_config(chat_config) if len(request.messages) > 1: - raise ValueError( - """ + raise ValueError( + """ The /v1/chat/completions endpoint currently only supports single message prompts. Please ensure your request contains only one message - """) + """ + ) if request.stream: - - session["chat_mod"]._prefill(input=request.messages[0].content) + session["chat_mod"]._prefill( + input=request.messages[0].content, generation_config=generation_config + ) async def iter_response(): prev_txt = "" - async for content in AsyncChatCompletionStream(): + async for content in AsyncCompletionStream(generation_config=generation_config): if content: chunk = ChatCompletionStreamResponse( choices=[ @@ -227,7 +213,9 @@ async def iter_response(): return StreamingResponse(iter_response(), media_type="text/event-stream") else: - msg = session["chat_mod"].generate(prompt=request.messages[0].content) + msg = session["chat_mod"].generate( + prompt=request.messages[0].content, generation_config=generation_config + ) return ChatCompletionResponse( choices=[ ChatCompletionResponseChoice( @@ -247,31 +235,15 @@ async def request_completion(request: CompletionRequest): Creates a completion for a given prompt. """ - conv_config = ConvConfig( - system=request.system_prompt, - roles=request.chat_roles, - messages=request.messages, - offset=request.offset, - separator_style=request.separator_style, - seps=request.seps, - role_msg_sep=request.role_msg_sep, - role_empty_sep=request.role_empty_sep, - stop_str=request.stop_str, - stop_tokens=request.stop_tokens, - add_bos=request.add_bos, - ) - - chat_config = ChatConfig( + generation_config = GenerationConfig( temperature=request.temperature, repetition_penalty=request.repetition_penalty, top_p=request.top_p, mean_gen_len=request.mean_gen_len, max_gen_len=request.max_gen_len, - conv_config=conv_config, ) session["chat_mod"].reset_chat() - session["chat_mod"].update_chat_config(chat_config) # Langchain's load_qa_chain.run expects the input to be a list with the query if isinstance(request.prompt, list): if len(request.prompt) > 1: @@ -279,18 +251,39 @@ async def request_completion(request: CompletionRequest): """ The /v1/completions endpoint currently only supports single message prompts. Please ensure your request contains only one message - """) + """ + ) prompt = request.prompt[0] else: prompt = request.prompt - msg = session["chat_mod"].generate(prompt=prompt) + if request.stream: + session["chat_mod"]._prefill(input=prompt, generation_config=generation_config) - return CompletionResponse( - choices=[CompletionResponseChoice(index=0, text=msg)], - # TODO: Fill in correct usage info - usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) + async def iter_response(): + prev_txt = "" + async for content in AsyncCompletionStream(generation_config=generation_config): + if content: + chunk = CompletionStreamResponse( + choices=[ + CompletionResponseStreamChoice( + index=0, + text=content[len(prev_txt) :], + finish_reason="stop", + ) + ] + ) + prev_txt = content + yield f"data: {chunk.json(exclude_unset=True)}\n\n" + + return StreamingResponse(iter_response(), media_type="text/event-stream") + else: + msg = session["chat_mod"].generate(prompt=prompt, generation_config=generation_config) + return CompletionResponse( + choices=[CompletionResponseChoice(index=0, text=msg)], + # TODO: Fill in correct usage info + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) @app.post("/v1/embeddings") @@ -305,7 +298,7 @@ async def request_embeddings(request: EmbeddingsRequest): inps = request.input else: assert f"Invalid input type {type(request.input)}" - + data = [] for i, inp in enumerate(inps): session["chat_mod"].reset_chat() @@ -315,12 +308,7 @@ async def request_embeddings(request: EmbeddingsRequest): data.append({"object": "embedding", "embedding": norm_emb.tolist(), "index": i}) # TODO: Fill in correct usage info return EmbeddingsResponse( - data=data, - usage=UsageInfo( - prompt_tokens=0, - completion_tokens=0, - total_tokens=0 - ) + data=data, usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0) ) diff --git a/tests/python/test_update_config.py b/tests/python/test_update_config.py deleted file mode 100644 index ec92e2c5d3..0000000000 --- a/tests/python/test_update_config.py +++ /dev/null @@ -1,89 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -from mlc_chat.chat_module import ChatConfig, ChatModule, ConvConfig - -class UpdateConfigTest(unittest.TestCase): - - - @patch("mlc_chat.chat_module.ChatModule.__init__") - def setUp(self, mock_init): - mock_init.return_value = None - self.cm_under_test = ChatModule("test") - default_conv_config = { - "prefix_tokens": [], - "role_empty_sep": "", - "role_msg_sep": "", - "seps": [""], - "stop_tokens": [2], - "offset": 0, - "separator_style": 1, - "messages": [], - "stop_str": "<\/s>", - "roles": ["Prompt", "Code"], - "system": "", - "add_bos": True, - "name": "codellama_completion" - } - default_chat_config = { - 'model_lib': 'default_model_lib', - 'local_id': 'default_local_id', - 'conv_template': 'codellama_completion', - 'temperature': 0.7, - 'repetition_penalty': 1.0, - 'top_p': 0.95, - 'mean_gen_len': 128, - 'max_gen_len': 512, - 'shift_fill_factor': 0.3, - 'tokenizer_files': ['tokenizer.json', 'tokenizer.model'], - 'conv_config': None, - 'model_category': 'llama', - 'model_name': 'default_model_name' - } - self.cm_under_test.default_chat_config = default_chat_config - self.cm_under_test.default_conv_config = default_conv_config - self.cm_under_test._load_json_override_func = MagicMock() - - def test_update_config(self): - expected_value = '{"model_lib": "default_model_lib", "local_id": "default_local_id", "conv_template": "codellama_completion", "temperature": 0.5, "repetition_penalty": 1.0, "top_p": 0.95, "mean_gen_len": 128, "max_gen_len": 512, "shift_fill_factor": 0.3, "tokenizer_files": ["tokenizer.json", "tokenizer.model"], "conv_config": {"prefix_tokens": [], "role_empty_sep": "", "role_msg_sep": "", "seps": [""], "stop_tokens": [2], "offset": 0, "separator_style": 1, "messages": [], "stop_str": "}", "roles": ["Prompt", "Code"], "system": "", "add_bos": true, "name": "codellama_completion"}, "model_category": "llama", "model_name": "default_model_name"}' - - conv_config = ConvConfig( - system=None, - roles=None, - messages=None, - offset=None, - separator_style=None, - seps=None, - role_msg_sep=None, - role_empty_sep=None, - stop_str="}", - stop_tokens=None, - add_bos=None, - ) - - chat_config = ChatConfig( - temperature=0.5, - repetition_penalty=None, - top_p=None, - mean_gen_len=None, - max_gen_len=None, - conv_config=conv_config, - ) - - self.cm_under_test.update_chat_config(chat_config) - self.cm_under_test._load_json_override_func.assert_called_once_with(expected_value.replace('\n', '').replace('\t', ''), True) - - def test_update_config_none_conv_config(self): - expected_value = '{"model_lib": "default_model_lib", "local_id": "default_local_id", "conv_template": "codellama_completion", "temperature": 0.5, "repetition_penalty": 1.0, "top_p": 0.95, "mean_gen_len": 128, "max_gen_len": 512, "shift_fill_factor": 0.3, "tokenizer_files": ["tokenizer.json", "tokenizer.model"], "conv_config": null, "model_category": "llama", "model_name": "default_model_name"}' - - chat_config = ChatConfig( - temperature=0.5, - repetition_penalty=None, - top_p=None, - mean_gen_len=None, - max_gen_len=None, - ) - - self.cm_under_test.update_chat_config(chat_config) - self.cm_under_test._load_json_override_func.assert_called_once_with(expected_value.replace('\n', '').replace('\t', ''), True) - \ No newline at end of file From ad3a6b998dab3df2d3a0a0f97ffb3da914b1e67c Mon Sep 17 00:00:00 2001 From: Roee Shenberg Date: Sun, 8 Oct 2023 04:59:15 +0200 Subject: [PATCH 002/116] Fix two bugs in kv-cache backtrack loop (#856) Fix two bugs in kv-cache pop loop Bug 1: old code would stop early because output_ids was shortened in-place during the loop Bug 2: off-by-one in backoff size due to break --- cpp/llm_chat.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index c83f7add9d..d12b0fbd92 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -1107,10 +1107,9 @@ class LLMChat { // back tracking, find the first set of token that is smaller // than the length size_t backoff = 0; - for (; backoff < output_ids_.size(); ++backoff) { + for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) { output_ids_.pop_back(); output_message_ = tokenizer_->Decode(output_ids_); - if (output_message_.length() <= stop_pos) break; } // resize kv to remove the context ft_.fkvcache_array_popn_(kv_cache_, backoff); From 6e40c21fb6433aeffe50ee321f1b589ef846b6fb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 7 Oct 2023 22:07:09 -0500 Subject: [PATCH 003/116] [Build] Added --pdb flag to build.py, drop into pdb on error (#1017) This commit adds an optional `--pdb` flag to the `build.py` script. If passed, any exception raised that would otherwise terminate the script will first enter a pdb post-mortem, allowing the error to be inspected. --- mlc_llm/build.py | 39 ++++++++++++++++++++++++++++++++++++--- mlc_llm/core.py | 7 +++++++ 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/mlc_llm/build.py b/mlc_llm/build.py index 703856c336..c90da542b8 100644 --- a/mlc_llm/build.py +++ b/mlc_llm/build.py @@ -1,13 +1,46 @@ """Script for building/compiling models.""" +import contextlib +import sys + from mlc_llm import core + +@contextlib.contextmanager +def debug_on_except(): + try: + yield + finally: + if sys.exc_info() == (None, None, None): + return + + import traceback + + try: + import ipdb as pdb + except ImportError: + import pdb + + traceback.print_exc() + pdb.post_mortem() + + def main(): """Main method for building model from command line.""" empty_args = core.convert_build_args_to_argparser() # Create new ArgumentParser parsed_args = empty_args.parse_args() # Parse through command line - # Post processing of arguments - parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access - core.build_model_from_args(parsed_args) + + with contextlib.ExitStack() as stack: + # Enter an exception-catching context before post-processing + # the arguments, in case the post-processing itself raises an + # exception. + if parsed_args.pdb: + stack.enter_context(debug_on_except()) + + # Post processing of arguments + parsed_args = core._parse_args(parsed_args) # pylint: disable=protected-access + + core.build_model_from_args(parsed_args) + if __name__ == "__main__": main() diff --git a/mlc_llm/core.py b/mlc_llm/core.py index e9ce63ce61..928d78da41 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -236,6 +236,13 @@ class BuildArgs: "action": "store_true", }, ) + pdb: bool = field( + default=False, + metadata={ + "help": ("If set, drop into a pdb debugger on error"), + "action": "store_true", + }, + ) def convert_build_args_to_argparser() -> argparse.ArgumentParser: From bae37b33034532f9e13e61cf0f9a05a10dec4779 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 8 Oct 2023 16:55:16 -0700 Subject: [PATCH 004/116] [Android] Use `AlertDialog` instead of `Toast` (#1039) --- .../main/java/ai/mlc/mlcchat/AppViewModel.kt | 70 +++++++++++-------- .../src/main/java/ai/mlc/mlcchat/StartView.kt | 40 +++++++++-- android/prepare_libs.sh | 5 +- 3 files changed, 80 insertions(+), 35 deletions(-) diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt index e1b5928019..f51d56ec10 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt @@ -2,6 +2,9 @@ package ai.mlc.mlcchat import ai.mlc.mlcllm.ChatModule import android.app.Application +import android.content.ClipData +import android.content.ClipboardManager +import android.content.Context import android.os.Environment import android.widget.Toast import androidx.compose.runtime.mutableStateOf @@ -23,6 +26,8 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { val modelList = emptyList().toMutableStateList() val chatState = ChatState() val modelSampleList = emptyList().toMutableStateList() + private var showAlert = mutableStateOf(false) + private var alertMessage = mutableStateOf("") private var appConfig = AppConfig( emptyList(), emptyList().toMutableList(), @@ -44,13 +49,38 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { loadAppConfig() } + fun supportedModelLibs(): List { + return appConfig.modelLibs + } + + fun isShowingAlert(): Boolean { + return showAlert.value + } + + fun errorMessage(): String { + return alertMessage.value + } + + fun dismissAlert() { + require(showAlert.value) + showAlert.value = false + } + + fun copyError() { + require(showAlert.value) + val clipboard = + application.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager + clipboard.setPrimaryClip(ClipData.newPlainText("MLCChat", errorMessage())) + } + + private fun issueAlert(error: String) { + showAlert.value = true + alertMessage.value = error + } + fun requestAddModel(url: String, localId: String?) { if (localId != null && localIdSet.contains(localId)) { - Toast.makeText( - application, - "localId: $localId has been occupied", - Toast.LENGTH_SHORT - ).show() + issueAlert("localId: $localId has been occupied") } else { downloadModelConfig(if (url.endsWith("/")) url else "$url/", localId, false) } @@ -58,11 +88,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { fun requestDeleteModel(localId: String) { deleteModel(localId) - Toast.makeText( - application, - "Model: $localId has been deleted", - Toast.LENGTH_SHORT - ).show() + issueAlert("Model: $localId has been deleted") } @@ -133,11 +159,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { private fun isModelConfigAllowed(modelConfig: ModelConfig): Boolean { if (appConfig.modelLibs.contains(modelConfig.modelLib)) return true; viewModelScope.launch { - Toast.makeText( - application, - "Model lib ${modelConfig.modelLib} is not supported.", - Toast.LENGTH_SHORT - ).show() + issueAlert("Model lib ${modelConfig.modelLib} is not supported.") } return false } @@ -169,11 +191,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { } if (localIdSet.contains(modelConfig.localId)) { tempFile.delete() - Toast.makeText( - application, - "${modelConfig.localId} has been used, please consider another local ID", - Toast.LENGTH_SHORT - ).show() + issueAlert("${modelConfig.localId} has been used, please consider another local ID") return@launch } if (!isModelConfigAllowed(modelConfig)) { @@ -188,21 +206,13 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { addModelConfig(modelConfig, modelUrl, isBuiltin) } catch (e: Exception) { viewModelScope.launch { - Toast.makeText( - application, - "Add model failed: ${e.localizedMessage}", - Toast.LENGTH_SHORT - ).show() + issueAlert("Add model failed: ${e.localizedMessage}") } } } } catch (e: Exception) { viewModelScope.launch { - Toast.makeText( - application, - "Download model config failed: ${e.localizedMessage}", - Toast.LENGTH_SHORT - ).show() + issueAlert("Download model config failed: ${e.localizedMessage}") } } diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt index ee2833fca0..87fba77a05 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt @@ -20,6 +20,7 @@ import androidx.compose.material.icons.outlined.Delete import androidx.compose.material.icons.outlined.Download import androidx.compose.material.icons.outlined.Pause import androidx.compose.material.icons.outlined.Schedule +import androidx.compose.material3.AlertDialog import androidx.compose.material3.Divider import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon @@ -94,9 +95,14 @@ fun StartView( } } if (isAddingModel) { - Text( - text = "Add Model Variant", modifier = Modifier.padding(top = 10.dp) - ) + Text(text = "Supported Base Model Libs", modifier = Modifier.padding(top = 10.dp)) + for (lib in appViewModel.supportedModelLibs()) { + Text( + text = lib, + style = MaterialTheme.typography.bodyMedium + ) + } + Text(text = "Add Model Variant", modifier = Modifier.padding(top = 10.dp)) LazyColumn() { items( items = appViewModel.modelSampleList @@ -148,10 +154,36 @@ fun StartView( } } } - + if (appViewModel.isShowingAlert()) { + AlertDialog( + onDismissRequest = { appViewModel.dismissAlert() }, + onConfirmation = { appViewModel.copyError() }, + error = appViewModel.errorMessage() + ) + } } } +@ExperimentalMaterial3Api +@Composable +fun AlertDialog( + onDismissRequest: () -> Unit, + onConfirmation: () -> Unit, + error: String, +) { + AlertDialog( + title = { Text(text = "Error") }, + text = { Text(text = error) }, + onDismissRequest = { onDismissRequest() }, + confirmButton = { + TextButton(onClick = { onConfirmation() }) { Text("Copy") } + }, + dismissButton = { + TextButton(onClick = { onDismissRequest() }) { Text("Dismiss") } + } + ) +} + @Composable fun ModelView( navController: NavController, diff --git a/android/prepare_libs.sh b/android/prepare_libs.sh index 72457954c0..938ffd5cd8 100755 --- a/android/prepare_libs.sh +++ b/android/prepare_libs.sh @@ -9,7 +9,10 @@ python prepare_model_lib.py cd build touch config.cmake -echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake +if [ ${TVM_HOME-0} -ne 0 ]; then + echo "set(TVM_HOME ${TVM_HOME})" >> config.cmake +fi + cmake .. \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ From b44f6793dac7e5cd2d0f8c83d0997c2620109272 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 9 Oct 2023 11:35:58 -0400 Subject: [PATCH 005/116] Add doc for ChatConfig, ConvConfig, GenerationConfig, BuildArgs (#1040) Add doc for ChatConfig, ConvConfig, GenerationConfig, BuildArgs, build model --- docs/community/guideline.rst | 2 +- docs/conf.py | 6 +++--- docs/deploy/python.rst | 9 +++++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst index 7f671f614b..eac77101e9 100644 --- a/docs/community/guideline.rst +++ b/docs/community/guideline.rst @@ -55,7 +55,7 @@ Contribute New Models to MLC-LLM * If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial. Once you have done that, you can create a pull request to add an entry in the :doc:`/prebuilt_models` page. Additionally, you have the option to `create a speed report issue `__ to track the speed and memory consumption of your model. You don't need to test it on all devices; let the community collaborate on building it together! -* If you add a new model variant to MLC-LLM by following our :doc:`/tutorials/bring-your-own-models` tutorial. +* If you add a new model variant to MLC-LLM by following our :doc:`/tutorials/customize/define_new_models` tutorial. Please create a pull request to add your model architecture (currently model architectures are placed under `relax_models `__ folder). diff --git a/docs/conf.py b/docs/conf.py index ee42500d51..0f7ed19014 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -7,6 +7,8 @@ # -- General configuration ------------------------------------------------ sys.path.insert(0, os.path.abspath("../python")) +sys.path.insert(0, os.path.abspath("../")) +autodoc_mock_imports = ["torch"] # do not load mlc-llm.so in docs os.environ["SKIP_LOADING_MLCLLM_SO"] = "1" @@ -29,9 +31,7 @@ "sphinx_reredirects", ] -redirects = { - "get_started/try_out": "../index.html#getting-started" -} +redirects = {"get_started/try_out": "../index.html#getting-started"} source_suffix = [".rst"] diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index ccdfec743d..0bf19c7b4c 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -314,6 +314,15 @@ The :class:`mlc_chat.ChatModule` class provides the following methods: .. automethod:: __init__ +.. autoclass:: ChatConfig + :members: + +.. autoclass:: ConvConfig + :members: + +.. autoclass:: GenerationConfig + :members: + Gradio Frontend --------------- From 3a9849ab5d9abf86a80496dc653dabbfaf437e45 Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Mon, 9 Oct 2023 12:27:58 -0400 Subject: [PATCH 006/116] [Android] Add Llama2 q4f16_0 (#1041) llama2 q4f160 --- android/MLCChat/app/src/main/assets/app-config.json | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/android/MLCChat/app/src/main/assets/app-config.json b/android/MLCChat/app/src/main/assets/app-config.json index ddbcb793ae..fb7c4546b3 100644 --- a/android/MLCChat/app/src/main/assets/app-config.json +++ b/android/MLCChat/app/src/main/assets/app-config.json @@ -1,9 +1,14 @@ { "model_libs": [ + "Llama-2-7b-chat-hf-q4f16_0", "Llama-2-7b-chat-hf-q4f16_1", "RedPajama-INCITE-Chat-3B-v1-q4f16_1" ], "model_list": [ + { + "model_url": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_0/", + "local_id": "Llama-2-7b-chat-hf-q4f16_0" + }, { "model_url": "https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1/", "local_id": "Llama-2-7b-chat-hf-q4f16_1" From bed9e60a5d587142d0e4257035f1f5fa7fa03e4c Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 9 Oct 2023 12:58:36 -0400 Subject: [PATCH 007/116] [Docs] Model prebuilts tracking page revamp (#1000) --- README.md | 63 +++ docs/prebuilt_models.rst | 993 +++++++++++++++++++++++++++++---------- 2 files changed, 814 insertions(+), 242 deletions(-) diff --git a/README.md b/README.md index bb52c1c735..3354a94f8b 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,69 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo +**Prebuilt model support.** MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can +use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Supported Model ArchitectureModel Variants with Prebuilts
LlamaLlama-2
Code Llama
Vicuna
WizardLM
WizardMath
OpenOrca Platypus2
FlagAlpha Llama-2 Chinese
georgesung Llama-2 Uncensored
GPT-NeoXRedPajama
GPT-J
RWKVRWKV-raven
MiniGPT
GPTBigCodeWizardCoder
ChatGLM
+ ## News * [08/25/2023] CodeLlama support is up. diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index 99a86780db..061a1e171b 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -7,168 +7,164 @@ Model Prebuilts :depth: 3 :local: -MLC-LLM is a universal solution for deploying different language models. Any language models that can be described in `TVM Relax `__ (a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the help of :doc:`TVM Unity `. +.. _model-prebuilts-overview: -The community has already supported several LLM architectures (LLaMA, GPT-NeoX, etc.) and have prebuilt some models (Vicuna, RedPajama, etc.) which you can use off the shelf. -With the goal of democratizing the deployment of LLMs, we eagerly anticipate further contributions from the community to expand the range of supported model architectures. +Overview +-------- -This page contains the list of prebuilt models for our CLI (command line interface) app, iOS and Android apps. -The models have undergone extensive testing on various devices, and their performance has been optimized by developers with the help of TVM. +MLC-LLM is a universal solution for deploying different language models. Any models that can be described in `TVM Relax `__ +(a general representation for Neural Networks and can be imported from models written in PyTorch) can be recognized by MLC-LLM and thus deployed to different backends with the +help of :doc:`TVM Unity `. -.. _prebuilt-models-cli: +There are two ways to run a model on MLC-LLM: -Prebuilt Models for CLI ------------------------ +1. Compile your own models following :doc:`the model compilation page `. +2. Use off-the-shelf prebuilts models following this current page. -.. list-table:: - :widths: 15 15 15 15 - :header-rows: 1 +This page focuses on the second option: - * - Model code - - Original Model - - Quantization Mode - - Hugging Face repo - * - `Llama-2-{7, 13, 70}b-chat-hf-q4f16_1` - - `Llama-2 `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - * `7B link `__ - * `13B link `__ - * `70B link `__ - * - `vicuna-v1-7b-q3f16_0` - - `Vicuna `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `rwkv-raven-{1b5, 3b, 7b}-q8f16_0` - - `RWKV `__ - - * Weight storage data type: uint8 - * Running data type: float16 - * Symmetric quantization - - * `1b5 link `__ - * `3b link `__ - * `7b link `__ - * - `WizardLM-13B-V1.2-{q4f16_1, q4f32_1}` - - `WizardLM `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `WizardCoder-15B-V1.0-{q4f16_1, q4f32_1}` - - `WizardCoder `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `WizardMath-{7, 13, 70}B-V1.0-q4f16_1` - - `WizardMath `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - * `7B link `__ - * `13B link `__ - * `70B link `__ - * - `llama2-7b-chat-uncensored-{q4f16_1, q4f32_1}` - - `georgesung `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `Llama2-Chinese-7b-Chat-{q4f16_1, q4f32_1}` - - `FlagAlpha `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `GOAT-7B-Community-{q4f16_1, q4f32_1}` - - `GOAT-AI `__ - - * Weight storage data type: int4 - * Running data type: float{16, 32} - * Symmetric quantization - - * `q4f16_1 link `__ - * `q4f32_1 link `__ - * - `OpenOrca-Platypus2-13B-q4f16_1` - - `Llama-2 `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - -To download and run one model with CLI, follow the instructions below: - -.. code:: shell - - # Create conda environment and install CLI if you have not installed. - conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly - conda activate mlc-chat-venv - conda install git git-lfs - git lfs install - - # Download prebuilt model binary libraries from GitHub if you have not downloaded. - mkdir -p dist/prebuilt - git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib - - # Download prebuilt model weights and run CLI. - cd dist/prebuilt - git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code] - cd ../.. - mlc_chat_cli --model [model-code] - - # e.g., - # cd dist/prebuilt - # git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0 - # cd ../.. - # mlc_chat_cli --model rwkv-raven-7b-q8f16_0 - - -.. _prebuilt-models-ios: - -Prebuilt Models for iOS ------------------------ - -.. list-table:: Prebuilt models for iOS - :widths: 15 15 15 15 - :header-rows: 1 +- Documenting :ref:`how to use prebuilts ` for various platforms, and +- Tracking what current :ref:`prebuilt models we provide `. + +Prerequisite: Model Libraries and Compiled Weights +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to run a specific model on MLC-LLM, you need: + +**1. A model library:** a binary file containing the end-to-end functionality to inference a model (e.g. ``Llama-2-7b-chat-hf-q4f16_1-cuda.so``). See the full list of all precompiled model libraries `here `__. + +**2. Compiled weights:** a folder containing multiple files that store the compiled and quantized weights of a model (e.g. https://huggingface.co/mlc-ai/mlc-chat-Llama-2-7b-chat-hf-q4f16_1). See the full list of all precompiled weights `here `__. + +.. _using-model-prebuilts: + +Using Prebuilt Models for Different Platforms +--------------------------------------------- + +We quickly go over how to use prebuilt models for each platform. You can find detailed instruction on each platform's corresponding page. + +.. _using-prebuilt-models-cli: + + +Prebuilt Models on CLI / Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the CLI page `, and the :doc:`the Python page `. + +.. collapse:: Click to show details + + First create the conda environment if you have not done so. + + .. code:: shell + + conda create -n mlc-chat-venv -c mlc-ai -c conda-forge mlc-chat-cli-nightly + conda activate mlc-chat-venv + conda install git git-lfs + git lfs install + + Download the prebuilt model libraries from github. + + .. code:: shell + + mkdir -p dist/prebuilt + git clone https://github.com/mlc-ai/binary-mlc-llm-libs.git dist/prebuilt/lib + + Download the prebuilt model weights from hugging face for the model variant you want. + + .. code:: shell + + # Say we want to run rwkv-raven-7b-q8f16_0 + cd dist/prebuilt + git clone https://huggingface.co/mlc-ai/mlc-chat-rwkv-raven-7b-q8f16_0 + cd ../.. + + # The format being: + # cd dist/prebuilt + # git clone https://huggingface.co/mlc-ai/mlc-chat-[model-code] + # cd ../.. + # mlc_chat_cli --model [model-code] + + Run the model with CLI: + + .. code:: shell - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `Llama-2-7b-q3f16_1` - - `Llama `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `vicuna-v1-7b-q3f16_0` - - `Vicuna `__ - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - -The `downloadable iOS app `_ has builtin RedPajama-3B model support. -To add a model to the iOS app, follow the steps below: - -.. collapse:: Click to show instructions + # For CLI + mlc_chat_cli --model rwkv-raven-7b-q8f16_0 + + To run the model with Python API, see :doc:`the Python page ` (all other downloading steps are the same as CLI). + + +.. for a blank line + +| + +.. _using-prebuilt-models-ios: + +Prebuilt Models on iOS +^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the iOS page `. + +.. collapse:: Click to show details + + The `iOS app `_ has builtin RedPajama-3B and Llama-2-7b support. + + All prebuilt models with an entry in ``iOS`` in the :ref:`model library table ` are supported by iOS. Namely, we have: + + .. list-table:: Prebuilt model libraries integrated in the iOS app + :widths: 15 15 15 + :header-rows: 1 + + * - Model library name + - Model Family + - Quantization Mode + * - `Llama-2-7b-chat-hf-q3f16_1` + - LLaMA + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + * - `vicuna-v1-7b-q3f16_0` + - LLaMA + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - GPT-NeoX + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + + As for prebuilt model weights, the ones we have integrated into app are listed below: + + .. list-table:: Tested prebuilt model weights for iOS + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q3f16_1` + - `Llama `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `vicuna-v1-7b-q3f16_0` + - `Vicuna `__ + - * Weight storage data type: int3 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + + To run a model variant you compiled on your own, you can directly reuse the above integrated prebuilt model libraries, as long as the model shares the architecture and is compiled with the same quantization mode. For example, if you compile `OpenLLaMA-7B `_ with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. Then you can upload the compiled weights to hugging face so that you can download the weights in the app as shown below (for more on uploading to hugging face, please check the :doc:`model distribution page `). + + To add a model to the iOS app, follow the steps below: .. tabs:: @@ -210,126 +206,639 @@ To add a model to the iOS app, follow the steps below: | -The iOS app has integrated with the following model libraries, which can be directly reused when you want to run a model you compiled in iOS, as long as the model is in the supported model family and is compiled with supported quantization mode. -For example, if you compile `OpenLLaMA-7B `_ with quantization mode ``q3f16_0``, then you can run the compiled OpenLLaMA model on iPhone without rebuilding the iOS app by reusing the `vicuna-v1-7b-q3f16_0` model library. Please check the :doc:`model distribution page ` for detailed instructions. - -.. list-table:: Prebuilt model libraries which are integrated in the iOS app - :widths: 15 15 15 - :header-rows: 1 - - * - Model library name - - Model Family - - Quantization Mode - * - `Llama-2-7b-chat-hf-q3f16_1` - - LLaMA - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - * - `vicuna-v1-7b-q3f16_0` - - LLaMA - - * Weight storage data type: int3 - * Running data type: float16 - * Symmetric quantization - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` - - GPT-NeoX - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - .. _prebuilt-models-android: -Prebuilt Models for Android ---------------------------- - -.. list-table:: Prebuilt models for Android - :widths: 15 15 15 15 - :header-rows: 1 +Prebuilt Models on Android +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For more, please see :doc:`the Android page `. + +.. collapse:: Click to show details + + The apk for demo Android app includes the following models. To add more, check out the Android page. + + .. list-table:: Prebuilt Models for Android + :widths: 15 15 15 15 + :header-rows: 1 + + * - Model code + - Model Series + - Quantization Mode + - Hugging Face repo + * - `Llama-2-7b-q4f16_1` + - `Llama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ + * - `RedPajama-INCITE-Chat-3B-v1-q4f16_1` + - `RedPajama `__ + - * Weight storage data type: int4 + * Running data type: float16 + * Symmetric quantization + - `link `__ +.. for a blank line - * - Model code - - Model Series - - Quantization Mode - - Hugging Face repo - * - `vicuna-v1-7b-q4f16_1` - - `Vicuna `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ - * - `RedPajama-INCITE-Chat-3B-v1-q4f16_0` - - `RedPajama `__ - - * Weight storage data type: int4 - * Running data type: float16 - * Symmetric quantization - - `link `__ +| ------------------- +.. _supported-model-architectures: -You can check `MLC-LLM pull requests `__ to track the ongoing efforts of new models. We encourage users to upload their compiled models to Hugging Face and share with the community. +Level 1: Supported Model Architectures (The All-In-One Table) +------------------------------------------------------------- -.. _supported-model-architectures: +For each model architecture (e.g. Llama), there are multiple variants (e.g. CodeLlama, WizardLM). The variants share the same code for inference and only differ in their weights. In other words, running CodeLlama and WizardLM can use the same model library file (specified in Level 2 tables), but different precompiled weights (specified in Level 3 tables). Note that we have not provided prebuilt weights for all model variants. -Supported Model Architectures ------------------------------ +Each entry below hyperlinks to the corresponding level 2 and level 3 tables. MLC-LLM supports the following model architectures: .. list-table:: Supported Model Architectures - :widths: 15 15 15 15 + :widths: 10 10 15 15 :header-rows: 1 - * - Category Code - - Series - - Model Definition - - Variants + * - Code + - Architecture + - Variants w/ MLC prebuilts + - Variants w/o MLC prebuilts * - ``llama`` - - `LLaMa `__ - - `Relax Code `__ - - * `Llama-2 `__ - * `Alpaca `__ - * `Vicuna `__ + - LLaMa + + * :ref:`Prebuilt library table ` + * `Official link `__ + * `Relax Code `__ + - * :ref:`Llama-2 ` + * :ref:`Code Llama ` + * :ref:`Vicuna ` + * :ref:`WizardLM ` + * :ref:`WizardMath ` + * :ref:`OpenOrca Platypus2 ` + * :ref:`FlagAlpha Llama-2 Chinese ` + * :ref:`georgesung Llama-2 Uncensored ` + - * `Alpaca `__ * `Guanaco `__ * `OpenLLaMA `__ * `Gorilla `__ - * `WizardLM `__ * `YuLan-Chat `__ - * `WizardMath `__ - * `FlagAlpha Llama-2 Chinese `__ + * `WizardCoder (new) `__ * - ``gpt-neox`` - - `GPT-NeoX `__ - - `Relax Code `__ - - * `RedPajama `__ - * `Dolly `__ + - GPT-NeoX + + * :ref:`Prebuilt library table ` + * `Official link `__ + * `Relax Code `__ + - * :ref:`RedPajama ` + - * `Dolly `__ * `Pythia `__ * `StableCode `__ * - ``gptj`` - - `GPT-J `__ - - `Relax Code `__ + - GPT-J + + * Prebuilt not compiled yet + * `Official link `__ + * `Relax Code `__ + - - * `MOSS `__ * - ``rwkv`` - - `RWKV `__ - - `Relax Code `__ - - * `RWKV-raven `__ + - RWKV + + * :ref:`Prebuilt library table ` + * `Official link `__ + * `Relax Code `__ + - * :ref:`RWKV-raven ` + - * - ``minigpt`` - - `MiniGPT `__ - - `Relax Code `__ - - + - MiniGPT + + * Prebuilt not compiled yet + * `Official link `__ + * `Relax Code `__ + - + - * `MiniGPT-4 `__ * - ``gpt_bigcode`` - - `GPTBigCode `__ - - `Relax Code `__ + - GPTBigCode + + * :ref:`Prebuilt library table ` + * `Official link `__ + * `Relax Code `__ + - * :ref:`WizardCoder (old) ` - * `StarCoder `__ - * `WizardCoder `__ * `SantaCoder `__ * - ``chatglm`` - - `ChatGLM `__ - - `Relax Code `__ + - ChatGLM + + * Prebuilt not compiled yet + * `Official link `__ + * `Relax Code `__ + - - * `ChatGLM2 `__ * `CodeGeeX2 `__ +If the model variant you are interested in is in one of these model architectures we support (but we have not provided the prebuilt weights yet), you can check the :doc:`model compilation page ` on how to compile your own models. Note that you only need to compile the weights for your model variant and reuse the library file found in Level 2 tables. + +For models structured in an architecture we have not supported yet, you could: + +- Either `create a new issue `_ to request a new model architecture. + +- Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. + + +.. _model-library-tables: + +Level 2: Model Library Tables (Precompiled Binary Files) +-------------------------------------------------------- + +As mentioned earlier, each model architecture corresponds to a different model library file. That is, you cannot use the same model library file to run ``RedPajama`` and ``Llama-2``. However, you can use the same ``Llama`` model library file to run ``Llama-2``, ``WizardLM``, ``CodeLlama``, etc, but just with different weight files (from tables in Level 3). + +Each table below demonstrates the pre-compiled model library files for each model architecture. This is categorized by: + +- **Size**: each size of model has its own distinct model library file (e.g. 7B or 13B number of parameters) + +- **Platform**: the backend that the model library is intended to be run on (e.g. CUDA, ROCm, iphone, etc.) + +- **Quantization scheme**: the model library file also differs due to the quantization scheme used. For more on this, please see the :doc:`model compilation page ` (e.g. ``q3f16_1`` vs. ``q4f16_1``) + +Each entry links to the specific model library file found in `this github repo `__. + +.. _llama_library_table: + +Llama +^^^^^ +.. list-table:: Llama + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 7B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q3f16_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + * - 13B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + * - 34B + - `q4f16_1 `__ + - + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_1 `__ + - + - + - + - + * - 70B + - + - + - + - + - `q3f16_1 `__ + + `q4f16_1 `__ + - + - + - `q4f16_1 `__ + - + +.. _gpt_neox_library_table: + +GPT-NeoX (RedPajama-INCITE) +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. list-table:: GPT-NeoX (RedPajama-INCITE) + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 3B + - `q4f16_1 `__ + - `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + - `q4f16_0 `__ + + `q4f16_1 `__ + + `q4f32_0 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + +.. _rwkv_library_table: + +RWKV +^^^^ +.. list-table:: RWKV + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 1B5 + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + * - 3B + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + * - 7B + - + - + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - `q8f16_0 `__ + - + - + - + +.. _gpt_big_code_library_table: + +GPTBigCode +^^^^^^^^^^ +Note that these all links to model libraries for WizardCoder (the older version released in Jun. 2023). +However, any GPTBigCode model variants should be able to reuse these (e.g. StarCoder, SantaCoder). + +.. list-table:: GPTBigCode + :widths: 8 8 8 8 8 8 8 8 8 8 + :header-rows: 1 + :stub-columns: 1 + + * - + - CUDA + - ROCm + - Vulkan + + (Linux) + - Vulkan + + (Windows) + - Metal + + (M1/M2) + - Metal + + (Intel) + - iOS + - webgpu + - mali + * - 15B + - `q4f16_1 `__ + + `q4f32_1 `__ + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + + `q4f32_1 `__ + - `q4f16_1 `__ + - + - + - `q4f16_1 `__ + + `q4f32_1 `__ + - + +.. _model-variant-tables: + +Level 3: Model Variant Tables (Precompiled Weights) +--------------------------------------------------- + +Finally, for each model variant, we provide the precompiled weights we uploaded to hugging face. + +Each precompiled weight is categorized by its model size (e.g. 7B vs. 13B) and the quantization scheme (e.g. ``q3f16_1`` vs. ``q4f16_1``). We note that the weights are **platform-agnostic**. + +Each model variant also loads its conversation configuration from a pre-defined :ref:`conversation template`. Note that multiple model variants can share a common conversation template. + +Some of these files are uploaded by our community contributors--thank you! + +.. _llama2_variant_table: + +`Llama-2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: Llama-2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_1 `__ + * `q4f16_1 `__ + * `q4f32_1 `__ + + * - 13B + - * `q4f16_1 `__ + * `q4f32_1 `__ + + * - 70B + - * `q3f16_1 `__ + * `q4f16_1 `__ + +.. _code_llama_variant_table: + +`Code Llama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``codellama_completion`` + +.. list-table:: Code Llama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + * - 13B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + * - 34B + - * `q4f16_1 (Base) `__ + * `q4f16_1 (Instruct) `__ + * `q4f16_1 (Python) `__ + + +.. _vicuna_variant_table: + +`Vicuna `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``vicuna_v1.1`` + +.. list-table:: Vicuna + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q3f16_0 `__ + * `q4f32_0 `__ + * `int3 (demo) `__ + * `int4 (demo) `__ + + +.. _WizardLM_variant_table: + +`WizardLM `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``vicuna_v1.1`` + +.. list-table:: WizardLM + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 13B + - * `q4f16_1 (V1.2) `__ + * `q4f32_1 (V1.2) `__ + + * - 70B + - * `q3f16_1 (V1.0) `__ + * `q4f16_1 (V1.0) `__ + + +.. _wizard_math_variant_table: + +`WizardMath `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardMath + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + * - 13B + - `q4f16_1 `__ + * - 70B + - `q4f16_1 `__ + + +.. _open_orca_variant_table: + +`OpenOrca Platypus2 `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: OpenOrca Platypus2 + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 13B + - `q4f16_1 `__ + + +.. _flag_alpha_llama2_variant_table: + +`FlagAlpha Llama-2 Chinese `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-2`` + +.. list-table:: FlagAlpha Llama-2 Chinese + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + + +.. _llama2_uncensored_variant_table: + +`Llama2 uncensored (georgesung) `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``llama-default`` + +.. list-table:: Llama2 uncensored + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 7B + - * `q4f16_1 `__ + * `q4f32_1 `__ + +.. _red_pajama_variant_table: + +`RedPajama `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``LM`` + +.. list-table:: Red Pajama + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 3B + - * `q4f16_0 (Instruct) `__ + * `q4f16_0 (Chat) `__ + * `q4f16_1 (Chat) `__ + * `q4f32_0 (Chat) `__ + + +.. _rwkv_raven_variant_table: + +`RWKV-raven `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``rwkv`` + +.. list-table:: RWKV-raven + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 1B5 + - `q8f16_0 `__ + + * - 3B + - `q8f16_0 `__ + + * - 7B + - `q8f16_0 `__ + + +.. _wizard_coder_variant_table: + +`WizardCoder `__ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Conversation template: ``wizard_coder_or_math`` + +.. list-table:: WizardCoder + :widths: 30 30 + :header-rows: 1 + + * - Size + - Hugging Face Repo Link + * - 15B + - `q4f16_1 `__ + +------------------ -For models structured in these model architectures, you can check the :doc:`model compilation page ` on how to compile models. -Please `create a new issue `_ if you want to request a new model architecture. -Our tutorial :doc:`Define New Models ` introduces how to bring a new model architecture to MLC-LLM. .. _contribute-models-to-mlc-llm: From c02fdafc917d8bdc941ecb18be7e943cef22d89b Mon Sep 17 00:00:00 2001 From: yongjer <54315206+yongjer@users.noreply.github.com> Date: Tue, 10 Oct 2023 00:58:51 +0800 Subject: [PATCH 008/116] Update compile_models.rst (#1038) fix permission issue --- docs/compilation/compile_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index e559c8fc27..b5f1044b75 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -30,7 +30,7 @@ The easiest way is to use MLC-LLM is to clone the repository, and compile models .. code:: bash # clone the repository - git clone git@github.com:mlc-ai/mlc-llm.git --recursive + git clone https://github.com/mlc-ai/mlc-llm.git --recursive # enter to root directory of the repo cd mlc-llm # install mlc-llm From 85001ed4b722ee99ab2329e1b2604650934cc49b Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Mon, 9 Oct 2023 20:40:52 +0100 Subject: [PATCH 009/116] Support for the Stable LM 3B model (#1008) Support for the stablelm-3b-4e1t model --- cpp/conv_templates.cc | 20 + mlc_llm/core.py | 7 +- mlc_llm/relax_model/stablelm_3b.py | 892 +++++++++++++++++++++++++++++ mlc_llm/utils.py | 3 +- 4 files changed, 918 insertions(+), 4 deletions(-) create mode 100644 mlc_llm/relax_model/stablelm_3b.py diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index 69f84b2421..ae91bf2070 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -473,6 +473,25 @@ Conversation VanillaLM() { return conv; } +Conversation StableLM3B() { + Conversation conv; + conv.name = "stablelm-3b"; + conv.system = ""; + conv.roles = {"Prompt", "LM"}; + conv.messages = {}; + conv.separator_style = SeparatorStyle::kLM; + conv.offset = 0; + conv.seps = {""}; + conv.role_msg_sep = ""; + conv.role_empty_sep = ""; + // TODO(mlc-team): add eos to mlc-chat-config + // and remove eos from stop token setting. + // so the same template works for more tokenizers + conv.stop_tokens = {0}; + conv.add_bos = true; + return conv; +} + Conversation GPTBigCode() { Conversation conv; conv.name = "gpt_bigcode"; @@ -580,6 +599,7 @@ Conversation Conversation::FromTemplate(const std::string& name) { {"minigpt", MiniGPT}, {"moss", MOSS}, {"LM", VanillaLM}, + {"stablelm-3b", StableLM3B}, {"gpt_bigcode", GPTBigCode}, {"wizardlm_7b", WizardLM7B}, {"wizard_coder_or_math", WizardCoderOrMATH}, diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 928d78da41..ddf93bf09a 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -25,6 +25,7 @@ minigpt, param_manager, rwkv, + stablelm_3b, ) from mlc_llm.relax_model.commons import create_shard_info_func from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention @@ -592,10 +593,10 @@ def build_model_from_args(args: argparse.Namespace): with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) if not use_cache or args.convert_weight_only: - if args.model_category == "llama": - mod, param_manager, params, model_config = llama.get_model(args, config) - elif args.model_category == "mistral": + if args.model_category in ("llama", "mistral"): mod, param_manager, params, model_config = llama.get_model(args, config) + elif args.model_category == "stablelm_epoch": + mod, param_manager, params, model_config = stablelm_3b.get_model(args, config) elif args.model_category == "gpt_neox": mod, param_manager, params, model_config = gpt_neox.get_model(args, config) elif args.model_category == "gpt_bigcode": diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py new file mode 100644 index 0000000000..4bb1beedeb --- /dev/null +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -0,0 +1,892 @@ +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl +from tvm.relax.op.nn import layer_norm +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import ModuleList, RotaryEmbedding +from .param_manager import ParamManager +from .llama import Embedding, Linear + + +@dataclass +class StableLM3bConfig: + def __init__( + self, + dtype="float32", + max_sequence_length=4096, + vocab_size=50304, + hidden_size=2560, + intermediate_size=6912, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + initializer_range=0.02, + norm_eps=1e-5, + pad_token_id=-1, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + position_embedding_base=10000, + combine_matmul=True, + num_shards=1, + build_model_only=False, + convert_weight_only=False, + **kwargs, + ): + self.dtype = dtype + self.max_sequence_length = max_sequence_length + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.tie_word_embeddings = tie_word_embeddings + self.position_embedding_base = position_embedding_base + self.combine_matmul = combine_matmul + if build_model_only and num_shards > 1: + self.num_shards = num_shards + else: + self.num_shards = 1 + self.kwargs = kwargs + + +class LayerNorm(nn.Module): + def __init__( + self, + hidden_size, + dtype, + eps=1e-5, + ): + super().__init__() + self.eps = eps + self.weight = nn.Parameter((hidden_size,), dtype="float16", name="weight") + self.bias = nn.Parameter((hidden_size,), dtype="float16", name="bias") + + def forward(self, x: relax.Expr) -> relax.Var: + x = nn.emit( + layer_norm( + x, + gamma=self.weight, + beta=self.bias, + axes=-1, + epsilon=self.eps, + ) + ) + return x + + +class StableLM3bMLP(nn.Module): + def __init__(self, config: StableLM3bConfig): + self.combine_matmul = config.combine_matmul + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + dtype = config.dtype + if self.combine_matmul: + self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.gate_up_proj.weight.shard_dim = 0 + self.down_proj.weight.shard_dim = 1 + else: + self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.gate_proj.weight.shard_dim = 0 + self.up_proj.weight.shard_dim = 0 + self.down_proj.weight.shard_dim = 1 + + def forward(self, x): + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_proj(x), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(x) + up_result = self.up_proj(x) + + result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + return result + + +class StableLM3bAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): + dtype = config.dtype + self.num_shards = config.num_shards + self.hidden_size = config.hidden_size + self.num_key_value_heads = ( + config.num_key_value_heads is None + and config.num_attention_heads + or config.num_key_value_heads + ) // config.num_shards + self.num_query_heads = config.num_attention_heads // self.num_shards + self.head_dim = self.hidden_size // config.num_attention_heads + self.position_embedding_base = config.position_embedding_base + self.rotary_embedding = rotary_embedding + + self.combine_matmul = config.combine_matmul + if self.combine_matmul: + self.query_key_value_proj = Linear( + self.hidden_size, + (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, + dtype=dtype, + bias=False, + ) + self.query_key_value_proj.weight.shard_dim = 0 + else: + self.q_proj = Linear( + self.hidden_size, + self.num_query_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.q_proj.weight.shard_dim = 0 + self.k_proj.weight.shard_dim = 0 + self.v_proj.weight.shard_dim = 0 + + self.o_proj = Linear( + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + ) + self.o_proj.weight.shard_dim = 1 + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: + from tvm.relax.op import ( + astype, + matmul, + maximum, + permute_dims, + reshape, + split, + squeeze, + ) + from tvm.relax.op.nn import softmax + + bsz, q_len, _ = hidden_states.struct_info.shape + assert bsz == 1, "Only support batch size 1 at this moment." + + if self.combine_matmul: + qkv_states = nn.emit( + split( + self.query_key_value_proj(hidden_states), + indices_or_sections=[ + self.num_query_heads * self.head_dim, + (self.num_query_heads + self.num_key_value_heads) * self.head_dim, + ], + axis=-1, + ) + ) + query_states = relax.TupleGetItem(qkv_states, 0) + key_states = relax.TupleGetItem(qkv_states, 1) + value_states = relax.TupleGetItem(qkv_states, 2) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = nn.emit( + reshape( + query_states, + (bsz, q_len, self.num_query_heads, self.head_dim), + ), + ) + key_states = nn.emit( + reshape( + key_states, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + value_states = nn.emit( + reshape( + value_states, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + + kv_seq_len = all_seq_len_shape.struct_info.values[0] + offset = kv_seq_len - q_len + query_states, key_states = self.rotary_embedding(query_states, key_states, offset) + # [bsz, t, nh, hd] + + kv_states_shape = key_states.struct_info.shape + kv_states_dtype = key_states.struct_info.dtype + assert kv_states_shape[0] == 1 # bsz + kv_states_shape = R.shape( + [kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]] + ) + kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]]) + + squeezed_key = nn.emit(squeeze(key_states, axis=0)) + squeezed_value = nn.emit(squeeze(value_states, axis=0)) + k_cache, v_cache = past_key_value + f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") + k_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[k_cache, squeezed_key], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_append, + args=[v_cache, squeezed_value], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + past_key_value = (k_cache, v_cache) + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + k_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)], + ) + ) + key_states = nn.emit(reshape(k_cache, kv_states_shape)) + value_states = nn.emit(reshape(v_cache, kv_states_shape)) + if self.num_key_value_heads != self.num_query_heads: + n_rep = self.num_query_heads // self.num_key_value_heads + key_states = nn.emit(relax.op.repeat(key_states, n_rep, axis=2)) + value_states = nn.emit(relax.op.repeat(value_states, n_rep, axis=2)) + + query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3])) + key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3])) + value_states = nn.emit(permute_dims(value_states, [0, 2, 1, 3])) + + attn_weights = nn.emit( + matmul(query_states, permute_dims(key_states, [0, 1, 3, 2])) + / relax.const(math.sqrt(self.head_dim), query_states.struct_info.dtype) + ) + + tvm.ir.assert_structural_equal( + attention_mask.struct_info.shape.values, + (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + ) + + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) + + # upcast attention to fp32 + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != query_states.struct_info.dtype: + attn_weights = astype(attn_weights, query_states.struct_info.dtype) + attn_output = nn.emit(matmul(attn_weights, value_states)) + + attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + + attn_output = self.o_proj(attn_output) + return attn_output, ((None, None) if past_key_value is None else past_key_value) + + +class StableLM3bDecoderLayer(nn.Module): + def __init__(self, config: StableLM3bConfig, rotary_embedding: RotaryEmbedding): + self.hidden_size = config.hidden_size + self.self_attn = StableLM3bAttention(config, rotary_embedding) + self.mlp = StableLM3bMLP(config) + self.input_layernorm = LayerNorm( + config.hidden_size, dtype=config.dtype, eps=config.norm_eps + ) + self.post_attention_layernorm = LayerNorm( + config.hidden_size, dtype=config.dtype, eps=config.norm_eps + ) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + all_seq_len_shape=all_seq_len_shape, + ) + if self.self_attn.num_shards > 1: + residual = nn.emit(residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype)) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit(residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype)) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + return hidden_states, present_key_value + + +def _make_causal_mask(input_ids_shape, dtype, src_len): + from tvm.relax.op import broadcast_to + + bsz, tgt_len = input_ids_shape + + def min_max_triu_te(): + return te.compute( + (tgt_len, tgt_len), + lambda i, j: tvm.tir.Select(j > i, tvm.tir.min_value(dtype), tvm.tir.max_value(dtype)), + name="make_diag_mask_te", + ) + + mask = nn.emit_te(min_max_triu_te) + diag_mask = nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, tgt_len))) + if src_len == tgt_len: + return diag_mask + + def extend_te(x, tgt_len, src_len): + return te.compute( + (bsz, 1, tgt_len, src_len), + lambda b, _, i, j: te.if_then_else( + j < src_len - tgt_len, + tvm.tir.max_value(dtype), + x[b, _, i, j - (src_len - tgt_len)], + ), + name="concat_te", + ) + + return nn.emit_te(extend_te, diag_mask, tgt_len, src_len) + + +class StableLM3bEmbedTokens(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var): + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.embed_tokens(input_ids) + return inputs_embeds + + +class StableLM3bEmbedTokensWrapper(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var): + # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent + self.model = StableLM3bEmbedTokens(config, vocab_size_var) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.model(input_ids) + return inputs_embeds + + +class StableLM3bModell(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + rotary_embedding = RotaryEmbedding( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + position_embedding_base=config.position_embedding_base, + max_sequence_length=config.max_sequence_length, + rotary_pct=0.25, + dtype=config.dtype, + ) + self.num_shards = config.num_shards + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [StableLM3bDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] + ) + self.norm = LayerNorm(config.hidden_size, dtype=config.dtype, eps=config.norm_eps) + + def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if isinstance(input_shape[-1], tvm.tir.Var) or input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask(input_shape, dtype, src_len) + else: + # Get src_len from input parameters + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + bsz, tgt_len = input_shape + combined_attention_mask = nn.emit( + relax.op.full( + (bsz, 1, tgt_len, src_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + ) + return combined_attention_mask + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + # retrieve input_ids + batch_size, seq_length, _ = inputs_embeds.struct_info.shape + seq_length_with_past = all_seq_len_shape.struct_info.values[0] + # embed positions + attention_mask = self._prepare_decoder_attention_mask( + (batch_size, seq_length), + seq_length_with_past, + inputs_embeds.struct_info.dtype, + ) + + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) + + hidden_states, key_value_cache = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + ) + next_decoder_cache += key_value_cache + + hidden_states = self.norm(hidden_states) + + assert len(next_decoder_cache) == len(self.layers) * 2 + return hidden_states, next_decoder_cache + + +class StableLM3bForCausalLM(nn.Module): + def __init__(self, config: StableLM3bConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + self.model = StableLM3bModell(config, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + assert config.hidden_size % config.num_attention_heads == 0 + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_tokens" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + with bb.function(func_name): + model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("v", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + all_seq_len = tvm.tir.Var("m", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + inputs, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: StableLM3bConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + all_seq_len = tvm.tir.Var("n", "int64") + + with bb.function(func_name): + model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, all_seq_len_shape, past_key_values=past_key_values + ) + params = [ + input_ids, + all_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: + num_key_value_heads = ( + config.num_attention_heads + if config.num_key_value_heads is None + else config.num_key_value_heads + ) // config.num_shards + init_shape = relax.ShapeExpr( + ( + config.max_sequence_length, + num_key_value_heads, + config.hidden_size // config.num_attention_heads, # head_dim + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.Call( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder((1, 1, tvm.tir.Var("v", "int64")), dtype="float32", name="logits") + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def emit_shard3d(bb: relax.BlockBuilder) -> None: + from tvm.script import tir as T + + def _emit(dtype: str, global_symbol: str): + @T.prim_func + def shard_3d(a: T.handle, num_shards: T.int64, b: T.handle): + T.func_attr( + { + "tir.noalias": T.bool(True), + "global_symbol": global_symbol, + } + ) + s_0, s_1, s_2 = T.int64(), T.int64(), T.int64() + # pylint: disable=invalid-name + A = T.match_buffer(a, (s_0, s_1, s_2), dtype) + B = T.match_buffer(b, (num_shards, s_0, s_1 // num_shards, s_2), dtype) + # pylint: enable=invalid-name + for j_o, i, j_i, k in T.grid(num_shards, s_0, s_1 // num_shards, s_2): + with T.block("B"): + v_j_o = T.axis.spatial(num_shards, j_o) + v_i = T.axis.spatial(s_0, i) + v_j_i = T.axis.spatial(s_1 // num_shards, j_i) + v_k = T.axis.spatial(s_2, k) + B[v_j_o, v_i, v_j_i, v_k] = A[v_i, v_j_o * (s_1 // num_shards) + v_j_i, v_k] + + bb.add_func(shard_3d, global_symbol) + + _emit("float32", "shard3d_fp32") + _emit("float16", "shard3d_fp16") + _emit("uint32", "shard3d_uint32") + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + max_seq_len = args.max_seq_len + sep_embed = args.sep_embed + + position_embedding_base = 10000 + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + config = StableLM3bConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + convert_weight_only=args.convert_weight_only, + ) + if max_seq_len != -1: + config.max_sequence_length = max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + emit_shard3d(bb) + + if sep_embed: + create_embed_func(bb, param_manager, config, args.quantization) + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=[2], + add_prefix_space=False, + ) + + mod = bb.get() + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr( + "tir_var_upper_bound", + { + "n": config.max_sequence_length, + "m": config.max_sequence_length, + }, + ) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + num_shards = args.num_shards + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.num_key_value_heads + if kv_heads is None: + kv_heads = q_heads + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + q = q.reshape((num_shards, q_heads // num_shards, head_dim, hidden_size)) + k = k.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) + v = v.reshape((num_shards, kv_heads // num_shards, head_dim, hidden_size)) + qkv = np.concatenate([q, k, v], axis=1) + qkv = qkv.reshape((-1, hidden_size)).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + intermediate_size = config.intermediate_size + gate, up = torch_params + gate = gate.reshape((num_shards, intermediate_size // num_shards, hidden_size)) + up = up.reshape((num_shards, intermediate_size // num_shards, hidden_size)) + gate_up = np.concatenate([gate, up], axis=1) + gate_up = gate_up.reshape((-1, hidden_size)).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + param_list = [None] * param_manager.nparam_to_load + + return mod, param_manager, param_list, config diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 17329c19d4..9d8751e5d6 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -13,7 +13,7 @@ from .transform import ReorderTransformFunc supported_model_types = set( - ["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral"] + ["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral", "stablelm_epoch"] ) @@ -64,6 +64,7 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: "codellama": "codellama_completion", "vicuna-": "vicuna_v1.1", "dolly-": "dolly", + "stablelm-3b-": "stablelm-3b", "stablelm-": "stablelm", "redpajama-": "redpajama_chat", "minigpt": "minigpt", From a032d40bfc4bd4a2338e5f7b44e49c71a50e103a Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:03:24 -0400 Subject: [PATCH 010/116] [Docs] Iterate model prebuilts docs (#1043) * Iterate model prebuilts docs * small fix --- README.md | 67 ++++++---------- .../distribute_compiled_models.rst | 2 +- docs/prebuilt_models.rst | 80 ++++++++----------- 3 files changed, 59 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index 3354a94f8b..e4379df567 100644 --- a/README.md +++ b/README.md @@ -52,8 +52,26 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo -**Prebuilt model support.** MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can -use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list. +## News + +* [08/25/2023] CodeLlama support is up. +* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. +* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. +* [08/02/2023] [Dockerfile](https://github.com/mlc-ai/llm-perf-bench/) is released for CUDA performance benchmarking. +* [07/19/2023] Support for Llama2-7B/13B/70B is up. +* [05/22/2023] [[Post]](https://blog.mlc.ai/2023/05/22/bringing-open-large-language-models-to-consumer-devices) RedPajama support is up. +* [05/08/2023] [[Post]](https://blog.mlc.ai/2023/05/08/bringing-hardware-accelerated-language-models-to-android-devices) MLC LLM is now available on Android. +* [05/01/2023] [[Post]](https://blog.mlc.ai/2023/05/01/bringing-accelerated-llm-to-consumer-hardware) MLC LLM is released with Metal, Vulkan and CUDA backends. +* [04/14/2023] [WebLLM](https://github.com/mlc-ai/web-llm) is released prior to MLC LLM with WebGPU and WebAssembly backend. + +## Getting Started + +Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. + +## Prebuilt model support + +MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can +use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. @@ -64,29 +82,8 @@ use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_mode - - - - - - - - - - - - - - - - - - - - - - - + + @@ -112,25 +109,13 @@ use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_mode + + + +
LlamaLlama-2
Code Llama
Vicuna
WizardLM
WizardMath
OpenOrca Platypus2
FlagAlpha Llama-2 Chinese
georgesung Llama-2 UncensoredLlamaLlama-2, Code Llama, Vicuna, WizardLM, WizardMath, OpenOrca Platypus2, FlagAlpha Llama-2 Chinese, georgesung Llama-2 Uncensored
GPT-NeoXChatGLM
StableLM
-## News - -* [08/25/2023] CodeLlama support is up. -* [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. -* [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. -* [08/02/2023] [Dockerfile](https://github.com/mlc-ai/llm-perf-bench/) is released for CUDA performance benchmarking. -* [07/19/2023] Support for Llama2-7B/13B/70B is up. -* [05/22/2023] [[Post]](https://blog.mlc.ai/2023/05/22/bringing-open-large-language-models-to-consumer-devices) RedPajama support is up. -* [05/08/2023] [[Post]](https://blog.mlc.ai/2023/05/08/bringing-hardware-accelerated-language-models-to-android-devices) MLC LLM is now available on Android. -* [05/01/2023] [[Post]](https://blog.mlc.ai/2023/05/01/bringing-accelerated-llm-to-consumer-hardware) MLC LLM is released with Metal, Vulkan and CUDA backends. -* [04/14/2023] [WebLLM](https://github.com/mlc-ai/web-llm) is released prior to MLC LLM with WebGPU and WebAssembly backend. - -## Getting Started - -Please visit our [this page](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. - ## Universal Deployment APIs MLC LLM provides multiple sets of APIs across platforms and environments. These include diff --git a/docs/compilation/distribute_compiled_models.rst b/docs/compilation/distribute_compiled_models.rst index b6f31f9386..96ac5a09a3 100644 --- a/docs/compilation/distribute_compiled_models.rst +++ b/docs/compilation/distribute_compiled_models.rst @@ -161,7 +161,7 @@ Download the Distributed Models and Run in iOS App For iOS app, model libraries are statically packed into the app at the time of app building. Therefore, the iOS app supports running any models whose model libraries are integrated into the app. -You can check the :ref:`list of supported model libraries `. +You can check the :ref:`list of supported model libraries `. To download and run the compiled RedPajama-3B instruct model on iPhone, we need to reuse the integrated ``RedPajama-INCITE-Chat-3B-v1-q4f16_1`` model library. Please revisit :ref:`distribute-model-step3-specify-model-lib` and make sure the ``model_lib`` field of `mlc-chat-config.json` is set to ``RedPajama-INCITE-Chat-3B-v1-q4f16_1``. diff --git a/docs/prebuilt_models.rst b/docs/prebuilt_models.rst index 061a1e171b..8a9f74253b 100644 --- a/docs/prebuilt_models.rst +++ b/docs/prebuilt_models.rst @@ -256,16 +256,13 @@ MLC-LLM supports the following model architectures: :widths: 10 10 15 15 :header-rows: 1 - * - Code - - Architecture - - Variants w/ MLC prebuilts - - Variants w/o MLC prebuilts - * - ``llama`` - - LLaMa - - * :ref:`Prebuilt library table ` - * `Official link `__ - * `Relax Code `__ + * - Model Architecture + - Support + - Available MLC Prebuilts + - Unavailable in MLC Prebuilts + * - `LLaMA `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ - * :ref:`Llama-2 ` * :ref:`Code Llama ` * :ref:`Vicuna ` @@ -280,64 +277,51 @@ MLC-LLM supports the following model architectures: * `Gorilla `__ * `YuLan-Chat `__ * `WizardCoder (new) `__ - * - ``gpt-neox`` - - GPT-NeoX - - * :ref:`Prebuilt library table ` - * `Official link `__ - * `Relax Code `__ + * - `GPT-NeoX `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ - * :ref:`RedPajama ` - * `Dolly `__ * `Pythia `__ * `StableCode `__ - * - ``gptj`` - - GPT-J - - * Prebuilt not compiled yet - * `Official link `__ - * `Relax Code `__ + * - `GPT-J `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ - - * `MOSS `__ - * - ``rwkv`` - - RWKV - - * :ref:`Prebuilt library table ` - * `Official link `__ - * `Relax Code `__ + * - `RWKV `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ - * :ref:`RWKV-raven ` - - * - ``minigpt`` - - MiniGPT - - * Prebuilt not compiled yet - * `Official link `__ - * `Relax Code `__ + * - `MiniGPT `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ - - * `MiniGPT-4 `__ - * - ``gpt_bigcode`` - - GPTBigCode - - * :ref:`Prebuilt library table ` - * `Official link `__ - * `Relax Code `__ + * - `GPTBigCode `__ + - * :ref:`Prebuilt Model Library ` + * `MLC Implementation `__ - * :ref:`WizardCoder (old) ` - * `StarCoder `__ * `SantaCoder `__ - * - ``chatglm`` - - ChatGLM - - * Prebuilt not compiled yet - * `Official link `__ - * `Relax Code `__ + * - `ChatGLM `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ - - * `ChatGLM2 `__ * `CodeGeeX2 `__ + * - `StableLM `__ + - * Prebuilt not compiled yet + * `MLC Implementation `__ + - + - * `StableLM `__ -If the model variant you are interested in is in one of these model architectures we support (but we have not provided the prebuilt weights yet), you can check the :doc:`model compilation page ` on how to compile your own models. Note that you only need to compile the weights for your model variant and reuse the library file found in Level 2 tables. +If the model variant you are interested in uses one of these model architectures we support (but we have not provided the prebuilt weights yet), you can check out :doc:`/compilation/compile_models` on how to compile your own models. Afterwards, you may follow :doc:`/compilation/distribute_compiled_models` to upload your prebuilt weights to hugging face, and submit a PR that adds an entry to this page, contributing to the community. For models structured in an architecture we have not supported yet, you could: -- Either `create a new issue `_ to request a new model architecture. +- Either `create a [Model Request] issue `__ which automatically shows up on our `Model Request Tracking Board `__. - Or follow our tutorial :doc:`Define New Models `, which introduces how to bring a new model architecture to MLC-LLM. From a58605fca3d4601899d64bf29c208306b9443a37 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 9 Oct 2023 15:05:34 -0700 Subject: [PATCH 011/116] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index e4379df567..97395eb769 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo ## News +* [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). * [08/25/2023] CodeLlama support is up. * [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. * [08/09/2023] [[Post]](https://blog.mlc.ai/2023/08/09/Making-AMD-GPUs-competitive-for-LLM-inference) ROCm backend is mature to use. From bdd9d9b94bd75052eaa3f69076e9a10f9ed29471 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 9 Oct 2023 19:08:14 -0400 Subject: [PATCH 012/116] [CPP] Separate common utils out from llm_chat.cc (#1044) This PR separates out the tokenizer creation function, the random number generator out from `llm_chat.cc` as a preparation step for batching inference support, since these functions/modules are also used in the same way in batching inference. --- cpp/base.h | 16 ++++++++ cpp/image_embed.h | 12 +----- cpp/llm_chat.cc | 97 ++++++++--------------------------------------- cpp/llm_chat.h | 12 +----- cpp/random.h | 37 ++++++++++++++++++ cpp/support.h | 31 +++++++++++++++ cpp/tokenizers.cc | 61 +++++++++++++++++++++++++++++ cpp/tokenizers.h | 24 ++++++++++++ 8 files changed, 186 insertions(+), 104 deletions(-) create mode 100644 cpp/base.h create mode 100644 cpp/random.h create mode 100644 cpp/support.h create mode 100644 cpp/tokenizers.cc create mode 100644 cpp/tokenizers.h diff --git a/cpp/base.h b/cpp/base.h new file mode 100644 index 0000000000..0cc6777dd4 --- /dev/null +++ b/cpp/base.h @@ -0,0 +1,16 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file base.h + */ + +#ifndef MLC_LLM_DLL +#ifdef _WIN32 +#ifdef MLC_LLM_EXPORTS +#define MLC_LLM_DLL __declspec(dllexport) +#else +#define MLC_LLM_DLL __declspec(dllimport) +#endif +#else +#define MLC_LLM_DLL __attribute__((visibility("default"))) +#endif +#endif diff --git a/cpp/image_embed.h b/cpp/image_embed.h index 87b862242c..e0e21da686 100644 --- a/cpp/image_embed.h +++ b/cpp/image_embed.h @@ -6,17 +6,7 @@ #include #include -#ifndef MLC_LLM_DLL -#ifdef _WIN32 -#ifdef MLC_LLM_EXPORTS -#define MLC_LLM_DLL __declspec(dllexport) -#else -#define MLC_LLM_DLL __declspec(dllimport) -#endif -#else -#define MLC_LLM_DLL __attribute__((visibility("default"))) -#endif -#endif +#include "base.h" namespace mlc { namespace llm { diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index d12b0fbd92..25f68203f2 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -32,6 +32,9 @@ #include #include "conversation.h" +#include "random.h" +#include "support.h" +#include "tokenizers.h" namespace mlc { namespace llm { @@ -39,62 +42,6 @@ namespace llm { using tvm::Device; using namespace tvm::runtime; namespace { -//---------------------------- -// Tokenizers -//---------------------------- -using tokenizers::Tokenizer; - -std::string LoadBytesFromFile(const std::string& path) { - std::ifstream fs(path, std::ios::in | std::ios::binary); - ICHECK(!fs.fail()) << "Cannot open " << path; - std::string data; - fs.seekg(0, std::ios::end); - size_t size = static_cast(fs.tellg()); - fs.seekg(0, std::ios::beg); - data.resize(size); - fs.read(data.data(), size); - return data; -} - -std::unique_ptr TokenizerFromPath(const std::string& _path) { - std::filesystem::path path(_path); - std::filesystem::path sentencepiece; - std::filesystem::path huggingface; - std::filesystem::path rwkvworld; - CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path; - if (std::filesystem::is_directory(path)) { - sentencepiece = path / "tokenizer.model"; - huggingface = path / "tokenizer.json"; - rwkvworld = path / "tokenizer_model"; - // Check ByteLevelBPE - { - std::filesystem::path merges_path = path / "merges.txt"; - std::filesystem::path vocab_path = path / "vocab.json"; - std::filesystem::path added_tokens_path = path / "added_tokens.json"; - if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) && - std::filesystem::exists(added_tokens_path)) { - std::string vocab = LoadBytesFromFile(vocab_path.string()); - std::string merges = LoadBytesFromFile(merges_path.string()); - std::string added_tokens = LoadBytesFromFile(added_tokens_path.string()); - return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens); - } - } - } else { - sentencepiece = path.parent_path() / "tokenizer.model"; - huggingface = path.parent_path() / "tokenizer.json"; - rwkvworld = path.parent_path() / "tokenizer_model"; - } - if (std::filesystem::exists(sentencepiece)) { - return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string())); - } - if (std::filesystem::exists(huggingface)) { - return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string())); - } - if (std::filesystem::exists(rwkvworld)) { - return Tokenizer::FromBlobRWKVWorld(rwkvworld.string()); - } - LOG(FATAL) << "Cannot find any tokenizer under: " << _path; -} //------------------------------ // support functions @@ -315,23 +262,6 @@ struct FunctionTable { PackedFunc fkvcache_array_popn_; }; -class RandomGenerator { - private: - std::mt19937 gen; - std::uniform_real_distribution<> dis; - - RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {} - - public: - static RandomGenerator& GetInstance(int seed = std::random_device{}()) { - static RandomGenerator instance(seed); - return instance; - } - - double GetRandomNumber() { return dis(gen); } - - void SetSeed(int seed) { gen.seed(seed); } -}; } // namespace //------------------------------ @@ -708,9 +638,10 @@ class LLMChat { return view; } - std::vector PrepareBeforeEmbedding(std::string inp, bool append_conversation = true, - PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, - picojson::object generation_config = picojson::object()) { + std::vector PrepareBeforeEmbedding( + std::string inp, bool append_conversation = true, + PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, + picojson::object generation_config = picojson::object()) { if (conversation_.separator_style == SeparatorStyle::kLM || conversation_.separator_style == SeparatorStyle::kCodeCompletion) { this->ResetChat(); @@ -742,7 +673,7 @@ class LLMChat { String generation_config_str = "") { // process generation settings picojson::object generation_config = picojson::object(); - if(!generation_config_str.empty()) { + if (!generation_config_str.empty()) { picojson::value generation_config_json; picojson::parse(generation_config_json, generation_config_str); generation_config = generation_config_json.get(); @@ -778,7 +709,8 @@ class LLMChat { * \param embedding The embedding to prefill with. * \param decode_next_token Whether to decode next token. */ - void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true, String generation_config_str = "") { + void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true, + String generation_config_str = "") { if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; throw; @@ -799,7 +731,7 @@ class LLMChat { // process generation settings picojson::object generation_config = picojson::object(); - if(!generation_config_str.empty()) { + if (!generation_config_str.empty()) { picojson::value generation_config_json; picojson::parse(generation_config_json, generation_config_str); generation_config = generation_config_json.get(); @@ -830,14 +762,15 @@ class LLMChat { if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; } - NDArray embedding = Downcast(EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str)); + NDArray embedding = Downcast( + EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str)); PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str); return; } // process generation settings picojson::object generation_config = picojson::object(); - if(!generation_config_str.empty()) { + if (!generation_config_str.empty()) { picojson::value generation_config_json; picojson::parse(generation_config_json, generation_config_str); generation_config = generation_config_json.get(); @@ -876,7 +809,7 @@ class LLMChat { void DecodeStep(String generation_config_str = "") { // process generation settings picojson::object generation_config = picojson::object(); - if(!generation_config_str.empty()) { + if (!generation_config_str.empty()) { picojson::value generation_config_json; picojson::parse(generation_config_json, generation_config_str); generation_config = generation_config_json.get(); diff --git a/cpp/llm_chat.h b/cpp/llm_chat.h index 1839e8cb4d..39408d1685 100644 --- a/cpp/llm_chat.h +++ b/cpp/llm_chat.h @@ -6,17 +6,7 @@ #include #include -#ifndef MLC_LLM_DLL -#ifdef _WIN32 -#ifdef MLC_LLM_EXPORTS -#define MLC_LLM_DLL __declspec(dllexport) -#else -#define MLC_LLM_DLL __declspec(dllimport) -#endif -#else -#define MLC_LLM_DLL __attribute__((visibility("default"))) -#endif -#endif +#include "base.h" namespace mlc { namespace llm { diff --git a/cpp/random.h b/cpp/random.h new file mode 100644 index 0000000000..e6331a9699 --- /dev/null +++ b/cpp/random.h @@ -0,0 +1,37 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file random.h + * \brief Header of random number generator. + */ + +#ifndef MLC_LLM_RANDOM_H_ +#define MLC_LLM_RANDOM_H_ + +#include + +namespace mlc { +namespace llm { + +// Random number generator +class RandomGenerator { + private: + std::mt19937 gen; + std::uniform_real_distribution<> dis; + + RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {} + + public: + static RandomGenerator& GetInstance(int seed = std::random_device{}()) { + static RandomGenerator instance(seed); + return instance; + } + + double GetRandomNumber() { return dis(gen); } + + void SetSeed(int seed) { gen.seed(seed); } +}; + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_RANDOM_H_ diff --git a/cpp/support.h b/cpp/support.h new file mode 100644 index 0000000000..20eadbbd0a --- /dev/null +++ b/cpp/support.h @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file support.h + * \brief Header of utilities. + */ + +#ifndef MLC_LLM_COMMON_H_ +#define MLC_LLM_COMMON_H_ + +#include +#include + +namespace mlc { +namespace llm { + +inline std::string LoadBytesFromFile(const std::string& path) { + std::ifstream fs(path, std::ios::in | std::ios::binary); + ICHECK(!fs.fail()) << "Cannot open " << path; + std::string data; + fs.seekg(0, std::ios::end); + size_t size = static_cast(fs.tellg()); + fs.seekg(0, std::ios::beg); + data.resize(size); + fs.read(data.data(), size); + return data; +} + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_COMMON_H_ diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc new file mode 100644 index 0000000000..8d38dd9572 --- /dev/null +++ b/cpp/tokenizers.cc @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file tokenizer.cc + */ + +#include "tokenizers.h" + +#include +#include + +#include +#include +#include + +#include "support.h" + +namespace mlc { +namespace llm { + +std::unique_ptr TokenizerFromPath(const std::string& _path) { + std::filesystem::path path(_path); + std::filesystem::path sentencepiece; + std::filesystem::path huggingface; + std::filesystem::path rwkvworld; + CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path; + if (std::filesystem::is_directory(path)) { + sentencepiece = path / "tokenizer.model"; + huggingface = path / "tokenizer.json"; + rwkvworld = path / "tokenizer_model"; + // Check ByteLevelBPE + { + std::filesystem::path merges_path = path / "merges.txt"; + std::filesystem::path vocab_path = path / "vocab.json"; + std::filesystem::path added_tokens_path = path / "added_tokens.json"; + if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) && + std::filesystem::exists(added_tokens_path)) { + std::string vocab = LoadBytesFromFile(vocab_path.string()); + std::string merges = LoadBytesFromFile(merges_path.string()); + std::string added_tokens = LoadBytesFromFile(added_tokens_path.string()); + return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens); + } + } + } else { + sentencepiece = path.parent_path() / "tokenizer.model"; + huggingface = path.parent_path() / "tokenizer.json"; + rwkvworld = path.parent_path() / "tokenizer_model"; + } + if (std::filesystem::exists(sentencepiece)) { + return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string())); + } + if (std::filesystem::exists(huggingface)) { + return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string())); + } + if (std::filesystem::exists(rwkvworld)) { + return Tokenizer::FromBlobRWKVWorld(rwkvworld.string()); + } + LOG(FATAL) << "Cannot find any tokenizer under: " << _path; +} + +} // namespace llm +} // namespace mlc diff --git a/cpp/tokenizers.h b/cpp/tokenizers.h new file mode 100644 index 0000000000..f44f828e97 --- /dev/null +++ b/cpp/tokenizers.h @@ -0,0 +1,24 @@ +/*! + * Copyright (c) 2023 by Contributors + * \file tokenizers.h + * \brief Header of tokenizer related functions. + */ + +#ifndef MLC_LLM_TOKENIZER_H_ +#define MLC_LLM_TOKENIZER_H_ + +#include + +#include "base.h" + +namespace mlc { +namespace llm { + +using tokenizers::Tokenizer; + +MLC_LLM_DLL std::unique_ptr TokenizerFromPath(const std::string& _path); + +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_TOKENIZER_H_ From 20131fbce9b808cd3b7ebfe95e3cd6654de9b619 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 9 Oct 2023 16:53:56 -0700 Subject: [PATCH 013/116] Update README.md (#1045) Update README.md --- README.md | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 97395eb769..4ef145d4cd 100644 --- a/README.md +++ b/README.md @@ -4,18 +4,18 @@ [Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/) | [Discord][discord-url] -Machine Learning Compilation for Large Language Models (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. +**M**achine **L**earning **C**ompilation for **L**arge **L**anguage **M**odels (MLC LLM) is a high-performance universal deployment solution that allows native deployment of any large language models with native APIs with compiler acceleration. The mission of this project is to enable everyone to develop, optimize and deploy AI models natively on everyone's devices with ML compilation techniques. **Universal deployment.** MLC LLM supports the following platforms and hardware: - - - - - + + + + + @@ -28,21 +28,18 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo - + - + - - - - + - + @@ -69,7 +66,7 @@ Machine Learning Compilation for Large Language Models (MLC LLM) is a high-perfo Please visit our [documentation](https://llm.mlc.ai/docs/index.html#getting-started) for detailed instructions. -## Prebuilt model support +## Model Support MLC LLM supports a wide range of model architectures and variants. We have the following prebuilts which you can use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_models.html) to see the full list, and [Compile Models via MLC](https://llm.mlc.ai/docs/compilation/compile_models.html) to see how to use models not on this list. @@ -77,8 +74,8 @@ use off-the-shelf. Visit [Prebuilt Models](https://llm.mlc.ai/docs/prebuilt_mode
AMD GPUNVIDIA GPUApple M1/M2 GPUIntel GPU AMD GPUNVIDIA GPUApple GPUIntel GPU
macOS✅ Metal✅ Metal (dGPU) N/A ✅ Metal✅ Metal✅ Metal (iGPU)
Web Browser✅ WebGPU✅ WebGPU✅ WebGPU✅ WebGPU✅ WebGPU and WASM
iOS / iPadOS✅ Metal on Apple M1/M2 GPU✅ Metal on Apple A-series GPU
Android
- - + + From 1e6fb11658356b1c394871872d480eef9f2c9197 Mon Sep 17 00:00:00 2001 From: Denise Kutnick Date: Wed, 11 Oct 2023 00:06:46 -0700 Subject: [PATCH 014/116] add verbose stats to mlc-chat REST API (#1049) * add verbose stats to mlc-chat REST API * update docs --- docs/deploy/rest.rst | 3 +++ python/mlc_chat/rest.py | 6 ++++++ 2 files changed, 9 insertions(+) diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index d5955190e9..338f8de56c 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -89,6 +89,9 @@ The REST API provides the following endpoints: Get the latest runtime stats (encode/decode speed). +.. http:get:: /verbose_stats + + Get the verbose runtime stats (encode/decode speed, total runtime). Use REST API in your own program -------------------------------- diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index f5038f5211..1703f97826 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -327,6 +327,12 @@ async def read_stats(): """ return session["chat_mod"].stats() +@app.get("/verbose_stats") +async def read_stats_verbose(): + """ + Get the verbose runtime stats. + """ + return session["chat_mod"].stats(verbose=True) ARGS = convert_args_to_argparser().parse_args() if __name__ == "__main__": From b9179cfdf02e041be15871b36e74400ab9001921 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 12 Oct 2023 12:15:44 -0500 Subject: [PATCH 015/116] [Transform] Apply split_rotary optimization on prefill (#1033) * [Transform] Apply split_rotary optimization on prefill Prior to this commit, the `transform.fuse_split_rotary_embedding` function was only applicable to the `decode` function of a Llama-type model. This was due to the sequence length being restricted to one, both in the pattern-match rule and in the `split_rotary` function, and the function being restricted to operate only on the `decode` function. This commit updates the `transform.fuse_split_rotary_embedding` pass to be a `tvm.ir.transform.Pass`, operating on all applicable matched in the `IRModule`. The `split_rotary` function is now produced as a fully-generic function, with static parameters substituted in afterwards. At this stage, the sequence length is retained as a dynamic parameter, such that it can be used by the `prefill` function. * Avoid multiple kernel launches for split_rotary --- mlc_llm/core.py | 3 +- .../transform/fuse_split_rotary_embedding.py | 460 ++++++++++-------- 2 files changed, 260 insertions(+), 203 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index ddf93bf09a..b9280274c9 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -402,12 +402,11 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() mod = fuse_split_rotary_embedding( - mod, config.num_attention_heads // args.num_shards, num_key_value_heads // args.num_shards, config.hidden_size // args.num_shards, config.position_embedding_base, - ) + )(mod) if args.target_kind == "cuda": patterns = [] diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index 4ecc843f4a..d04f37ee69 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -1,5 +1,5 @@ +import tvm from tvm import relax -from tvm.script import tir as T from tvm.relax.dpl import ( PatternContext, is_op, @@ -10,234 +10,292 @@ TuplePattern, is_shape, ) -from tvm.script import relax as R +from tvm.script import relax as R, tir as T -def get_split_rotary(num_attention_heads, head_dim, position_embedding_base): - hidden_size = num_attention_heads * head_dim +def get_dynamic_split_rotary(): + """Implementation of R.split(rotary_embedding(fused_qkv)) - @T.prim_func + Implementation is generic over the number of query heads, + key/value heads, sequence length, head dimension, and position + embedding base. These parameters can be replaced with static + values using `PrimFunc.specialize`. + """ + + @T.prim_func(private=True) def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, + fused_qkv_handle: T.handle, + embedded_query_handle: T.handle, + embedded_key_handle: T.handle, + value_handle: T.handle, + rotary_offset: T.int64, + batch_size: T.int64, + seq_len: T.int64, + num_query_heads: T.int64, + num_kv_heads: T.int64, + head_dim: T.int64, + position_embedding_base: T.float32, ): - A = T.match_buffer(qkv, [1, 1, hidden_size * 3], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, hidden_size], dtype="float16") + Fused_QKV = T.match_buffer( + fused_qkv_handle, + [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], + dtype="float16", + ) + EmbeddedQuery = T.match_buffer( + embedded_query_handle, + [batch_size, seq_len, num_query_heads, head_dim], + dtype="float16", + ) + EmbeddedKey = T.match_buffer( + embedded_key_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) + Value = T.match_buffer( + value_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)], - ) - T.writes( - T_split[v_ax0, v_ax1, v_ax2], - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) + + for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): + with T.block("FusedRotaryEmbeddingAndSplitQKV"): + batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) + pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) + inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), + position_embedding_base, + T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), ) freq: T.float32 = pos * inv_freq cos_value: T.float16 = T.Cast("float16", T.cos(freq)) sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) + T.int64(head_dim // 2)] + + input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] + embedded_value = cos_value * input_value + sin_value * T.Select( + head_i < T.int64(head_dim // 2), + Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] * T.float16(-1), + Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)] + if head_num < num_query_heads: + EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value + elif head_num < num_query_heads + num_kv_heads: + EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value + else: + Value[ + batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i + ] = input_value + + param_sinfo = [] + for param in split_rotary.params: + if param in split_rotary.buffer_map: + buf = split_rotary.buffer_map[param] + sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) + else: + sinfo = relax.PrimStructInfo(param.dtype) + param_sinfo.append(sinfo) + + relax.expr._update_struct_info( + split_rotary, + tvm.relax.FuncStructInfo( + params=param_sinfo, + ret=relax.TupleStructInfo([]), + purity=False, + ), + ) return split_rotary -def get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base +def fuse_split_rotary_embedding( + num_query_heads, num_kv_heads, hidden_size, position_embedding_base ): - query_hidden_size = num_query_heads * head_dim - kv_hidden_size = num_kv_heads * head_dim - total_size = query_hidden_size + kv_hidden_size * 2 + @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") + def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + head_dim = hidden_size // num_query_heads + split_rotary = get_dynamic_split_rotary() - @T.prim_func - def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, - ): - A = T.match_buffer(qkv, [1, 1, total_size], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, query_hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, kv_hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, kv_hidden_size], dtype="float16") + ( + dyn_batch_size, + dyn_seq_len, + dyn_num_query_heads, + dyn_num_kv_heads, + dyn_head_dim, + dyn_position_embedding_base, + ) = split_rotary.params[-6:] - T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(query_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - ) - T.writes(T_split[v_ax0, v_ax1, v_ax2]) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(kv_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size)], - ) - T.writes( - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) + T.int64(head_dim // 2)] - * T.float16(-1), - ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size) - ] + split_rotary = split_rotary.specialize( + { + # Static model parameters + dyn_batch_size: T.int64(1), + dyn_num_query_heads: T.int64(num_query_heads), + dyn_num_kv_heads: T.int64(num_kv_heads), + dyn_head_dim: T.int64(head_dim), + dyn_position_embedding_base: T.float32(position_embedding_base), + # Dynamic parameters, to be inferred from TIR Buffer shapes + dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), + } + ) - return split_rotary + mod["split_rotary"] = split_rotary + split_rotary_gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) -def fuse_split_rotary_embedding( - mod, num_query_heads, num_kv_heads, hidden_size, position_embedding_base -): - head_dim = hidden_size // num_query_heads - mod["split_rotary"] = ( - get_split_rotary(num_query_heads, head_dim, position_embedding_base) - if num_query_heads == num_kv_heads - else get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base - ) - ) + with PatternContext() as ctx: + # flat_qkv_tuple: R.Tuple( + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) + # + # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] + # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_query, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] + # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_key, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] + # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_value, R.shape([batch_size, seq_len, 32, 128]) + # ) + # embedded_query = R.call_tir( + # cls.rotary_embedding1, + # [query], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) + # embedded_key = R.call_tir( + # cls.rotary_embedding1, + # [key], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) - gvar = mod.get_global_var("split_rotary") - relax.expr._update_struct_info(gvar, mod.get_global_var("rotary_embedding1").struct_info) + pat_rotary_embedding_gvar = GlobalVarPattern("rotary_embedding") | GlobalVarPattern( + "rotary_embedding1" + ) - with PatternContext() as ctx: - # lv3: R.Tuple(R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")) = R.split(lv2, indices_or_sections=[4096, 8192], axis=2) + pat_flat_fused_qkv = wildcard() + offset = wildcard() - # lv1521: R.Tensor((1, 1, 4096), dtype="float16") = lv3[0] - # lv1522: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1521, R.shape([1, 1, 32, 128])) - # lv1524: R.Tensor((1, 1, 4096), dtype="float16") = lv3[1] - # lv1525: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1524, R.shape([1, 1, 32, 128])) - # lv1527: R.Tensor((1, 1, 4096), dtype="float16") = lv3[2] - # lv1528: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1527, R.shape([1, 1, 32, 128])) - # lv1530 = R.call_tir(cls.rotary_embedding1, (lv1525, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape([n])) - # lv_1 = R.call_tir(cls.rotary_embedding1, (lv1522, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape( + # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) + pat_query_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_key_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_value_shape = wildcard() - inp_pat = wildcard() - offset = wildcard() + pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) + pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) + pat_query = is_op("relax.reshape")( + pat_flat_query, pat_query_shape, add_constraint=False + ) + pat_flat_query.used_by(pat_query) + pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) + pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) + pat_flat_key.used_by(pat_key) + pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) + pat_value = is_op("relax.reshape")( + pat_flat_value, pat_value_shape, add_constraint=False + ) + pat_flat_value.used_by(pat_value) - lv3 = is_op("relax.split")(inp_pat) - lv1521 = is_tuple_get_item(lv3, 0) - lv1522 = is_op("relax.reshape")( - lv1521, is_shape([1, 1, num_query_heads, head_dim]), add_constraint=False - ) - lv1521.used_by(lv1522) - lv1524 = is_tuple_get_item(lv3, 1) - lv1525 = is_op("relax.reshape")( - lv1524, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1524.used_by(lv1525) - lv1527 = is_tuple_get_item(lv3, 2) - V = is_op("relax.reshape")( - lv1527, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1527.used_by(V) + pat_embedded_query = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, TuplePattern([pat_query]), offset, add_constraint=False + ) + pat_embedded_key = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, TuplePattern([pat_key]), offset, add_constraint=False + ) - Q = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1522]), offset, add_constraint=False - ) - K = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1525]), offset, add_constraint=False - ) + pat_flat_qkv_tuple.used_by(pat_flat_query) + pat_flat_qkv_tuple.used_by(pat_flat_key) + pat_flat_qkv_tuple.used_by(pat_flat_value) + pat_query.used_by(pat_embedded_query) + pat_key.used_by(pat_embedded_key) - lv3.used_by(lv1521) - lv3.used_by(lv1524) - lv3.used_by(lv1527) - lv1522.used_by(Q) - lv1525.used_by(K) - - def rewriter(matchings, bindings): - inp = matchings[inp_pat] - call_tir = matchings[Q] - n = bindings[call_tir].args[-1] - out_sinfo = [ - R.Tensor((1, 1, num_query_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - ] - lv3_new = R.call_tir( - mod.get_global_var("split_rotary"), (inp,), out_sinfo=out_sinfo, tir_vars=n - ) - lv1521_new = lv3_new[0] - lv1522_new = R.reshape(lv1521_new, R.shape([1, 1, num_query_heads, head_dim])) - lv1524_new = lv3_new[1] - lv1525_new = R.reshape(lv1524_new, R.shape([1, 1, num_kv_heads, head_dim])) - lv1527_new = lv3_new[2] - lv1528_new = R.reshape(lv1527_new, R.shape([1, 1, num_kv_heads, head_dim])) - - return { - matchings[lv3]: lv3_new, - matchings[lv1521]: lv1521_new, - matchings[lv1522]: lv1522_new, - matchings[lv1524]: lv1524_new, - matchings[lv1525]: lv1525_new, - matchings[lv1527]: lv1527_new, - matchings[V]: lv1528_new, - matchings[Q]: lv1522_new, - matchings[K]: lv1525_new, - } - - mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"]) - return mod + def rewriter(matchings, bindings): + # Extracting all the relax and TIR variables that we'll need + flat_fused_qkv = matchings[pat_flat_fused_qkv] + + flat_qkv_tuple = matchings[pat_flat_qkv_tuple] + + flat_query = matchings[pat_flat_query] + flat_key = matchings[pat_flat_key] + flat_value = matchings[pat_flat_value] + + query = matchings[pat_query] + key = matchings[pat_key] + value = matchings[pat_value] + + embedded_query = matchings[pat_embedded_query] + embedded_key = matchings[pat_embedded_key] + + rotary_embedding_offset = bindings[query].args[-1][1] + + batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape + _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape + + # Rewriting along the new path + + fused_qkv = relax.op.reshape( + flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] + ) + + split_rotary_sinfo = [ + R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + ] + qkv_tuple_new = R.call_tir( + split_rotary_gvar, + (fused_qkv,), + out_sinfo=split_rotary_sinfo, + tir_vars=[rotary_embedding_offset], + ) + + embedded_query_new = qkv_tuple_new[0] + embedded_key_new = qkv_tuple_new[1] + value_new = qkv_tuple_new[2] + + # Reproduce intermediates + # + # These will most likely be removed by DCE, as they were + # previously intermediate values. However, just in case + # something else was using them, we should define them. + flat_query_new = relax.op.reshape( + qkv_tuple_new[0], [batch_size, seq_len, num_query_heads * head_dim] + ) + flat_key_new = relax.op.reshape( + qkv_tuple_new[1], [batch_size, seq_len, num_kv_heads * head_dim] + ) + flat_value_new = relax.op.reshape( + qkv_tuple_new[2], [batch_size, seq_len, num_kv_heads * head_dim] + ) + flat_qkv_tuple_new = relax.Tuple([flat_query_new, flat_key_new, flat_value_new]) + + return { + flat_qkv_tuple: flat_qkv_tuple_new, + flat_query: flat_query_new, + flat_key: flat_key_new, + flat_value: flat_value_new, + value: value_new, + embedded_query: embedded_query_new, + embedded_key: embedded_key_new, + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + assert "split_rotary" in str(new_mod["prefill"]) + return new_mod + + return ir_module_pass From 98ebd28da044cdfecfaefda85f9baf1558248de3 Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Thu, 12 Oct 2023 13:24:10 -0700 Subject: [PATCH 016/116] [Docs] Add `mlc.ai/package` to `DEPENDENCY INSTALLATION` group (#1055) Co-authored-by: Junru Shao --- docs/community/faq.rst | 3 +- docs/deploy/python.rst | 3 +- docs/deploy/rest.rst | 3 +- docs/index.rst | 5 +- docs/install/gpu.rst | 2 +- docs/install/mlc_llm.rst | 133 +++++++++++++++++++++++++++++++++++++++ docs/install/tvm.rst | 15 +---- 7 files changed, 142 insertions(+), 22 deletions(-) create mode 100644 docs/install/mlc_llm.rst diff --git a/docs/community/faq.rst b/docs/community/faq.rst index 45d73b4904..f426a0c624 100644 --- a/docs/community/faq.rst +++ b/docs/community/faq.rst @@ -13,5 +13,4 @@ This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free ... Why do I encounter an error ``free(): invalid pointer, Aborted (core dumped)`` at the end of model compilation? This happens if you compiled TVM-Unity from source and didn't hide LLVM symbols in cmake configurations. - Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, - or use our pre-builc MLC-AI pip wheels from `MLC Packages `__. + Please follow our instructions in :ref:`Building TVM Unity from Source ` tutorial to compile TVM-Unity which hides LLVM symbols, or use our pre-built MLC-LLM :doc:`pip wheels <../install/mlc_llm>`. diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index 0bf19c7b4c..b27d8ff935 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -11,8 +11,7 @@ We also provide a web demo based on `gradio `_ as an exampl Python API ---------- -The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels and you can install it by -following the instructions in ``_. +The Python API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels via the :doc:`installation page <../install/mlc_llm>`. Verify Installation ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 338f8de56c..95d57f491e 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -11,8 +11,7 @@ for user to interact with MLC-Chat in their own programs. Install MLC-Chat Package ------------------------ -The REST API is a part of the MLC-Chat package, which we have prepared pre-built pip wheels and you can install it by -following the instructions in ``_. +The REST API is a part of the MLC-Chat package, which we have prepared pre-built :doc:`pip wheels <../install/mlc_llm>`. Verify Installation ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/index.rst b/docs/index.rst index 89be4d4161..28a7d103ac 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,13 +11,13 @@ Getting Started --------------- To begin with, try out MLC LLM support for int4-quantized Llama2 7B. -It is recommended to have at least 4.5GB of free VRAM to run it. +It is recommended to have at least 6GB free VRAM to run it. .. tabs:: .. tab:: Python - **Install MLC Chat**. `MLC Chat `_ is available via pip. + **Install MLC Chat Python**. :doc:`MLC LLM ` is available via pip. It is always recommended to install it in an isolated conda virtual environment. **Download pre-quantized weights**. The comamnds below download the int4-quantized Llama2-7B from HuggingFace: @@ -209,6 +209,7 @@ It is recommended to have at least 4.5GB of free VRAM to run it. :hidden: install/tvm.rst + install/mlc_llm.rst install/conda.rst install/gpu.rst install/emcc.rst diff --git a/docs/install/gpu.rst b/docs/install/gpu.rst index 48ac7a5e1f..608c238265 100644 --- a/docs/install/gpu.rst +++ b/docs/install/gpu.rst @@ -105,7 +105,7 @@ After installation, you can run ``vulkaninfo`` in command line and see if you ca Vulkan SDK ---------- -Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our `pre-built wheels `__ already ships with Vulkan SDK. +Vulkan SDK is required for compiling models to Vulkan backend. To build TVM Unity compiler from source, you will need to install Vulkan SDK as a dependency, but our :doc:`pre-built wheels <../install/mlc_llm>` already ships with Vulkan SDK. Check Vulkan SDK installation guide according to your platform: diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst new file mode 100644 index 0000000000..07da0378e4 --- /dev/null +++ b/docs/install/mlc_llm.rst @@ -0,0 +1,133 @@ +.. _install-mlc-packages: + +Install MLC LLM Python Package +============================== + +.. contents:: Table of Contents + :local: + :depth: 2 + +MLC LLM Python Package can be installed directly from a prebuilt developer package, or built from source. + +Option 1. Prebuilt Package +-------------------------- + +We provide nightly built pip wheels for MLC-LLM via pip. +Select your operating system/compute platform and run the command in your terminal: + +.. note:: + ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts. + +.. tabs:: + + .. tab:: Linux + + .. tabs:: + + .. tab:: CPU + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. tab:: CUDA 11.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117 + + .. tab:: CUDA 11.8 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118 + + .. tab:: CUDA 12.1 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121 + + .. tab:: CUDA 12.2 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122 + + .. tab:: ROCm 5.6 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56 + + .. tab:: ROCm 5.7 + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57 + + .. tab:: Vulkan + + Supported in all Linux packages. + + .. note:: + + If encountering issues with GLIBC not found, please install the latest glibc in conda: + + .. code-block:: bash + + conda install -c conda-forge libgcc-ng + + .. tab:: macOS + + .. tabs:: + + .. tab:: CPU + Metal + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. note:: + + Always check if conda is installed properly in macOS using the command below: + + .. code-block:: bash + + conda info | grep platform + + It should return "osx-64" for Mac with Intel chip, and "osx-arm64" for Mac with Apple chip. + + .. tab:: Windows + + .. tabs:: + + .. tab:: CPU + Vulkan + + .. code-block:: bash + + conda activate your-environment + python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + + .. note:: + If encountering the error below: + + .. code-block:: bash + + FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. + + It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. + + +Option 2. Build from Source +--------------------------- + +Upcoming. \ No newline at end of file diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 97ec1c9e40..ea97025abf 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -97,14 +97,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. tabs:: - .. tab:: CPU - - .. code-block:: bash - - conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly - - .. tab:: Metal + .. tab:: CPU + Metal .. code-block:: bash @@ -125,17 +118,13 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. tabs:: - .. tab:: CPU + .. tab:: CPU + Vulkan .. code-block:: bash conda activate your-environment python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly - .. tab:: Vulkan - - Supported in all Windows packages. - .. note:: If encountering the error below: From bfaa5b9c0c38b7930e3394585ba1de878be70907 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 12 Oct 2023 17:40:54 -0400 Subject: [PATCH 017/116] Revert "[Transform] Apply split_rotary optimization on prefill (#1033)" (#1058) This reverts commit b9179cfdf02e041be15871b36e74400ab9001921 as elaborated here https://github.com/mlc-ai/mlc-llm/pull/1033#issuecomment-1760386712 --- mlc_llm/core.py | 3 +- .../transform/fuse_split_rotary_embedding.py | 460 ++++++++---------- 2 files changed, 203 insertions(+), 260 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index b9280274c9..ddf93bf09a 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -402,11 +402,12 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() mod = fuse_split_rotary_embedding( + mod, config.num_attention_heads // args.num_shards, num_key_value_heads // args.num_shards, config.hidden_size // args.num_shards, config.position_embedding_base, - )(mod) + ) if args.target_kind == "cuda": patterns = [] diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index d04f37ee69..4ecc843f4a 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -1,5 +1,5 @@ -import tvm from tvm import relax +from tvm.script import tir as T from tvm.relax.dpl import ( PatternContext, is_op, @@ -10,292 +10,234 @@ TuplePattern, is_shape, ) -from tvm.script import relax as R, tir as T +from tvm.script import relax as R -def get_dynamic_split_rotary(): - """Implementation of R.split(rotary_embedding(fused_qkv)) +def get_split_rotary(num_attention_heads, head_dim, position_embedding_base): + hidden_size = num_attention_heads * head_dim - Implementation is generic over the number of query heads, - key/value heads, sequence length, head dimension, and position - embedding base. These parameters can be replaced with static - values using `PrimFunc.specialize`. - """ - - @T.prim_func(private=True) + @T.prim_func def split_rotary( - fused_qkv_handle: T.handle, - embedded_query_handle: T.handle, - embedded_key_handle: T.handle, - value_handle: T.handle, - rotary_offset: T.int64, - batch_size: T.int64, - seq_len: T.int64, - num_query_heads: T.int64, - num_kv_heads: T.int64, - head_dim: T.int64, - position_embedding_base: T.float32, + qkv: T.handle, + split_0: T.handle, + split_1: T.handle, + split_2: T.handle, + n: T.int64, ): - Fused_QKV = T.match_buffer( - fused_qkv_handle, - [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], - dtype="float16", - ) - EmbeddedQuery = T.match_buffer( - embedded_query_handle, - [batch_size, seq_len, num_query_heads, head_dim], - dtype="float16", - ) - EmbeddedKey = T.match_buffer( - embedded_key_handle, - [batch_size, seq_len, num_kv_heads, head_dim], - dtype="float16", - ) - Value = T.match_buffer( - value_handle, - [batch_size, seq_len, num_kv_heads, head_dim], - dtype="float16", - ) + A = T.match_buffer(qkv, [1, 1, hidden_size * 3], dtype="float16") + T_split = T.match_buffer(split_0, [1, 1, hidden_size], dtype="float16") + T_split_1 = T.match_buffer(split_1, [1, 1, hidden_size], dtype="float16") + T_split_2 = T.match_buffer(split_2, [1, 1, hidden_size], dtype="float16") T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - - for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): - with T.block("FusedRotaryEmbeddingAndSplitQKV"): - batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) - pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) - + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(hidden_size)): + with T.block("T_split"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + A[v_ax0, v_ax1, v_ax2], + A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size)], + A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)], + ) + T.writes( + T_split[v_ax0, v_ax1, v_ax2], + T_split_1[v_ax0, v_ax1, v_ax2], + T_split_2[v_ax0, v_ax1, v_ax2], + ) + pos: T.float32 = T.Cast("float32", n - T.int64(1)) inv_freq: T.float32 = T.float32(1) / T.pow( - position_embedding_base, - T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), + T.float32(position_embedding_base), + T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), ) freq: T.float32 = pos * inv_freq cos_value: T.float16 = T.Cast("float16", T.cos(freq)) sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - - input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] - embedded_value = cos_value * input_value + sin_value * T.Select( - head_i < T.int64(head_dim // 2), - Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] + T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ + v_ax0, v_ax1, v_ax2 + ] + sin_value * T.Select( + T.int64(head_dim // 2) <= v_ax2 % head_dim, + A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], + A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), + ) + T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ + v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) + ] + sin_value * T.Select( + T.int64(head_dim // 2) <= v_ax2 % head_dim, + A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - T.int64(head_dim // 2)], + A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) + T.int64(head_dim // 2)] * T.float16(-1), - Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], ) - if head_num < num_query_heads: - EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value - elif head_num < num_query_heads + num_kv_heads: - EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value - else: - Value[ - batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i - ] = input_value - - param_sinfo = [] - for param in split_rotary.params: - if param in split_rotary.buffer_map: - buf = split_rotary.buffer_map[param] - sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) - else: - sinfo = relax.PrimStructInfo(param.dtype) - param_sinfo.append(sinfo) - - relax.expr._update_struct_info( - split_rotary, - tvm.relax.FuncStructInfo( - params=param_sinfo, - ret=relax.TupleStructInfo([]), - purity=False, - ), - ) + T_split_2[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)] return split_rotary -def fuse_split_rotary_embedding( - num_query_heads, num_kv_heads, hidden_size, position_embedding_base +def get_split_rotary_group_query_attention( + num_query_heads, num_kv_heads, head_dim, position_embedding_base ): - @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") - def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: - head_dim = hidden_size // num_query_heads - split_rotary = get_dynamic_split_rotary() - - ( - dyn_batch_size, - dyn_seq_len, - dyn_num_query_heads, - dyn_num_kv_heads, - dyn_head_dim, - dyn_position_embedding_base, - ) = split_rotary.params[-6:] - - split_rotary = split_rotary.specialize( - { - # Static model parameters - dyn_batch_size: T.int64(1), - dyn_num_query_heads: T.int64(num_query_heads), - dyn_num_kv_heads: T.int64(num_kv_heads), - dyn_head_dim: T.int64(head_dim), - dyn_position_embedding_base: T.float32(position_embedding_base), - # Dynamic parameters, to be inferred from TIR Buffer shapes - dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), - } - ) - - mod["split_rotary"] = split_rotary - - split_rotary_gvar = mod.get_global_var("split_rotary") - relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) - - with PatternContext() as ctx: - # flat_qkv_tuple: R.Tuple( - # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), - # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), - # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), - # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) - # - # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] - # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( - # flat_query, R.shape([batch_size, seq_len, 32, 128]) - # ) - # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] - # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( - # flat_key, R.shape([batch_size, seq_len, 32, 128]) - # ) - # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] - # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( - # flat_value, R.shape([batch_size, seq_len, 32, 128]) - # ) - # embedded_query = R.call_tir( - # cls.rotary_embedding1, - # [query], - # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), - # tir_vars=R.shape([n]), - # ) - # embedded_key = R.call_tir( - # cls.rotary_embedding1, - # [key], - # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), - # tir_vars=R.shape([n]), - # ) - - pat_rotary_embedding_gvar = GlobalVarPattern("rotary_embedding") | GlobalVarPattern( - "rotary_embedding1" - ) - - pat_flat_fused_qkv = wildcard() - offset = wildcard() - - # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) - pat_query_shape = wildcard() - # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) - pat_key_shape = wildcard() - # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) - pat_value_shape = wildcard() - - pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) - pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) - pat_query = is_op("relax.reshape")( - pat_flat_query, pat_query_shape, add_constraint=False - ) - pat_flat_query.used_by(pat_query) - pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) - pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) - pat_flat_key.used_by(pat_key) - pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) - pat_value = is_op("relax.reshape")( - pat_flat_value, pat_value_shape, add_constraint=False - ) - pat_flat_value.used_by(pat_value) - - pat_embedded_query = is_op("relax.call_tir")( - pat_rotary_embedding_gvar, TuplePattern([pat_query]), offset, add_constraint=False - ) - pat_embedded_key = is_op("relax.call_tir")( - pat_rotary_embedding_gvar, TuplePattern([pat_key]), offset, add_constraint=False - ) - - pat_flat_qkv_tuple.used_by(pat_flat_query) - pat_flat_qkv_tuple.used_by(pat_flat_key) - pat_flat_qkv_tuple.used_by(pat_flat_value) - pat_query.used_by(pat_embedded_query) - pat_key.used_by(pat_embedded_key) - - def rewriter(matchings, bindings): - # Extracting all the relax and TIR variables that we'll need - flat_fused_qkv = matchings[pat_flat_fused_qkv] - - flat_qkv_tuple = matchings[pat_flat_qkv_tuple] - - flat_query = matchings[pat_flat_query] - flat_key = matchings[pat_flat_key] - flat_value = matchings[pat_flat_value] - - query = matchings[pat_query] - key = matchings[pat_key] - value = matchings[pat_value] + query_hidden_size = num_query_heads * head_dim + kv_hidden_size = num_kv_heads * head_dim + total_size = query_hidden_size + kv_hidden_size * 2 - embedded_query = matchings[pat_embedded_query] - embedded_key = matchings[pat_embedded_key] + @T.prim_func + def split_rotary( + qkv: T.handle, + split_0: T.handle, + split_1: T.handle, + split_2: T.handle, + n: T.int64, + ): + A = T.match_buffer(qkv, [1, 1, total_size], dtype="float16") + T_split = T.match_buffer(split_0, [1, 1, query_hidden_size], dtype="float16") + T_split_1 = T.match_buffer(split_1, [1, 1, kv_hidden_size], dtype="float16") + T_split_2 = T.match_buffer(split_2, [1, 1, kv_hidden_size], dtype="float16") - rotary_embedding_offset = bindings[query].args[-1][1] + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(query_hidden_size)): + with T.block("T_split"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + A[v_ax0, v_ax1, v_ax2], + ) + T.writes(T_split[v_ax0, v_ax1, v_ax2]) + pos: T.float32 = T.Cast("float32", n - T.int64(1)) + inv_freq: T.float32 = T.float32(1) / T.pow( + T.float32(position_embedding_base), + T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), + ) + freq: T.float32 = pos * inv_freq + cos_value: T.float16 = T.Cast("float16", T.cos(freq)) + sin_value: T.float16 = T.Cast("float16", T.sin(freq)) + T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ + v_ax0, v_ax1, v_ax2 + ] + sin_value * T.Select( + T.int64(head_dim // 2) <= v_ax2 % head_dim, + A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], + A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), + ) + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(kv_hidden_size)): + with T.block("T_split"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size)], + A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size)], + ) + T.writes( + T_split_1[v_ax0, v_ax1, v_ax2], + T_split_2[v_ax0, v_ax1, v_ax2], + ) + pos: T.float32 = T.Cast("float32", n - T.int64(1)) + inv_freq: T.float32 = T.float32(1) / T.pow( + T.float32(position_embedding_base), + T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), + ) + freq: T.float32 = pos * inv_freq + cos_value: T.float16 = T.Cast("float16", T.cos(freq)) + sin_value: T.float16 = T.Cast("float16", T.sin(freq)) + T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ + v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) + ] + sin_value * T.Select( + T.int64(head_dim // 2) <= v_ax2 % head_dim, + A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - T.int64(head_dim // 2)], + A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) + T.int64(head_dim // 2)] + * T.float16(-1), + ) + T_split_2[v_ax0, v_ax1, v_ax2] = A[ + v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size) + ] - batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape - _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape + return split_rotary - # Rewriting along the new path - fused_qkv = relax.op.reshape( - flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] - ) +def fuse_split_rotary_embedding( + mod, num_query_heads, num_kv_heads, hidden_size, position_embedding_base +): + head_dim = hidden_size // num_query_heads + mod["split_rotary"] = ( + get_split_rotary(num_query_heads, head_dim, position_embedding_base) + if num_query_heads == num_kv_heads + else get_split_rotary_group_query_attention( + num_query_heads, num_kv_heads, head_dim, position_embedding_base + ) + ) - split_rotary_sinfo = [ - R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), - R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), - R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), - ] - qkv_tuple_new = R.call_tir( - split_rotary_gvar, - (fused_qkv,), - out_sinfo=split_rotary_sinfo, - tir_vars=[rotary_embedding_offset], - ) + gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(gvar, mod.get_global_var("rotary_embedding1").struct_info) - embedded_query_new = qkv_tuple_new[0] - embedded_key_new = qkv_tuple_new[1] - value_new = qkv_tuple_new[2] + with PatternContext() as ctx: + # lv3: R.Tuple(R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")) = R.split(lv2, indices_or_sections=[4096, 8192], axis=2) - # Reproduce intermediates - # - # These will most likely be removed by DCE, as they were - # previously intermediate values. However, just in case - # something else was using them, we should define them. - flat_query_new = relax.op.reshape( - qkv_tuple_new[0], [batch_size, seq_len, num_query_heads * head_dim] - ) - flat_key_new = relax.op.reshape( - qkv_tuple_new[1], [batch_size, seq_len, num_kv_heads * head_dim] - ) - flat_value_new = relax.op.reshape( - qkv_tuple_new[2], [batch_size, seq_len, num_kv_heads * head_dim] - ) - flat_qkv_tuple_new = relax.Tuple([flat_query_new, flat_key_new, flat_value_new]) + # lv1521: R.Tensor((1, 1, 4096), dtype="float16") = lv3[0] + # lv1522: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1521, R.shape([1, 1, 32, 128])) + # lv1524: R.Tensor((1, 1, 4096), dtype="float16") = lv3[1] + # lv1525: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1524, R.shape([1, 1, 32, 128])) + # lv1527: R.Tensor((1, 1, 4096), dtype="float16") = lv3[2] + # lv1528: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1527, R.shape([1, 1, 32, 128])) + # lv1530 = R.call_tir(cls.rotary_embedding1, (lv1525, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape([n])) + # lv_1 = R.call_tir(cls.rotary_embedding1, (lv1522, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape( - return { - flat_qkv_tuple: flat_qkv_tuple_new, - flat_query: flat_query_new, - flat_key: flat_key_new, - flat_value: flat_value_new, - value: value_new, - embedded_query: embedded_query_new, - embedded_key: embedded_key_new, - } + inp_pat = wildcard() + offset = wildcard() - new_mod = {} - for gvar, func in mod.functions.items(): - if isinstance(func, relax.Function): - func = rewrite_bindings(ctx, rewriter, func) - new_mod[gvar] = func + lv3 = is_op("relax.split")(inp_pat) + lv1521 = is_tuple_get_item(lv3, 0) + lv1522 = is_op("relax.reshape")( + lv1521, is_shape([1, 1, num_query_heads, head_dim]), add_constraint=False + ) + lv1521.used_by(lv1522) + lv1524 = is_tuple_get_item(lv3, 1) + lv1525 = is_op("relax.reshape")( + lv1524, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False + ) + lv1524.used_by(lv1525) + lv1527 = is_tuple_get_item(lv3, 2) + V = is_op("relax.reshape")( + lv1527, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False + ) + lv1527.used_by(V) - new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) - assert "split_rotary" in str(new_mod["prefill"]) - return new_mod + Q = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern([lv1522]), offset, add_constraint=False + ) + K = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern([lv1525]), offset, add_constraint=False + ) - return ir_module_pass + lv3.used_by(lv1521) + lv3.used_by(lv1524) + lv3.used_by(lv1527) + lv1522.used_by(Q) + lv1525.used_by(K) + + def rewriter(matchings, bindings): + inp = matchings[inp_pat] + call_tir = matchings[Q] + n = bindings[call_tir].args[-1] + out_sinfo = [ + R.Tensor((1, 1, num_query_heads * head_dim), dtype="float16"), + R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), + R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), + ] + lv3_new = R.call_tir( + mod.get_global_var("split_rotary"), (inp,), out_sinfo=out_sinfo, tir_vars=n + ) + lv1521_new = lv3_new[0] + lv1522_new = R.reshape(lv1521_new, R.shape([1, 1, num_query_heads, head_dim])) + lv1524_new = lv3_new[1] + lv1525_new = R.reshape(lv1524_new, R.shape([1, 1, num_kv_heads, head_dim])) + lv1527_new = lv3_new[2] + lv1528_new = R.reshape(lv1527_new, R.shape([1, 1, num_kv_heads, head_dim])) + + return { + matchings[lv3]: lv3_new, + matchings[lv1521]: lv1521_new, + matchings[lv1522]: lv1522_new, + matchings[lv1524]: lv1524_new, + matchings[lv1525]: lv1525_new, + matchings[lv1527]: lv1527_new, + matchings[V]: lv1528_new, + matchings[Q]: lv1522_new, + matchings[K]: lv1525_new, + } + + mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"]) + return mod From ca8c11b6b304de0e3529a17bdc4ee1d8d2500fd3 Mon Sep 17 00:00:00 2001 From: Sunghyun Park Date: Fri, 13 Oct 2023 09:00:21 -0700 Subject: [PATCH 018/116] [BugFix] Set the right `max_sequence_length` for both Llama-1 and Llama-2 families (#1032) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix * reflect feedback --------- Co-authored-by: “Sunghyun --- mlc_llm/relax_model/llama.py | 42 +++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 3e9f3ef7b4..bcd2a99004 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -817,26 +817,42 @@ def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype - max_seq_len = args.max_seq_len sep_embed = args.sep_embed position_embedding_base = 10000 max_position_embeddings = 2048 if "rope_theta" in hf_config: position_embedding_base = hf_config["rope_theta"] - if "max_position_embeddings" in hf_config: - max_position_embeddings = hf_config["max_position_embeddings"] - config = LlamaConfig( - **hf_config, - dtype=dtype, - position_embedding_base=position_embedding_base, - combine_matmul=True, - num_shards=args.num_shards, - build_model_only=args.build_model_only, - ) - if max_seq_len != -1: - config.max_sequence_length = max_seq_len + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception("The model config should contain information about maximum sequence length.") + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len param_manager = ParamManager() bb = relax.BlockBuilder() From edab9b57ec75ab66cc573afe024f91d47b631913 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 13 Oct 2023 09:57:46 -0700 Subject: [PATCH 019/116] [Doc] Use -U instead of --force-reinstall (#1062) `--force-reinstall` will reinstall all dependencies to a python package, which is unnecessary. `-U` is a better choice in this case. --- docs/install/mlc_llm.rst | 18 +++++++++--------- docs/install/tvm.rst | 18 +++++++++--------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 07da0378e4..f95cc3ee9c 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -29,49 +29,49 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly .. tab:: CUDA 11.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu117 mlc-ai-nightly-cu117 .. tab:: CUDA 11.8 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu118 mlc-ai-nightly-cu118 .. tab:: CUDA 12.1 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu121 mlc-ai-nightly-cu121 .. tab:: CUDA 12.2 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-cu122 mlc-ai-nightly-cu122 .. tab:: ROCm 5.6 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm56 mlc-ai-nightly-rocm56 .. tab:: ROCm 5.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly-rocm57 mlc-ai-nightly-rocm57 .. tab:: Vulkan @@ -94,7 +94,7 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly .. note:: @@ -115,7 +115,7 @@ Select your operating system/compute platform and run the command in your termin .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-chat-nightly mlc-ai-nightly .. note:: If encountering the error below: diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index ea97025abf..0dc716258d 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -37,49 +37,49 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. tab:: CUDA 11.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu117 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu117 .. tab:: CUDA 11.8 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu118 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu118 .. tab:: CUDA 12.1 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu121 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu121 .. tab:: CUDA 12.2 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cu122 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu122 .. tab:: ROCm 5.6 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm56 .. tab:: ROCm 5.7 .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm57 .. tab:: Vulkan @@ -102,7 +102,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: @@ -123,7 +123,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. .. code-block:: bash conda activate your-environment - python3 -m pip install --pre --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly + python3 -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly .. note:: If encountering the error below: From d8541050cb42c1c026265c878c94193057876a4d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Fri, 13 Oct 2023 20:45:58 -0400 Subject: [PATCH 020/116] [Model] Initial batching support for Llama (#1048) This PR introduces the initial batched input support for llama models. To make the code managable, we keep both the single-sequence handling flow and the batching handling flow in the Llama modeling. Now, with `--enable-batching` as a build argument, we build Llama for the batched version. NOTE: The paged attention kernel/TIR func are not included in this PR, so currently the built library with batching enabled is not runnable. We will follow up with the attention kernel in the future. This PR guarantees that the existing single-sequence inference (Python API, CLI, etc.) is not broken. P.S.. The batching flow is subject to bug fixes as we integrate with the attention function and run the e2e flow in the future. --- cpp/llm_chat.cc | 15 +- mlc_llm/core.py | 35 +- mlc_llm/relax_model/llama.py | 453 ++++++++++++++++-- mlc_llm/transform/decode_take.py | 36 +- .../transform/fuse_split_rotary_embedding.py | 3 + tests/debug/test_batching_llama.py | 160 +++++++ 6 files changed, 614 insertions(+), 88 deletions(-) create mode 100644 tests/debug/test_batching_llama.py diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 25f68203f2..7286cda7b0 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -481,7 +481,7 @@ class LLMChat { // Step 6. KV cache creation. this->kv_cache_ = ft_.create_kv_cache_func_(); // Step 7. Pre-allocate fixed size ndarray - this->temperature_arr_ = NDArray::Empty({}, DataType::Float(32), device_); + this->temperature_arr_ = NDArray::Empty({1}, DataType::Float(32), device_); float temperature = static_cast(this->temperature_); this->temperature_arr_.CopyFromBytes(&temperature, sizeof(float)); if (ft_.use_disco) { @@ -947,19 +947,18 @@ class LLMChat { // the generation_config will not override the original config // since is only used for this generation double gen_temperature; - NDArray gen_temperature_arr; double gen_repetition_penalty; double gen_top_p; if (generation_config.count("temperature")) { CHECK(generation_config["temperature"].is()); gen_temperature = generation_config["temperature"].get(); - - gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_); - float temperature_cast = static_cast(gen_temperature); - gen_temperature_arr.CopyFromBytes(&temperature_cast, sizeof(float)); + if (gen_temperature != this->temperature_) { + this->temperature_ = gen_temperature; + float temperature_cast = static_cast(gen_temperature); + this->temperature_arr_.CopyFromBytes(&temperature_cast, sizeof(float)); + } } else { gen_temperature = this->temperature_; - gen_temperature_arr = this->temperature_arr_; } if (generation_config.count("repetition_penalty")) { CHECK(generation_config["repetition_penalty"].is()); @@ -979,7 +978,7 @@ class LLMChat { if (gen_temperature < 1e-6f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); } else { - this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, gen_temperature_arr)); + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); } } else { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); diff --git a/mlc_llm/core.py b/mlc_llm/core.py index ddf93bf09a..2e6eefad4d 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -79,6 +79,11 @@ class BuildArgs: Build with separated embedding layer, only applicable to LlaMa. This feature is in testing stage, and will be formally replaced after massive overhaul of embedding feature for all models and use cases. + enable_batching: bool + Build the model for batched inference. + This is a temporary flag used to control the model execution flow in single- + sequence and batching settings for now. We will eventually merge two flows + in the future and remove this flag then. """ model: str = field( default="auto", @@ -180,21 +185,29 @@ class BuildArgs: "action": "store_true", }, ) - no_cutlass_attn: bool = field( + enable_batching: bool = field( default=False, metadata={ "help": ( - "Disable offloading attention operations to CUTLASS." + "Build the model for batched inference." + "This is a temporary flag used to control the model execution flow in single-" + "sequence and batching settings for now. We will eventually merge two flows" + "in the future and remove this flag then." ), "action": "store_true", }, ) + no_cutlass_attn: bool = field( + default=False, + metadata={ + "help": ("Disable offloading attention operations to CUTLASS."), + "action": "store_true", + }, + ) no_cutlass_norm: bool = field( default=False, metadata={ - "help": ( - "Disable offloading layer and RMS norm operations to CUTLASS." - ), + "help": ("Disable offloading layer and RMS norm operations to CUTLASS."), "action": "store_true", }, ) @@ -231,9 +244,7 @@ class BuildArgs: use_flash_attn_mqa: bool = field( default=False, metadata={ - "help": ( - "Offload multi-query attention workload to Flash Attention." - ), + "help": ("Offload multi-query attention workload to Flash Attention."), "action": "store_true", }, ) @@ -380,6 +391,8 @@ def mod_transform_before_build( ] if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] + if args.enable_batching: + model_names[2] = "decode_with_embed" if args.model.lower().startswith("rwkv-"): model_names += ["reset_kv_cache"] @@ -458,7 +471,7 @@ def mod_transform_before_build( ), annotate_workspace, relax.transform.AllocateWorkspace(), - relax.transform.RunCodegen(options, entry_functions=model_names) + relax.transform.RunCodegen(options, entry_functions=model_names), ] )(mod) @@ -558,7 +571,9 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None: with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": use_cuda_graph}): # The num_input attribute is needed to capture transformed weights passed as input # into a cuda graph. - mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) + # NOTE: CUDA graph for batching is not enabled and is left as a TODO item. + if not args.enable_batching: + mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3}) ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib) output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}" diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index bcd2a99004..eea1bb05bc 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1,10 +1,10 @@ import math from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, Union import numpy as np import tvm -from tvm import relax, te +from tvm import relax, te, tir from tvm.relax.op import ccl from tvm.relax.testing import nn from tvm.script import relax as R @@ -217,7 +217,7 @@ def rotary_compute(*idx): return q_embed, k_embed -class LlamaAttention(nn.Module): +class LlamaAttentionBase(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig): @@ -271,23 +271,14 @@ def __init__(self, config: LlamaConfig): def forward( self, hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: - from tvm.relax.op import ( - astype, - matmul, - maximum, - permute_dims, - reshape, - split, - squeeze, - ) - from tvm.relax.op.nn import softmax + ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: + from tvm.relax.op import reshape, split bsz, q_len, _ = hidden_states.struct_info.shape - assert bsz == 1, "Only support batch size 1 at this moment." if self.combine_matmul: qkv_states = nn.emit( @@ -327,7 +318,121 @@ def forward( ), ) - kv_seq_len = all_seq_len_shape.struct_info.values[0] + attn_output, past_key_values = self.attention_fwd( + query_states, + key_states, + value_states, + past_key_values, + bsz, + q_len, + layer_id=layer_id, + all_seq_len_shape=all_seq_len_shape, + attention_mask=attention_mask, + ) + + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + attn_output = self.o_proj(attn_output) + return attn_output, past_key_values + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ): + raise NotImplementedError() + + +class LlamaPagedAttention(LlamaAttentionBase): + def __init__(self, config: LlamaConfig): + super().__init__(config) + ctx_mod = relax.BlockBuilder.current().get() + self.kv_cache_transpose_append = ctx_mod.get_global_var("kv_cache_transpose_append") + self.attention_compute = ctx_mod.get_global_var("attention") + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ) -> Tuple[relax.Expr, relax.Expr]: + assert "layer_id" in kwargs and isinstance(kwargs["layer_id"], int) + layer_id = kwargs["layer_id"] + + f_kv_cache_append = relax.extern("vm.builtin.paged_attention_kv_cache_append") + past_key_values = nn.emit( + relax.call_pure_packed( + f_kv_cache_append, + past_key_values, + self.kv_cache_transpose_append, + key_states, + value_states, + relax.PrimValue(layer_id), + sinfo_args=relax.ObjectStructInfo(), + ) + ) + + f_kv_cache_attention = relax.extern("vm.builtin.paged_attention_kv_cache_attention") + attn_output = nn.emit( + relax.call_dps_packed( + f_kv_cache_attention, + [ + past_key_values, + self.attention_compute, + query_states, + relax.PrimValue(layer_id), + True, + 1.0, + self.position_embedding_base, + ], + out_sinfo=relax.TensorStructInfo( + ((batch_size, q_len, self.num_query_heads, self.head_dim)), + query_states.struct_info.dtype, + ), + ) + ) + return attn_output, past_key_values + + +class LlamaAttention(LlamaAttentionBase): + def __init__(self, config: LlamaConfig): + super().__init__(config) + + def attention_fwd( + self, + query_states: relax.Expr, + key_states: relax.Expr, + value_states: relax.Expr, + past_key_values: relax.Expr, + batch_size: tir.PrimExpr, + q_len: tir.PrimExpr, + **kwargs, + ) -> Tuple[relax.Expr, Tuple[relax.Expr]]: + assert "attention_mask" in kwargs + assert "all_seq_len_shape" in kwargs + attention_mask = kwargs["attention_mask"] + kv_seq_len = kwargs["all_seq_len_shape"].struct_info.values[0] + + from tvm.relax.op import ( + astype, + matmul, + maximum, + permute_dims, + reshape, + squeeze, + ) + from tvm.relax.op.nn import softmax + offset = kv_seq_len - q_len query_states, key_states = apply_rotary_pos_emb( query_states, @@ -347,7 +452,7 @@ def forward( squeezed_key = nn.emit(squeeze(key_states, axis=0)) squeezed_value = nn.emit(squeeze(value_states, axis=0)) - k_cache, v_cache = past_key_value + k_cache, v_cache = past_key_values f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append") k_cache = nn.emit( relax.Call( @@ -363,7 +468,7 @@ def forward( sinfo_args=[relax.ObjectStructInfo()], ) ) - past_key_value = (k_cache, v_cache) + past_key_values = (k_cache, v_cache) f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") k_cache = nn.emit( relax.Call( @@ -397,7 +502,7 @@ def forward( tvm.ir.assert_structural_equal( attention_mask.struct_info.shape.values, - (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + (batch_size, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), ) attn_weights = nn.emit( @@ -420,18 +525,14 @@ def forward( attn_output = nn.emit(matmul(attn_weights, value_states)) attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) - attn_output = nn.emit( - reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) - ) - - attn_output = self.o_proj(attn_output) - return attn_output, ((None, None) if past_key_value is None else past_key_value) + return attn_output, past_key_values class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__(self, config: LlamaConfig, enable_batching: bool): + attn_class = LlamaPagedAttention if enable_batching else LlamaAttention self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config) + self.self_attn = attn_class(config) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm( config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps @@ -443,8 +544,9 @@ def __init__(self, config: LlamaConfig): def forward( self, hidden_states: relax.Expr, - all_seq_len_shape: relax.Expr, - past_key_value: Tuple[relax.Expr], + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, attention_mask: Optional[relax.Expr] = None, ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: residual = hidden_states @@ -454,9 +556,10 @@ def forward( # Self Attention hidden_states, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, all_seq_len_shape=all_seq_len_shape, + layer_id=layer_id, ) if self.self_attn.num_shards > 1: residual = nn.emit( @@ -481,7 +584,7 @@ def forward( def _make_causal_mask(input_ids_shape, dtype, src_len): - from tvm.relax.op import broadcast_to, full, triu + from tvm.relax.op import broadcast_to bsz, tgt_len = input_ids_shape @@ -530,8 +633,14 @@ def forward(self, input_ids: relax.Expr): return inputs_embeds -class LlamaModel(nn.Module): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): +class LlamaModelBase(nn.Module): + def __init__( + self, + config: LlamaConfig, + vocab_size_var: tir.Var, + sep_embed: bool = False, + enable_batching: bool = False, + ): self.num_shards = config.num_shards self.padding_idx = config.pad_token_id self.embed_tokens = None @@ -540,10 +649,23 @@ def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) self.layers = ModuleList( - [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)] + [LlamaDecoderLayer(config, enable_batching) for _ in range(config.num_hidden_layers)] ) self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + raise NotImplementedError() + + +class LlamaModelForSingleSequence(LlamaModelBase): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + super().__init__(config, vocab_size_var, sep_embed, enable_batching=False) + def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -566,7 +688,7 @@ def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): def forward( self, inputs: relax.Expr, - all_seq_len_shape: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], past_key_values: relax.Expr, ): if self.num_shards > 1: @@ -597,8 +719,9 @@ def forward( hidden_states, key_value_cache = decoder_layer( hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_value, all_seq_len_shape=all_seq_len_shape, + layer_id=idx, ) next_decoder_cache += key_value_cache @@ -608,9 +731,51 @@ def forward( return hidden_states, next_decoder_cache +class LlamaModelForBatching(LlamaModelBase): + def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool): + assert sep_embed + super().__init__(config, vocab_size_var, sep_embed=True, enable_batching=True) + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: relax.Expr, + ): + assert all_seq_len_shape is None + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + hidden_states, past_key_values = decoder_layer( + hidden_states, + attention_mask=None, + past_key_values=past_key_values, + all_seq_len_shape=all_seq_len_shape, + layer_id=idx, + ) + + hidden_states = self.norm(hidden_states) + return hidden_states, past_key_values + + class LlamaForCausalLM(nn.Module): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): - self.model = LlamaModel(config, vocab_size_var, sep_embed) + def __init__( + self, + config: LlamaConfig, + vocab_size_var: tvm.tir.Var, + sep_embed: bool = False, + enable_batching: bool = False, + ): + model_class = LlamaModelForBatching if enable_batching else LlamaModelForSingleSequence + self.model = model_class(config, vocab_size_var, sep_embed) self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) ############ Rotary embedding constants ############ @@ -627,7 +792,7 @@ def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.Var, sep_embed: def forward( self, inputs: relax.Expr, - all_seq_len_shape: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], past_key_values: relax.Expr, ): hidden_states, key_value_cache = self.model( @@ -637,8 +802,9 @@ def forward( ) def te_slicing(x: te.Tensor): + assert x.ndim == 3 return te.compute( - shape=(1, 1, x.shape[-1]), + shape=(x.shape[0], 1, x.shape[2]), fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], name="slice", ) @@ -669,7 +835,7 @@ def create_embed_func( ) -> None: func_name = "embed" - bsz = 1 + bsz = tvm.tir.Var("nseq", "int64") seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): model = LlamaEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) @@ -687,7 +853,7 @@ def create_embed_func( bb.update_func(gv, mod[gv].with_attr("num_input", 1)) -def create_encoding_func( +def create_prefill_func_for_single_seq( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, @@ -701,7 +867,9 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): - model = LlamaForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) + model = LlamaForCausalLM( + config, tvm.tir.Var("vocab_size", "int64"), sep_embed, enable_batching=False + ) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = ( @@ -733,7 +901,43 @@ def create_encoding_func( bb.update_func(gv, mod[gv].with_attr("num_input", 3)) -def create_decoding_func( +def create_prefill_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "prefill_with_embed" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.Var("vocab_size", "int64"), sep_embed=True, enable_batching=True + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder( + (bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds" + ) + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, + all_seq_len_shape=None, + past_key_values=past_key_values, + ) + params = [inputs, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + + +def create_decoding_func_for_single_seq( bb: relax.BlockBuilder, param_manager: ParamManager, config: LlamaConfig, @@ -773,6 +977,37 @@ def create_decoding_func( bb.update_func(gv, mod[gv].with_attr("num_input", 3)) +def create_decoding_func_for_batching( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode_with_embed" + + bsz = tir.Var("nseq", "int64") + hidden_size = config.hidden_size + with bb.function(func_name): + model = LlamaForCausalLM( + config, tvm.tir.Var("vocab_size", "int64"), sep_embed=True, enable_batching=True + ) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = nn.Placeholder((bsz, 1, hidden_size), dtype=config.dtype, name="inputs_embeds") + past_key_values = relax.Var("kv_cache", relax.ObjectStructInfo()) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, all_seq_len_shape=None, past_key_values=past_key_values + ) + params = [inputs, past_key_values] + model.parameters() + gv = bb.emit_output((logits, key_value_cache)) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 2)) + + def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: num_key_value_heads = config.get_num_key_value_heads() // config.num_shards init_shape = relax.ShapeExpr( @@ -801,24 +1036,130 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: bb.emit_func_output(gv) +def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + head_dim = config.hidden_size // config.num_attention_heads + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + + page_size = tir.Var("page_size", "int64") + total_seq_len = tir.Var("total_seq_len", "int64") + reserved_nseq = tir.Var("reserved_nseq", "int64") + cache_config = relax.Var( + "cache_config", + relax.ShapeStructInfo([reserved_nseq, total_seq_len, page_size]), + ) + + with bb.function("create_kv_cache", [cache_config]): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros((), config.dtype)) + f_kv_cache_create = relax.extern("vm.builtin.paged_attention_kv_cache_create") + cache = bb.emit_output( + relax.Call( + f_kv_cache_create, + args=[ + cache_config, + relax.PrimValue(config.num_hidden_layers), + relax.PrimValue(num_key_value_heads), + relax.PrimValue(head_dim), + zeros, + ], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + bb.emit_func_output(cache) + + def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: with bb.function("softmax_with_temperature"): + bsz = tvm.tir.Var("nseq", "int64") logits = nn.Placeholder( - (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + (bsz, 1, tvm.tir.Var("vocab_size", "int64")), + dtype="float32", + name="logits", ) - temperature = nn.Placeholder((), dtype="float32", name="temperature") + temperature = nn.Placeholder((bsz,), dtype="float32", name="temperature") with bb.dataflow(): - div = bb.emit(relax.op.divide(logits, temperature)) + t_reshaped = bb.emit(relax.op.reshape(temperature, (bsz, 1, 1))) + div = bb.emit(relax.op.divide(logits, t_reshaped)) softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) gv = bb.emit_output(softmax) bb.emit_func_output(gv, [logits, temperature]) +def emit_paged_kv_cache_op(bb: relax.BlockBuilder, dtype: str) -> None: + from tvm.script import tir as T + + # fmt: off + @T.prim_func + def kv_cache_transpose_append( + var_pages: T.handle, + var_k_data: T.handle, + var_v_data: T.handle, + var_page_table_indptr: T.handle, + var_page_table_values: T.handle, + var_last_page_offset: T.handle, + var_append_length_indptr: T.handle, + var_pos2seqidx: T.handle, + layer_id: T.int32, + ): + nseq = T.int32() + ntoken = T.int32() + nhead = T.int32() + nfeat = T.int32() + nlayer = T.int32() + npage = T.int32() + page_size = T.int32() + num_pages = T.int32() + + pages = T.match_buffer(var_pages, (num_pages, nlayer, 2, nhead, page_size, nfeat), dtype) + k_data = T.match_buffer(var_k_data, (ntoken, nhead, nfeat), dtype) + v_data = T.match_buffer(var_v_data, (ntoken, nhead, nfeat), dtype) + last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32") + page_table_indptr = T.match_buffer(var_page_table_indptr, (nseq + 1,), "int32") + page_table_values = T.match_buffer(var_page_table_values, (npage,), "int32") + append_length_indptr = T.match_buffer(var_append_length_indptr, (nseq + 1,), "int32") + pos2seqidx = T.match_buffer(var_pos2seqidx, (ntoken,), "int32") + + for global_pos, h, f in T.grid(ntoken, nhead, nfeat): + with T.block("k_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], + layer_id, + 0, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = k_data[vgpos, vh, vf] + with T.block("v_transpose_append"): + vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f]) + seq_idx = pos2seqidx[vgpos] + seqlen: T.int32 = (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx] + pages[ + page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)], + layer_id, + 1, + vh, + T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size), + vf, + ] = v_data[vgpos, vh, vf] + # fmt: on + + bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append") + # Todo: integrating attention TIR func/kernel. + bb.add_func(relax.extern("attention_func"), "attention") + + def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype + enable_batching = args.enable_batching sep_embed = args.sep_embed + if enable_batching and not sep_embed: + raise ValueError("`sep_embed` is required when batching is enabled.") + position_embedding_base = 10000 max_position_embeddings = 2048 if "rope_theta" in hf_config: @@ -859,9 +1200,17 @@ def get_model(args, hf_config): if sep_embed: create_embed_func(bb, param_manager, config, args.quantization) - create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) - create_decoding_func(bb, param_manager, config, args.quantization) - create_kv_cache_func(bb, config) + + if enable_batching: + emit_paged_kv_cache_op(bb, dtype) + create_prefill_func_for_batching(bb, param_manager, config, args.quantization) + create_decoding_func_for_batching(bb, param_manager, config, args.quantization) + create_paged_kv_cache_func(bb, config) + else: + create_prefill_func_for_single_seq(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func_for_single_seq(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) create_metadata_func( bb, diff --git a/mlc_llm/transform/decode_take.py b/mlc_llm/transform/decode_take.py index ece1c7ab23..cd09771126 100644 --- a/mlc_llm/transform/decode_take.py +++ b/mlc_llm/transform/decode_take.py @@ -17,7 +17,7 @@ def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint -def decode_take_pattern(n_aux_tensor: int): +def decode_take_pattern(n_aux_tensor: int, match_tir_vars: bool): aux_tensors = [wildcard(), wildcard(), wildcard()] decode = is_op("relax.call_tir")( GlobalVarPattern(), @@ -26,9 +26,10 @@ def decode_take_pattern(n_aux_tensor: int): ) indices = ~is_const() take_args = [decode, indices] - take = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern(take_args), add_constraint=False - ) + call_tir_args_take = [GlobalVarPattern(), TuplePattern(take_args)] + if match_tir_vars: + call_tir_args_take.append(wildcard()) + take = is_op("relax.call_tir")(*call_tir_args_take, add_constraint=False) annotations = { "take": take, @@ -41,18 +42,17 @@ def decode_take_pattern(n_aux_tensor: int): @tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") class FuseDecodeTake: - def transform_module( - self, mod: IRModule, ctx: tvm.transform.PassContext - ) -> IRModule: + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: for n_aux_tensor in [2, 3]: - mod = relax.transform.FuseOpsByPattern( - [ - ( - "decode_take", - *decode_take_pattern(n_aux_tensor), - ) - ] - )(mod) + for match_tir_vars in [False, True]: + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_take", + *decode_take_pattern(n_aux_tensor, match_tir_vars), + ) + ] + )(mod) mod = relax.transform.FuseTIR()(mod) for gv, func in mod.functions.items(): @@ -61,9 +61,9 @@ def transform_module( if "fused_decode" not in gv.name_hint or "take" not in gv.name_hint: continue - downcasted_mod = tir.transform.ForceNarrowIndexToInt32()( - tvm.IRModule({"main": func}) - )["main"] + downcasted_mod = tir.transform.ForceNarrowIndexToInt32()(tvm.IRModule({"main": func}))[ + "main" + ] sch = tir.Schedule(downcasted_mod) sch.compute_inline("decode") mod[gv] = sch.mod["main"] diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index 4ecc843f4a..a7dbdf6c31 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -151,6 +151,9 @@ def split_rotary( def fuse_split_rotary_embedding( mod, num_query_heads, num_kv_heads, hidden_size, position_embedding_base ): + if "rotary_embedding1" not in [gv.name_hint for gv in mod.functions]: + return mod + head_dim = hidden_size // num_query_heads mod["split_rotary"] = ( get_split_rotary(num_query_heads, head_dim, position_embedding_base) diff --git a/tests/debug/test_batching_llama.py b/tests/debug/test_batching_llama.py new file mode 100644 index 0000000000..ff11188e4b --- /dev/null +++ b/tests/debug/test_batching_llama.py @@ -0,0 +1,160 @@ +# pylint: disable=invalid-name,missing-docstring +# Used as reference + +import argparse +import json +import os + +import numpy as np +import torch +import tvm +from transformers import LlamaTokenizer # type: ignore[import] +from tvm import relax +from tvm.runtime import ShapeTuple + +from mlc_llm import utils + +############################################################## +# Test file for e2e Llama with batching enabled by directly +# calling functions in VM. +# +# NOTE: the test will not be runnable until the attention +# compute function is integrated to Llama. This is left as +# an item that we will work on shortly in the future. +############################################################## + + +def _parse_args(): + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, default="Llama-2-7b-chat-hf-q4f16_1") + args.add_argument("--device-name", type=str, default="auto") + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--prompt", type=str, default="What's the meaning of life?") + args.add_argument("--profile", action="store_true", default=False) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def sample_from_logits(vm, logits, device): + temperature = 0.7 + top_p = 0.95 + + num_sequence = logits.shape[0] + temperature_arr = tvm.nd.array(np.full((num_sequence,), temperature, dtype="float32"), device) + probs = vm["softmax_with_temperature"](logits, temperature_arr).numpy() + + sampled_tokens = [] + fsample_top_p_from_prob = tvm.get_global_func("vm.builtin.sample_top_p_from_prob") + for seq_id in range(num_sequence): + token = fsample_top_p_from_prob(tvm.nd.array(probs[seq_id]), top_p, np.random.sample()) + sampled_tokens.append(token) + return sampled_tokens + + +def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals + device = tvm.device(args.device_name) + const_params = utils.load_params(args.artifact_path, device) + ex = tvm.runtime.load_module( + os.path.join( + args.artifact_path, + f"{args.model}-{args.quantization.name}-{args.device_name}.so", + ) + ) + vm = relax.VirtualMachine(ex, device) + + with open( + os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), + "r", + encoding="utf-8", + ) as f: + config = json.load(f) + + assert config["model_category"] == "llama" + tokenizer = LlamaTokenizer.from_pretrained( + os.path.join(args.artifact_path, "params"), trust_remote_code=True + ) + + num_sequences = 4 + generated_tokens = [[], [], [], []] + prompts = [ + "What's the meaning of life?", + "Introduce the history of Pittsburgh to me.", + "Write a three-day Seattle travel plan.", + "What is Alaska famous of?", + ] + num_decode_steps = 256 + + print("Create KV cache...") + max_total_seq_len = 16384 + page_size = 16 + kv_cache = vm["create_kv_cache"](ShapeTuple([num_sequences, max_total_seq_len, page_size])) + + fadd_sequence = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_add_sequence") + freset_append_length = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reset_append_lengths" + ) + freserve = tvm.get_global_func( + "vm.builtin.paged_attention_kv_cache_reserve_extra_length_for_append" + ) + fsync = tvm.get_global_func("vm.builtin.paged_attention_kv_cache_sync_aux_array_to_device") + + for seq_id in range(num_sequences): + print(f"Process seq {seq_id} for prefill...") + inputs = tvm.nd.array( + tokenizer(prompts[seq_id], return_tensors="pt").input_ids.to(torch.int32).numpy(), + device, + ) + seq_length = inputs.shape[1] + embedding = vm["embed"](inputs, const_params) + + seq_id_in_cache = fadd_sequence(kv_cache) + assert seq_id_in_cache == seq_id + + freset_append_length(kv_cache) + freserve(kv_cache, seq_id, seq_length) + fsync(kv_cache) + + print(f"Prefilling seq {seq_id}...") + logits, _ = vm["prefill_with_embed"](embedding, kv_cache, const_params) + + tokens = sample_from_logits(vm, logits, device) + assert len(tokens) == 1 + generated_tokens[seq_id].append(tokens[0]) + + print("Decoding...") + for step in range(num_decode_steps): + inputs = tvm.nd.array( + np.array( + [[generated_tokens[seq_id][-1]] for seq_id in range(num_sequences)], dtype="int32" + ), + device, + ) + embedding = vm["embed"](inputs, const_params) + freset_append_length(kv_cache) + for seq_id in range(num_sequences): + freserve(kv_cache, seq_id, 1) + fsync(kv_cache) + + logits, _ = vm["decode_with_embed"](embedding, kv_cache, const_params) + tokens = sample_from_logits(vm, logits, device) + assert len(tokens) == num_sequences + + for seq_id in range(num_sequences): + generated_tokens[seq_id].append(tokens[seq_id]) + + for seq_id in range(num_sequences): + output = tokenizer.decode(generated_tokens[seq_id]) + print("====================================================================") + print(f"Prompt {seq_id}: {prompts[seq_id]}") + print(f"Output: {output}") + print("\n\n") + + +if __name__ == "__main__": + ARGS = _parse_args() + deploy_to_pipeline(ARGS) From c2b8cbcb7199629f3eccff475102402d6c871047 Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Sat, 14 Oct 2023 06:32:05 +0100 Subject: [PATCH 021/116] Fix Stable LM 3B build (#1061) * [stablelm 3b] Rename dynamic vocab size from "v" to "vocab_size" * Add get_num_key_value_heads method to StableLM3bConfig --- mlc_llm/relax_model/stablelm_3b.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mlc_llm/relax_model/stablelm_3b.py b/mlc_llm/relax_model/stablelm_3b.py index 4bb1beedeb..89c15a7955 100644 --- a/mlc_llm/relax_model/stablelm_3b.py +++ b/mlc_llm/relax_model/stablelm_3b.py @@ -66,6 +66,11 @@ def __init__( self.num_shards = 1 self.kwargs = kwargs + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + return self.num_key_value_heads + class LayerNorm(nn.Module): def __init__( @@ -579,7 +584,7 @@ def create_embed_func( bsz = 1 seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): - model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("v", "int64")) + model = StableLM3bEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") @@ -608,7 +613,7 @@ def create_encoding_func( all_seq_len = tvm.tir.Var("m", "int64") hidden_size = config.hidden_size with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64"), sep_embed) + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs = ( @@ -652,7 +657,7 @@ def create_decoding_func( all_seq_len = tvm.tir.Var("n", "int64") with bb.function(func_name): - model = StableLM3bForCausalLM(config, tvm.tir.Var("v", "int64")) + model = StableLM3bForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") @@ -714,7 +719,9 @@ def create_kv_cache_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> No def create_softmax_func(bb: relax.BlockBuilder, config: StableLM3bConfig) -> None: with bb.function("softmax_with_temperature"): - logits = nn.Placeholder((1, 1, tvm.tir.Var("v", "int64")), dtype="float32", name="logits") + logits = nn.Placeholder( + (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + ) temperature = nn.Placeholder((), dtype="float32", name="temperature") with bb.dataflow(): div = bb.emit(relax.op.divide(logits, temperature)) From 481cd923e16730807a801381fcbbd1a1c2e3e18d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 14 Oct 2023 00:32:36 -0500 Subject: [PATCH 022/116] [Core] Remove duplication in MODEL.get_model calls (#1054) This commit removes the `if`/`elif` chain in `core.py`, where the body of each conditional assigns the same `mod, param_manager, params, model_config`, and is identical except for the choice of model being built. --- mlc_llm/core.py | 44 +++++++++++++++++++--------------- mlc_llm/relax_model/minigpt.py | 2 +- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 2e6eefad4d..7168c9d8b1 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -604,28 +604,34 @@ def build_model_from_args(args: argparse.Namespace): use_cache = args.use_cache and os.path.isfile(cache_path) if args.sep_embed and args.model_category != "llama": raise ValueError(f"separate embedding not supported on {args.model}") - if args.model_category != "minigpt": + + if args.model_cateogry == "minigpt": + # Special case for minigpt, which neither provides nor requires a configuration. + config = {} + else: with open(os.path.join(args.model_path, "config.json"), encoding="utf-8") as i_f: config = json.load(i_f) + if not use_cache or args.convert_weight_only: - if args.model_category in ("llama", "mistral"): - mod, param_manager, params, model_config = llama.get_model(args, config) - elif args.model_category == "stablelm_epoch": - mod, param_manager, params, model_config = stablelm_3b.get_model(args, config) - elif args.model_category == "gpt_neox": - mod, param_manager, params, model_config = gpt_neox.get_model(args, config) - elif args.model_category == "gpt_bigcode": - mod, param_manager, params, model_config = gpt_bigcode.get_model(args, config) - elif args.model_category == "minigpt": - mod, param_manager, params, model_config = minigpt.get_model(args) - elif args.model_category == "gptj": - mod, param_manager, params, model_config = gptj.get_model(args, config) - elif args.model_category == "rwkv" or args.model_category == "rwkv_world": - mod, param_manager, params, model_config = rwkv.get_model(args, config) - elif args.model_category == "chatglm": - mod, param_manager, params, model_config = chatglm.get_model(args, config) - else: - raise ValueError(f"Model {args.model} not supported") + + model_generators = { + "llama": llama, + "mistral": llama, + "stablelm_epoch": stablelm_3b, + "gpt_neox": gpt_neox, + "gpt_bigcode": gpt_bigcode, + "minigpt": minigpt, + "gptj": gptj, + "rwkv": rwkv, + "rwkv_world": rwkv, + "chatglm": chatglm, + } + + assert args.model_category in model_generators, f"Model {args.model} not supported" + + mod, param_manager, params, model_config = model_generators[args.model_category].get_model( + args, config + ) for qspec_updater_class in param_manager.qspec_updater_classes: qspec_updater = qspec_updater_class(param_manager) diff --git a/mlc_llm/relax_model/minigpt.py b/mlc_llm/relax_model/minigpt.py index 7bd30e70ed..96126bbf5b 100644 --- a/mlc_llm/relax_model/minigpt.py +++ b/mlc_llm/relax_model/minigpt.py @@ -502,7 +502,7 @@ def create_embed_func( bb.update_func(gv, mod[gv].with_attr("num_input", 1)) -def get_model(args): +def get_model(args, _config): model_name = args.model model_path = args.model_path From 81844314c24fdd127625bdd1e50d4eb77e9e95cb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 14 Oct 2023 00:33:15 -0500 Subject: [PATCH 023/116] [ParamManager] Cleanup creation of quantization IRModule (#1053) This commit replaces the single-parameter `relax_model.param_manager.create_quantize_func` function with a method on the `ParamManager`, `create_parameter_transformation`. This avoids potential typos between `param_manager` as the imported Python module `mlc_llm.relax_model.param_manager` and an instance of the `ParamManager` class named `param_manager`, and makes the functionality easier to find. This function also takes an optional `optimize_parameter_order` flag, defaulting to `True`, which applies the `ReorderTransformFunc` pass. Since the `ReorderTransformFunc` is intended to be used with several configuration objects owned by `ParamManager`, this simplifies the common path of producing an optimally-ordered parameter transformation module. --- mlc_llm/core.py | 4 ++ mlc_llm/relax_model/param_manager.py | 55 +++++++++++++++++++++++++++- mlc_llm/utils.py | 25 ++----------- 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 7168c9d8b1..0787db7073 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -638,6 +638,10 @@ def build_model_from_args(args: argparse.Namespace): qspec_updater.visit_module(mod) if not args.build_model_only: + # Run pre-quantization if provided. + args.model_path = param_manager.run_pre_quantize(args.model_path) + param_manager.init_torch_pname_to_bin_name(args.use_safetensors) + new_params = utils.convert_weights(param_manager, params, args) utils.save_params(new_params, args.artifact_path) if args.model_category != "minigpt": diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 138f04f769..590b60d76b 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -13,6 +13,7 @@ from .. import quantization from .modules import named_parameters +from ..transform import ReorderTransformFunc def f_default_compute_relax_param(relax_pname: str, torch_params: List[Any]) -> Any: @@ -274,6 +275,31 @@ def register_params( self.params_in_func[func_name].append(param) + def run_pre_quantize(self, model_path: str): + if self.f_run_prequantize is not None: + model_path = self.f_run_prequantize(model_path) + + self.model_path = model_path + return model_path + + def init_torch_pname_to_bin_name(self, use_safetensors: bool): + assert hasattr(self, "model_path"), ( + "Must call either set_param_loading_func or run_pre_quantize " + "before init_torch_pname_to_bin_name" + ) + + if self.pidx2pname: + mapping = load_torch_pname2binname_map( + self.model_path, + use_safetensors, + set(self.pidx2pname.values()), + self.f_convert_pname_fwd, + ) + else: + mapping = {} + + self.torch_pname2binname = mapping + def set_param_loading_func( self, model_path: str, @@ -726,6 +752,33 @@ def _dequantize( # Apply the dequantization function. return bb.emit(f_dequantize(bb, qparams)) + def create_parameter_transformation(self, optimize_parameter_order: bool = True): + """Produce an IRModule that can transform the parameters + + Parameters + ---------- + optimize_parameter_order: bool + + If true, reorder the parameter transformations to + prioritize operations that use a currently-open file. If + false, transform the parameters in their default order. + + Returns + ------- + tvm.IRModule + The transformation module + + """ + mod = _create_quantize_func(self) + if optimize_parameter_order: + reorder_pass = ReorderTransformFunc( + self.pidx2pname, + self.torch_pname2binname, + self.f_convert_pname_fwd, + ) + mod = reorder_pass(mod) + return mod + @mutator class ParamReplacer(PyExprMutator): @@ -868,7 +921,7 @@ def load_torch_pname2binname_map( return torch_pname2binname -def create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: +def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: """Construct the Relax function which computes quantization. This method is called by `transform_module` below, and is not directly invoked outside the class. diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 9d8751e5d6..f356874d1d 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -10,7 +10,7 @@ from .quantization import quantization_schemes from .relax_model import param_manager -from .transform import ReorderTransformFunc + supported_model_types = set( ["llama", "gpt_neox", "gpt_bigcode", "minigpt", "moss", "rwkv", "gptj", "chatglm", "mistral", "stablelm_epoch"] @@ -192,31 +192,12 @@ def convert_weights( model_params: List[Optional[tvm.nd.NDArray]], args: argparse.Namespace, ): - # Run pre-quantization if provided. - if param_mgr.f_run_prequantize is not None: - args.model_path = param_mgr.f_run_prequantize(args.model_path) - param_mgr.model_path = args.model_path - param_mgr.torch_pname2binname = ( - param_manager.load_torch_pname2binname_map( - args.model_path, - args.use_safetensors, - set(param_mgr.pidx2pname.values()), - param_mgr.f_convert_pname_fwd, - ) - if len(param_mgr.pidx2pname) != 0 - else dict() - ) - # Create the quantization function. # We first create an initial one, then reorder it according to each # weight's location in the binary files, in the purpose of reducing # memory usage when loading torch weights as well as acceleration. - mod_transform = param_manager.create_quantize_func(param_mgr) - mod_transform = ReorderTransformFunc( - param_mgr.pidx2pname, - param_mgr.torch_pname2binname, - param_mgr.f_convert_pname_fwd, - )(mod_transform) + mod_transform = param_mgr.create_parameter_transformation() + # Remove the dataflow block inside the param transform function, # so that the LazyTransformParams pass can be applied. mod_transform = relax.transform.ToNonDataflow()(mod_transform) From 9010d48a69a56086d1c076478d57b3e5b6874ca8 Mon Sep 17 00:00:00 2001 From: Jeethu Rao Date: Sun, 15 Oct 2023 06:42:24 +0100 Subject: [PATCH 024/116] Minor typo fix (#1064) --- mlc_llm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 0787db7073..628cc6e91b 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -605,7 +605,7 @@ def build_model_from_args(args: argparse.Namespace): if args.sep_embed and args.model_category != "llama": raise ValueError(f"separate embedding not supported on {args.model}") - if args.model_cateogry == "minigpt": + if args.model_category == "minigpt": # Special case for minigpt, which neither provides nor requires a configuration. config = {} else: From b0bfc88c7a4e3d329725a7b727f77567ca0e8e87 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 15 Oct 2023 00:24:24 -0700 Subject: [PATCH 025/116] Add links to Python API Reference (#1068) --- docs/index.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index 28a7d103ac..345b5d9603 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,8 @@ It is recommended to have at least 6GB free VRAM to run it. **Colab walkthrough.** A Jupyter notebook on `Colab `_ is provided with detailed walkthrough of the Python API. + **Documentation and tutorial.** Python API reference and its tutorials are `available online `_. + .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-api.jpg :width: 600 :align: center From 204860b786460247f77717053cd3a4e1e0069b06 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 15 Oct 2023 14:02:12 -0400 Subject: [PATCH 026/116] [Fix] ChatModule incorrect temperature buffer shape (#1070) PR #1048 updated the signature of softmax in the built model library and changed the temperature buffer shape in ChatModule. This causes some existing demo unable to run since we did not do a round of model library update. This PR reverts the ChatModule change, and adds back the softmax function in non-batching case. With this PR, the regression should be fixed. --- cpp/llm_chat.cc | 2 +- mlc_llm/relax_model/llama.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 7286cda7b0..339e5429d1 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -481,7 +481,7 @@ class LLMChat { // Step 6. KV cache creation. this->kv_cache_ = ft_.create_kv_cache_func_(); // Step 7. Pre-allocate fixed size ndarray - this->temperature_arr_ = NDArray::Empty({1}, DataType::Float(32), device_); + this->temperature_arr_ = NDArray::Empty({}, DataType::Float(32), device_); float temperature = static_cast(this->temperature_); this->temperature_arr_.CopyFromBytes(&temperature, sizeof(float)); if (ft_.use_disco) { diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index eea1bb05bc..e45a4a3e20 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -1068,7 +1068,20 @@ def create_paged_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> N bb.emit_func_output(cache) -def create_softmax_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None: +def create_softmax_func_for_single_seq(bb: relax.BlockBuilder, config: LlamaConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def create_softmax_func_for_batching(bb: relax.BlockBuilder, config: LlamaConfig) -> None: with bb.function("softmax_with_temperature"): bsz = tvm.tir.Var("nseq", "int64") logits = nn.Placeholder( @@ -1206,12 +1219,13 @@ def get_model(args, hf_config): create_prefill_func_for_batching(bb, param_manager, config, args.quantization) create_decoding_func_for_batching(bb, param_manager, config, args.quantization) create_paged_kv_cache_func(bb, config) + create_softmax_func_for_batching(bb, config) else: create_prefill_func_for_single_seq(bb, param_manager, config, args.quantization, sep_embed) create_decoding_func_for_single_seq(bb, param_manager, config, args.quantization) create_kv_cache_func(bb, config) + create_softmax_func_for_single_seq(bb, config) - create_softmax_func(bb, config) create_metadata_func( bb, model_name=model_name, From d2020770c55fa477e360613049d18a468d21831e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 16 Oct 2023 08:06:26 -0500 Subject: [PATCH 027/116] [ParamManager] Added progress bar for get_item/set_item (#1063) --- mlc_llm/utils.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index f356874d1d..bb19f45c4f 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -1,5 +1,6 @@ # pylint: disable=missing-docstring,invalid-name import argparse +import functools import json import os import shutil @@ -17,6 +18,24 @@ ) +def wrap_tqdm_counter(func, **tqdm_kwargs): + # tqdm isn't a hard requirement, so return the original function + # if it isn't available. + try: + from tqdm import tqdm + except ImportError: + return func + + pbar = tqdm(**tqdm_kwargs) + + @functools.wraps(func) + def inner(*args, **kwargs): + pbar.update(1) + return func(*args, **kwargs) + + return inner + + def argparse_postproc_common(args: argparse.Namespace) -> None: if hasattr(args, "device_name"): if args.device_name == "auto": @@ -198,6 +217,12 @@ def convert_weights( # memory usage when loading torch weights as well as acceleration. mod_transform = param_mgr.create_parameter_transformation() + # Save the number of parameters before we lower mod_transform, so + # we can use them in the progress bar. + transform_func = mod_transform["transform_params"] + num_original_params = len(transform_func.params[0].struct_info.fields) + num_transformed_params = len(transform_func.struct_info.ret.fields) + # Remove the dataflow block inside the param transform function, # so that the LazyTransformParams pass can be applied. mod_transform = relax.transform.ToNonDataflow()(mod_transform) @@ -227,6 +252,14 @@ def convert_weights( device, device_cpu, ) + + get_item = wrap_tqdm_counter( + get_item, desc="Get old param", position=0, unit="tensors", total=num_original_params + ) + set_item = wrap_tqdm_counter( + set_item, desc="Set new param", position=1, unit="tensors", total=num_transformed_params + ) + tvm.register_func(func_name="get_item", f=get_item, override=True) tvm.register_func(func_name="set_item", f=set_item, override=True) From 9872c48097522db490709eb539e571dbbd9c5265 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 16 Oct 2023 14:56:24 -0400 Subject: [PATCH 028/116] [Python] Extract common device str parse function in ChatModule (#1074) This PR lifts the device string parsing (just a few of lines) to a standalone function, so that on the serving side the serving can make use of this function as well. Tested Python API and it does not seem to incur regression. --- python/mlc_chat/chat_module.py | 40 ++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 97b7cb7670..db46c080f4 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -548,6 +548,38 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) +def _parse_device_str(device: str): + """Parse the input device identifier into device name and id. + + Parameters + ---------- + device : str + The device identifier to parse. + It can be "device_name" (e.g., "cuda") or + "device_name:device_id" (e.g., "cuda:1"). + + Returns + ------- + device_name : str + The name of the device. + + device_id : int + The id of the device, or 0 if not specified in the input. + """ + device_err_msg = ( + f"Invalid device name: {device}. Please enter the device in the form " + "'device_name:device_id' or 'device_name', where 'device_name' needs to be " + "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'." + ) + device_args = device.split(":") + if len(device_args) == 1: + return device_args[0], 0 + elif len(device_args) == 2: + return device_args[0], int(device_args[1]) + elif len(device_args) > 2: + raise ValueError(device_err_msg) + + def _detect_local_device(device_id: int = 0): """Automatically detect the local device if user does not specify. @@ -647,13 +679,7 @@ def __init__( ) # 0. Retrieve device_name and device_id (if any, default 0) from device arg - device_args = device.split(":") - if len(device_args) == 1: - device_name, device_id = device_args[0], 0 - elif len(device_args) == 2: - device_name, device_id = device_args[0], int(device_args[1]) - elif len(device_args) > 2: - raise ValueError(device_err_msg) + device_name, device_id = _parse_device_str(device) # 1. Get self.device if device_name == "cuda": From 3aefd9f9f25debe6f55e92b4e181cf3096f1d5e0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 16 Oct 2023 21:16:27 -0700 Subject: [PATCH 029/116] [Bugfix] Compilation Error in q4f32_1 (#1078) The pass `fuse-split-rotary` assumes the compute dtype is fp16, which usually is, but in certain cases, e.g. `q0f32` and `q4f32_1`, the compute is based on fp32 instead. This PR strengthens the check guard. --- mlc_llm/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 628cc6e91b..34e3041e2f 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -405,6 +405,7 @@ def mod_transform_before_build( hasattr(config, "num_attention_heads") and hasattr(config, "hidden_size") and hasattr(config, "position_embedding_base") + and getattr(config, "dtype", "float16") == "float16" ): max_seq_len = None if args.max_seq_len > 0: From 2625945ee161231057792c4490bfbbb665ad1612 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 19 Oct 2023 08:57:50 -0700 Subject: [PATCH 030/116] Establish `mlc_chat.compiler` (#1082) This PR establishes the compiler components in MLC-Chat Python API, which currently includes two primary components: models and parameters. The models are `nn.Module`-based definition of an LLM, which, as the very first stab, contains only `LlamaForCasualLM`. It is decomposed into three files: - `llama_config.py`: common configurations for Llama, where we define relevant configurations of its architecture, as well as include standard config file for Llama2-7B/13B/70B for convenient testing; - `llama.py`: the model architecture of Llama, based on the PyTorch-like `nn.Module` API; - `llama_parameter.py`: defines the mapping between MLC parameters and pytorch parameters. The parameters contains the basic functionality of parameter mapping, and the loaders that effectively convert parameters from PyTorch to MLC according to the mapping specified. Currently, only `HFTorchLoader` is implemented, but loaders like SafeTensor, GPTQ or AWQ should be quite straightforward according to the existing design. On top of this PR, on-the-fly quantization could be defined as a loading time transformation on MLC parameters, while pre-quantized parameter loading is effectively parameter loading after MLC's `nn.Module` is quantized. Two unittests examplify how the infrastructure works: - `./tests/python/model/test_llama.py` shows how to create an `nn.Module` using the new infra, and then convert it to TVM IRModule; - `./tests/python/parameter/hf_torch_loader.py` shows how to load parameters from HuggingFace PyTorch format. Besides, `mlc_chat.support` is established for utility functions, which now contains two utils: - `config.py` which supports reading configurations into dataclasses from JSON file or Python dict. On top of Python dataclass, it throws irrelevant fields into `cls.kwargs`, which is helpful when loading HuggingFace configuration file; - `tqdm.py` which contains tqdm-related utilities, primarily redirecting logging and printing to work nicely with tqdm. --- mlc_llm/models/__init__.py | 3 - mlc_llm/param_loader/__init__.py | 6 - mlc_llm/param_loader/hf_torch_loader.py | 191 ------------- mlc_llm/param_loader/param_mapping.py | 36 --- python/mlc_chat/cli/benchmark.py | 26 +- python/mlc_chat/compiler/__init__.py | 5 + python/mlc_chat/compiler/model/__init__.py | 2 + .../mlc_chat/compiler/model}/llama.py | 34 +-- .../mlc_chat/compiler/model/llama_config.py | 108 ++++++++ .../compiler/model/llama_parameter.py | 20 +- .../mlc_chat/compiler/parameter/__init__.py | 6 + .../compiler/parameter/hf_torch_loader.py | 260 ++++++++++++++++++ python/mlc_chat/compiler/parameter/mapping.py | 79 ++++++ python/mlc_chat/support/__init__.py | 4 + .../mlc_chat/support/config.py | 21 +- python/mlc_chat/support/tqdm.py | 38 +++ tests/python/model/test_llama.py | 19 ++ .../python/parameter/test_hf_torch_loader.py | 42 +++ tests/python/test_model_llama.py | 70 ----- tests/python/test_param_loader_llama.py | 32 --- 20 files changed, 617 insertions(+), 385 deletions(-) delete mode 100644 mlc_llm/models/__init__.py delete mode 100644 mlc_llm/param_loader/__init__.py delete mode 100644 mlc_llm/param_loader/hf_torch_loader.py delete mode 100644 mlc_llm/param_loader/param_mapping.py create mode 100644 python/mlc_chat/compiler/__init__.py create mode 100644 python/mlc_chat/compiler/model/__init__.py rename {mlc_llm/models => python/mlc_chat/compiler/model}/llama.py (90%) create mode 100644 python/mlc_chat/compiler/model/llama_config.py rename mlc_llm/models/llama_param_map.py => python/mlc_chat/compiler/model/llama_parameter.py (78%) create mode 100644 python/mlc_chat/compiler/parameter/__init__.py create mode 100644 python/mlc_chat/compiler/parameter/hf_torch_loader.py create mode 100644 python/mlc_chat/compiler/parameter/mapping.py create mode 100644 python/mlc_chat/support/__init__.py rename mlc_llm/models/model_config_base.py => python/mlc_chat/support/config.py (65%) create mode 100644 python/mlc_chat/support/tqdm.py create mode 100644 tests/python/model/test_llama.py create mode 100644 tests/python/parameter/test_hf_torch_loader.py delete mode 100644 tests/python/test_model_llama.py delete mode 100644 tests/python/test_param_loader_llama.py diff --git a/mlc_llm/models/__init__.py b/mlc_llm/models/__init__.py deleted file mode 100644 index 380ea83505..0000000000 --- a/mlc_llm/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Model definition using PyTorch-like nn.Module API""" -from . import llama, llama_param_map -from .model_config_base import ModelConfig diff --git a/mlc_llm/param_loader/__init__.py b/mlc_llm/param_loader/__init__.py deleted file mode 100644 index dfc748d3f6..0000000000 --- a/mlc_llm/param_loader/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -""" -Utilities for loading parameters from specific formats, for example, HuggingFace PyTorch, -HuggingFace SafeTensor, GGML, AutoGPTQ. -""" -from .hf_torch_loader import HFTorchLoader -from .param_mapping import ParameterMapping diff --git a/mlc_llm/param_loader/hf_torch_loader.py b/mlc_llm/param_loader/hf_torch_loader.py deleted file mode 100644 index 6c12af9181..0000000000 --- a/mlc_llm/param_loader/hf_torch_loader.py +++ /dev/null @@ -1,191 +0,0 @@ -"""A weight loader for HuggingFace's PyTorch format""" -import gc -import json -import logging -import time -from collections import defaultdict -from pathlib import Path -from typing import Dict, List - -import numpy as np - -from .param_mapping import ParameterMapping - -logger = logging.getLogger(__name__) - - -class HFTorchLoader: - """A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters. - - Attributes - ---------- - param_map : ParameterMapping - The parameter mapping from MLC to HuggingFace PyTorch. - - torch_to_path : Dict[str, Path] - A mapping from PyTorch parameter name to the path of the file containing it. - - cached_files : Dict[Path, Dict[str, np.ndarray]] - A cache of the loaded files. The key is the path of the file, and the value is a mapping - from parameter name to the parameter value. - - stats_load_time_sec : float - The time spent on loading the files in seconds. - - stats_load_data_gb : float - The amount of data loaded in GB. - """ - - param_map: ParameterMapping - torch_to_path: Dict[str, Path] - cached_files: Dict[Path, Dict[str, np.ndarray]] - stats_load_time_sec: float - stats_load_data_gb: float - - def __init__(self, config_path: Path, param_map: ParameterMapping) -> None: - """Create a parameter loader from HuggingFace PyTorch format. - - Parameters - ---------- - config_path : pathlib.Path - Path to the torch indexing file, usually `pytorch_model.bin.index.json` in the repo. - param_map : ParameterMapping - The parameter mapping from MLC to HuggingFace PyTorch. - """ - with config_path.open("r", encoding="utf-8") as in_file: - torch_weight_map = json.load(in_file)["weight_map"] - self.param_map = param_map - self.torch_to_path = {} - for torch_name, path_str in torch_weight_map.items(): - path = config_path.parent / path_str - self.torch_to_path[torch_name] = path - self.cached_files = {} - self.stats_load_time_sec = 0.0 - self.stats_load_data_gb = 0.0 - - used_torch_names = sum(param_map.name_map.values(), ()) - # Check 1. All PyTorch parameters in the weight files are used unless explicitly specified - unused_torch_names = set(torch_weight_map) - set(used_torch_names) - param_map.unused_params - if unused_torch_names: - logger.warning( - "Unused torch parameters: %s", - ", ".join(sorted(unused_torch_names)), - ) - # Check 2. All PyTorch parameters required are stored in the weight files - nonexistent_torch_names = set(used_torch_names) - set(torch_weight_map) - if nonexistent_torch_names: - raise ValueError( - "The following torch parameters do not exist in the weight files:\n " - + "\n ".join(sorted(nonexistent_torch_names)), - ) - - def suggest_loading_order(self) -> List[str]: - """Suggest a loading order for MLC parameters. - - Returns - ------- - order : List[str] - A list of MLC parameters in the order that ensures file locality. - """ - # Step 1. Build a map from path to torch parameters - path_to_torch: Dict[Path, List[str]] = defaultdict(list) - for torch_name, path in self.torch_to_path.items(): - path_to_torch[path].append(torch_name) - # Step 2. Build a map from torch parameters to MLC parameters - torch_to_mlc = defaultdict(list) - for mlc_name, torch_names in self.param_map.name_map.items(): - for torch_name in torch_names: - torch_to_mlc[torch_name].append(mlc_name) - # Step 3. Construct the ordering that ensures file locality - order = [] - for _, torch_names in path_to_torch.items(): - for torch_name in torch_names: - for mlc_name in torch_to_mlc[torch_name]: - order.append(mlc_name) - return order - - def load_param(self, name: str) -> np.ndarray: - """Load a MLC parameter according to its name. - - Parameters - ---------- - name : str - The name of the MLC parameter. - - Returns - ------- - param : np.ndarray - The parameter value as a numpy array. Note that if the parameter is stored in bfloat16, - it will be converted to float32. - """ - mlc_name = name - torch_names = self.param_map.name_map[mlc_name] - files_required = {self.torch_to_path[p] for p in torch_names} - files_existing = set(self.cached_files.keys()) - files_to_load = files_required - files_existing - files_to_unload = files_existing - files_required - - # Step 1. When there is some file to unloaded: - # - If no pending file load: unloading is deferred as there is no gain in peak memory usage; - # - Need to load files: unload immediately to save memory and make space for the new files. - if files_to_load: - for path in files_to_unload: - self._unload_file(path) - # Step 2. Load all the files needed - for path in files_to_load: - self._load_file(path) - # Step 3. Collect all torch parameters in order - torch_names = [self._retrieve_torch_param_from_cache(name) for name in torch_names] - # Step 4. Apply the mapping function - map_func = self.param_map.map_func[mlc_name] - return map_func(*torch_names) - - def __enter__(self) -> "HFTorchLoader": - self.stats_load_time_sec = 0.0 - self.stats_load_data_gb = 0.0 - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - cached_files = list(self.cached_files.keys()) - for path in cached_files: - self._unload_file(path) - logger.info( - "Time used in PyTorch loading: %.3f sec. Total %.3f GB loaded", - self.stats_load_time_sec, - self.stats_load_data_gb, - ) - - def _load_file(self, path: Path) -> None: - import torch # pylint: disable=import-outside-toplevel - - logging.info("Loading PyTorch parameters from: %s", path) - - start_time = time.time() - result = {} - for name, param in torch.load(path, map_location=torch.device("cpu")).items(): - param = param.detach().cpu() - dtype = str(param.dtype) - if dtype == "torch.bfloat16": - param = param.float() - param = param.numpy() - self.stats_load_data_gb += param.nbytes / (1024**3) - result[name] = param - logging.debug(' Parameter: "%s", shape: %s, dtype: %s', name, param.shape, dtype) - self.cached_files[path] = result - self.stats_load_time_sec += time.time() - start_time - - def _unload_file(self, path: Path) -> None: - logging.debug("Unloading PyTorch weight file: %s", path) - - start_time = time.time() - del self.cached_files[path] - gc.collect() - self.stats_load_time_sec += time.time() - start_time - - def _retrieve_torch_param_from_cache(self, name: str) -> np.ndarray: - assert name in self.torch_to_path - path = self.torch_to_path[name] - assert path in self.cached_files - cache = self.cached_files[path] - assert name in cache - return cache[name] diff --git a/mlc_llm/param_loader/param_mapping.py b/mlc_llm/param_loader/param_mapping.py deleted file mode 100644 index c378b30268..0000000000 --- a/mlc_llm/param_loader/param_mapping.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Parameter mapping for converting different LLM implementations to MLC LLM.""" -import dataclasses -from typing import Callable, Dict, Set, Tuple - -import numpy as np - - -@dataclasses.dataclass -class ParameterMapping: - """Mapping from a parameter name in MLC LLM's model definition to its potential source, - for example, from MLC parameter "model.layers.2.post_attention_layernorm.weight" to PyTorch's - parameter correspondingly. - - Parameters - ---------- - name_map : Dict[str, Tuple[str, ...]] - A dictionary that maps the name of a parameter to its source. For example, - in Llama2, the source of MLC parameter "model.layers.0.self_attn.qkv_proj.weight" from - huggingface torch are: - - - "model.layers.0.self_attn.q_proj.weight" - - "model.layers.0.self_attn.k_proj.weight" - - "model.layers.0.self_attn.v_proj.weight" - - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] - A dictionary that maps the name of a parameter to a function that combines the source - parameters into the MLC parameter. For example, for the above example, the function - would be: `lambda q, k, v: np.concatenate([q, k, v], axis=0)`. - - unused_params : Set[str] - Parameter names in the source weights that are not used in the MLC LLM model definition. - """ - - name_map: Dict[str, Tuple[str, ...]] - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] - unused_params: Set[str] = dataclasses.field(default_factory=dict) diff --git a/python/mlc_chat/cli/benchmark.py b/python/mlc_chat/cli/benchmark.py index bcbb4eca53..308921e3d0 100644 --- a/python/mlc_chat/cli/benchmark.py +++ b/python/mlc_chat/cli/benchmark.py @@ -1,7 +1,7 @@ """A command line tool for benchmarking a chat model.""" import argparse -from mlc_chat import ChatModule +from mlc_chat import ChatConfig, ChatModule parser = argparse.ArgumentParser(description="Benchmark an MLC LLM ChatModule.") parser.add_argument( @@ -13,6 +13,21 @@ the model folder over possible paths.""", required=True, ) +parser.add_argument( + "--model-lib", + type=str, + help="""The compiled model library. In MLC LLM, an LLM is compiled to a shared or static + library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat, + as the runtime of MLC LLM, depends on the compiled model library to generate tokens. + """, + required=False, +) +parser.add_argument( + "--num-shards", + type=int, + help="Number of GPUs to be used.", + required=False, +) parser.add_argument( "--device", type=str, @@ -40,7 +55,14 @@ def main(): """The main function that runs the benchmarking.""" args = parser.parse_args() - chat_module = ChatModule(model=args.model, device=args.device) + chat_module = ChatModule( + model=args.model, + device=args.device, + chat_config=ChatConfig( + num_shards=args.num_shards, + ), + lib_path=args.model_lib, + ) output = chat_module.benchmark_generate(args.prompt, generate_length=args.generate_length) print(f"Generated text:\n{output}\n") print(f"Statistics: {chat_module.stats(verbose=True)}") diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py new file mode 100644 index 0000000000..2206f480f6 --- /dev/null +++ b/python/mlc_chat/compiler/__init__.py @@ -0,0 +1,5 @@ +""" +A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency, +but users could optionally import it if they want to use the compiler. +""" +from . import model, parameter diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py new file mode 100644 index 0000000000..b568bd84f7 --- /dev/null +++ b/python/mlc_chat/compiler/model/__init__.py @@ -0,0 +1,2 @@ +"""Model definition for the compiler.""" +from . import llama, llama_config, llama_parameter diff --git a/mlc_llm/models/llama.py b/python/mlc_chat/compiler/model/llama.py similarity index 90% rename from mlc_llm/models/llama.py rename to python/mlc_chat/compiler/model/llama.py index 40df48180e..663e6d93c2 100644 --- a/mlc_llm/models/llama.py +++ b/python/mlc_chat/compiler/model/llama.py @@ -1,41 +1,19 @@ -"""Implementation for Llama2 architecture""" -import dataclasses +""" +Implementation for Llama2 architecture. +TODO: add docstring +""" import math -from typing import Any, Dict, Optional +from typing import Optional from tvm import te, tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from .model_config_base import ModelConfig +from .llama_config import LlamaConfig # pylint: disable=invalid-name,missing-docstring -@dataclasses.dataclass -class LlamaConfig(ModelConfig): # pylint: disable=too-many-instance-attributes - hidden_act: str - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_hidden_layers: int - rms_norm_eps: float - vocab_size: int - max_sequence_length: int = 2048 - position_embedding_base: int = 10000 - num_key_value_heads: int = 0 - kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - head_dim: int = 0 - - def __post_init__(self): - if self.num_key_value_heads == 0: - self.num_key_value_heads = self.num_attention_heads - if self.head_dim == 0: - self.head_dim = self.hidden_size // self.num_attention_heads - assert self.num_attention_heads % self.num_key_value_heads == 0 - assert self.head_dim * self.num_attention_heads == self.hidden_size - - class RotaryEmbedding(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() diff --git a/python/mlc_chat/compiler/model/llama_config.py b/python/mlc_chat/compiler/model/llama_config.py new file mode 100644 index 0000000000..113acd456f --- /dev/null +++ b/python/mlc_chat/compiler/model/llama_config.py @@ -0,0 +1,108 @@ +"""Common configuration for Llama models.""" +import dataclasses +from typing import Any, Dict + +from ...support.config import ConfigBase + + +@dataclasses.dataclass +class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + hidden_act: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + max_sequence_length: int = 2048 + position_embedding_base: int = 10000 + num_key_value_heads: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + head_dim: int = 0 + + def __post_init__(self): + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.num_attention_heads % self.num_key_value_heads == 0 + assert self.head_dim * self.num_attention_heads == self.hidden_size + + @staticmethod + def from_predefined(name: str) -> "LlamaConfig": + """Create a LlamaConfig from a predefined configuration.""" + return LlamaConfig.from_dict(CONFIG[name]) + + +CONFIG = { + "llama2_7b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "llama2_13b": { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pad_token_id": 0, + "pretraining_tp": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "llama2_70b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, +} diff --git a/mlc_llm/models/llama_param_map.py b/python/mlc_chat/compiler/model/llama_parameter.py similarity index 78% rename from mlc_llm/models/llama_param_map.py rename to python/mlc_chat/compiler/model/llama_parameter.py index 3737893702..39a8921a05 100644 --- a/mlc_llm/models/llama_param_map.py +++ b/python/mlc_chat/compiler/model/llama_parameter.py @@ -4,14 +4,12 @@ """ import numpy as np -from mlc_llm.param_loader import ParameterMapping - +from ..parameter import ExternMapping from .llama import LlamaConfig, LlamaForCasualLM -def hf_torch(model_config: LlamaConfig) -> ParameterMapping: - """ - Returns a parameter mapping that maps from the names of MLC LLM parameters to +def hf_torch(model_config: LlamaConfig) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to the names of HuggingFace PyTorch parameters. Parameters @@ -21,14 +19,14 @@ def hf_torch(model_config: LlamaConfig) -> ParameterMapping: Returns ------- - param_map : ParameterMapping + param_map : ExternMapping The parameter mapping from MLC to HuggingFace PyTorch. """ model = LlamaForCasualLM(model_config) _, named_params = model.export_tvm(spec=model.get_default_spec()) parameter_names = {name for name, _ in named_params} - name_map = {} + param_map = {} map_func = {} unused_params = set() @@ -37,7 +35,7 @@ def hf_torch(model_config: LlamaConfig) -> ParameterMapping: attn = f"model.layers.{i}.self_attn" assert f"{attn}.qkv_proj.weight" in parameter_names map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0) - name_map[f"{attn}.qkv_proj.weight"] = ( + param_map[f"{attn}.qkv_proj.weight"] = ( f"{attn}.q_proj.weight", f"{attn}.k_proj.weight", f"{attn}.v_proj.weight", @@ -46,7 +44,7 @@ def hf_torch(model_config: LlamaConfig) -> ParameterMapping: mlp = f"model.layers.{i}.mlp" assert f"{mlp}.gate_up_proj.weight" in parameter_names map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0) - name_map[f"{mlp}.gate_up_proj.weight"] = ( + param_map[f"{mlp}.gate_up_proj.weight"] = ( f"{mlp}.gate_proj.weight", f"{mlp}.up_proj.weight", ) @@ -56,5 +54,5 @@ def hf_torch(model_config: LlamaConfig) -> ParameterMapping: for name in parameter_names: if name not in map_func: map_func[name] = lambda x: x - name_map[name] = (name,) - return ParameterMapping(name_map, map_func, unused_params) + param_map[name] = (name,) + return ExternMapping(param_map, map_func, unused_params) diff --git a/python/mlc_chat/compiler/parameter/__init__.py b/python/mlc_chat/compiler/parameter/__init__.py new file mode 100644 index 0000000000..3ea9a2b46e --- /dev/null +++ b/python/mlc_chat/compiler/parameter/__init__.py @@ -0,0 +1,6 @@ +""" +A subpackage of the compiler that represents mapping between external parameters, quantized +parameters and parameters in MLC-defined models. +""" +from .hf_torch_loader import HFTorchLoader +from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_chat/compiler/parameter/hf_torch_loader.py b/python/mlc_chat/compiler/parameter/hf_torch_loader.py new file mode 100644 index 0000000000..000642800e --- /dev/null +++ b/python/mlc_chat/compiler/parameter/hf_torch_loader.py @@ -0,0 +1,260 @@ +"""A weight loader for HuggingFace's PyTorch format""" +import dataclasses +import gc +import json +import logging +import time +from collections import OrderedDict, defaultdict +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Iterator, List, Set, Tuple + +import numpy as np +from tqdm import tqdm +from tvm.runtime import NDArray +from tvm.runtime.ndarray import array as as_ndarray + +from .mapping import ExternMapping + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Stats: + """Statistics of the loading process of HuggingFace PyTorch loader. + + Attributes + ---------- + load_time_sec : float + Time used in loading the parameters. + + map_time_sec : float + Time used in applying the mapping function, i.e. `ExternMapping.map_func`. + + quant_time_sec : float + Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`. + + current_memory_gb : float + The current RAM usage in GB. + + total_memory_gb : float + The total size data loaded from disk in GB. + + max_memory_gb : float + The maximum RAM usage in GB. + """ + + load_time_sec: float = 0.0 + map_time_sec: float = 0.0 + quant_time_sec: float = 0.0 + + current_memory_gb: float = 0.0 + total_memory_gb: float = 0.0 + max_memory_gb: float = 0.0 + + def timer(self, attr): + """A context manager to time the scope and add the time to the attribute.""" + + @contextmanager + def timed_scope(): + start_time = time.time() + yield + elapsed_time = time.time() - start_time + setattr(self, attr, getattr(self, attr) + elapsed_time) + + return timed_scope() + + def mem_add(self, nbytes: int): + """Add the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb += mem_gb + self.total_memory_gb += mem_gb + self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb) + + def mem_rm(self, nbytes: int): + """Remove the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb -= mem_gb + + +class HFTorchLoader: # pylint: disable=too-few-public-methods + """A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters. + + Attributes + ---------- + stats : Stats + Statistics of the loading process. + + extern_param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + + torch_to_path : Dict[str, Path] + A mapping from PyTorch parameter name to the path of the file containing it, or the path + meaning all parameters are stored in a single file. + + cached_files : Dict[Path, Dict[str, np.ndarray]] + A cache of the loaded files. The key is the path of the file, and the value is a mapping + from parameter name to the parameter value. + """ + + stats: Stats + extern_param_map: ExternMapping + cached_files: Dict[Path, Dict[str, np.ndarray]] + torch_to_path: Dict[str, Path] + + def __init__( + self, + path: Path, + extern_param_map: ExternMapping, + ) -> None: + """Create a parameter loader from HuggingFace PyTorch format. + + Parameters + ---------- + path : pathlib.Path + Path to either a JSON indexing file, or a PyTorch bin file. + 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` in the repo, + which contains a `weight_map` that maps each PyTorch parameter to the file containing + the weight. 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo, + which contains all the parameters. + + extern_param_map : ExternMapping + Maps an MLC parameter to a list of PyTorch parameters. + """ + assert path.is_file() + self.stats = Stats() + self.extern_param_map = extern_param_map + self.cached_files = {} + self.torch_to_path = {} + if path.suffix == ".bin": + self._load_file(path) + for name in self.cached_files[path].keys(): + self.torch_to_path[name] = path + elif path.suffix == ".json": + with path.open("r", encoding="utf-8") as in_file: + torch_weight_map = json.load(in_file)["weight_map"] + for torch_name, path_str in torch_weight_map.items(): + self.torch_to_path[torch_name] = path.parent / path_str + else: + raise FileNotFoundError(f"Unknown file suffix: {path}") + _check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) + + def load(self) -> Iterator[Tuple[str, NDArray]]: + """Load the parameters and yield the MLC parameter and its value.""" + mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) + for mlc_name in tqdm(mlc_names): + param = self._load_mlc_param(mlc_name) + yield mlc_name, param + cached_files = list(self.cached_files.keys()) + for path in cached_files: + self._unload_file(path) + + logger.info( + "Time used: " + "PyTorch loading: %.3f sec; " + "Pre-quantization mapping: %.3f sec; " + "Quantization: %.3f sec", + self.stats.load_time_sec, + self.stats.map_time_sec, + self.stats.quant_time_sec, + ) + logger.info( + "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB", + self.stats.total_memory_gb, + self.stats.max_memory_gb, + ) + + def _load_mlc_param(self, mlc_name: str) -> np.ndarray: + torch_names = self.extern_param_map.param_map[mlc_name] + files_required = {self.torch_to_path[p] for p in torch_names} + files_existing = set(self.cached_files.keys()) + files_to_load = files_required - files_existing + files_to_unload = files_existing - files_required + + # Step 1. When there is some file to unloaded: + # - If no pending file load: unloading is deferred as there is no gain in peak memory usage; + # - Need to load files: unload immediately to save memory and make space for the new files. + if files_to_load: + for path in files_to_unload: + self._unload_file(path) + # Step 2. Load all the files needed + for path in files_to_load: + self._load_file(path) + # Step 3. Collect all torch parameters in order + torch_params = [self.cached_files[self.torch_to_path[i]][i] for i in torch_names] + # Step 4. Apply the mapping function + with self.stats.timer("map_time_sec"): + param = self.extern_param_map.map_func[mlc_name](*torch_params) + logger.info(' Parameter: "%s", shape: %s, dtype: %s', mlc_name, param.shape, param.dtype) + param = as_ndarray(param) + return param + + def _load_file(self, path: Path) -> None: + logger.info("Loading PyTorch parameters from: %s", path) + with self.stats.timer("load_time_sec"): + result = {} + for name, param in _load_torch_shard(path): + result[name] = param + self.stats.mem_add(param.nbytes) + self.cached_files[path] = result + + def _unload_file(self, path: Path) -> None: + logger.info("Unloading PyTorch weight file: %s", path) + with self.stats.timer("load_time_sec"): + for _, param in self.cached_files[path].items(): + self.stats.mem_rm(param.nbytes) + del self.cached_files[path] + gc.collect() + + +def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]): + used_torch_names = set(sum(param_map.param_map.values(), ())) + # Check 1. All PyTorch parameters in the weight files are used unless explicitly specified + unused_torch_names = torch_weights - used_torch_names - param_map.unused_params + if unused_torch_names: + logger.warning( + "Unused torch parameters: %s", + ", ".join(sorted(unused_torch_names)), + ) + # Check 2. All PyTorch parameters required are stored in the weight files + nonexistent_torch_names = used_torch_names - torch_weights + if nonexistent_torch_names: + raise ValueError( + "The following torch parameters do not exist in the weight files:\n " + + "\n ".join(sorted(nonexistent_torch_names)), + ) + + +def _load_torch_shard(path: Path): + import torch # pylint: disable=import-outside-toplevel + + for name, param in torch.load(path, map_location=torch.device("cpu")).items(): + param = param.detach().cpu() + dtype = str(param.dtype) + if dtype == "torch.bfloat16": + param = param.float() + param = param.numpy() + yield name, param + + +def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]: + # Step 1. Build a map from path to torch parameters + path_to_torch: Dict[Path, List[str]] = defaultdict(list) + for torch_name, path in torch_to_path.items(): + path_to_torch[path].append(torch_name) + # Step 2. Build a map from torch parameters to MLC parameters + torch_to_mlc = defaultdict(list) + for mlc_name, torch_names in param_map.param_map.items(): + for torch_name in torch_names: + torch_to_mlc[torch_name].append(mlc_name) + # Step 3. Construct the ordering that ensures file locality + order = OrderedDict() + for _, torch_names in path_to_torch.items(): + for torch_name in torch_names: + for mlc_name in torch_to_mlc[torch_name]: + if mlc_name not in order: + order[mlc_name] = 1 + return list(order.keys()) + + +__all__ = ["HFTorchLoader"] diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/parameter/mapping.py new file mode 100644 index 0000000000..3018c91ca3 --- /dev/null +++ b/python/mlc_chat/compiler/parameter/mapping.py @@ -0,0 +1,79 @@ +"""Parameter mapping for converting different LLM implementations to MLC LLM.""" +import dataclasses +from typing import Callable, Dict, List, Set + +import numpy as np +from tvm.runtime import NDArray + + +@dataclasses.dataclass +class ExternMapping: + """Mapping from a parameter name in MLC LLM's model definition to its potential source, + for example, from MLC parameter "model.layers.2.post_attention_layernorm.weight" to PyTorch's + parameter correspondingly. + + Parameters + ---------- + param_map : Dict[str, List[str]] + A dictionary that maps the name of a parameter to its source. For example, + in Llama2, the source of MLC parameter "model.layers.0.self_attn.qkv_proj.weight" from + huggingface torch are: + + - "model.layers.0.self_attn.q_proj.weight" + - "model.layers.0.self_attn.k_proj.weight" + - "model.layers.0.self_attn.v_proj.weight" + + map_func : Dict[str, Callable[[np.ndarray, ...], np.ndarray]] + A dictionary that maps the name of a parameter to a function that combines the source + parameters into the MLC parameter. For example, for the above example, the function + would be: `lambda q, k, v: np.concatenate([q, k, v], axis=0)`. + + unused_params : Set[str] + Parameter names in the source weights that are not used in the MLC LLM model definition. + """ + + param_map: Dict[str, List[str]] + map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] + unused_params: Set[str] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class QuantizeMapping: + """Mapping from a parameter in MLC LLM's model definition to its eventual names and values after + quantization. In certain group quantization, for example, `qkv_proj.weight` is mapped to + `qkv_proj.weight_quantized` and `qkv_proj.weight_scale` respectively. If a parameter's name is + not in the mapping, it is assumed to be unchanged, i.e. not quantized. + + Parameters + ---------- + param_map : Dict[str, List[str]] + A dictionary that maps the name of a parameter to its destination. For example, + in certain group quantization, the destinations of MLC parameter "qkv_proj.weight` are: + + - "qkv_proj.weight_quantized" + - "qkv_proj.weight_scale" + + map_func : Dict[str, Callable[NDArray, List[NDArray]]] + A dictionary that maps the name of a parameter to a function that splits the MLC parameter + into the destination parameters. + + Notes + ----- + There are two forms of weight conversion in MLC LLM, one is A) on-the-fly quantization to the + raw fp16/bf16/fp32 weights from HuggingFace, and the other is B) loading pre-quantized weights + from an external framework, e.g. AutoGPTQ, AutoAWQ. From the perspective of parameter + correspondence. + + - In case A), it is recommended that the weight loader take both `ExternMapping` and + `QuantizeMapping` as input, and do quantiaztion on the fly as a raw parameter being + loaded into RAM; + - In case B), a pass over `nn.Module` is recommended to take place first to converts parameters + from its non-quantized form to the quantized one, and then only `ExternMapping` is + used to convert the quantized parameters into the desired form. + """ + + param_map: Dict[str, Callable[str, List[str]]] + map_func: Dict[str, Callable[NDArray, List[NDArray]]] + + +__all__ = ["ExternMapping", "QuantizeMapping"] diff --git a/python/mlc_chat/support/__init__.py b/python/mlc_chat/support/__init__.py new file mode 100644 index 0000000000..ca5d7a6b5b --- /dev/null +++ b/python/mlc_chat/support/__init__.py @@ -0,0 +1,4 @@ +""" +Common utilities used in the Python package. Do not import anything by default, +as they may introduce unnecessary dependencies. +""" diff --git a/mlc_llm/models/model_config_base.py b/python/mlc_chat/support/config.py similarity index 65% rename from mlc_llm/models/model_config_base.py rename to python/mlc_chat/support/config.py index 85ac46dfc2..62270ffd9c 100644 --- a/mlc_llm/models/model_config_base.py +++ b/python/mlc_chat/support/config.py @@ -1,18 +1,24 @@ """ -Utilities that handle model configuration. Model configuration is usually a JSON file in HuggingFace -that contains the model's hyperparameters. For instance, Vicuna-13b-v1.5-16k contains the following -config file: https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json +A common base class for configuration. A configuration could be initialized from its constructor, +a JSON string or a JSON file, and irrelevant fields during initialization are automatically moved +to the `kwargs` field. + +Take model configuration as an example: it is usually a JSON file in HuggingFace that contains +the model's hyperparameters. For instance, Vicuna-13b-v1.5-16k contains the following +[JSON file](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json). +The base class allows us to load the configuration from this JSON file, moving irrelevant fields +into `kwargs`, such as `transformers_version` and `use_cache`. """ import dataclasses import json from pathlib import Path from typing import Any, Dict, Type, TypeVar -ConfigClass = TypeVar("ConfigClass", bound="ModelConfig") +ConfigClass = TypeVar("ConfigClass", bound="ConfigBase") -class ModelConfig: - """Base class for model configurations, providing a common interface for loading configs from a +class ConfigBase: + """Base class for configurations, providing a common interface for loading configs from a JSON file or a dict. It requires the subclasses to be dataclasses, and has an `kwargs` field that stores the extra fields that are not defined in the dataclass. """ @@ -55,3 +61,6 @@ def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass: """ with source.open("r", encoding="utf-8") as in_file: return cls.from_dict(json.load(in_file)) + + +__all__ = ["ConfigBase"] diff --git a/python/mlc_chat/support/tqdm.py b/python/mlc_chat/support/tqdm.py new file mode 100644 index 0000000000..9adceca480 --- /dev/null +++ b/python/mlc_chat/support/tqdm.py @@ -0,0 +1,38 @@ +"""Utils to better use tqdm""" +import contextlib +import inspect +import io + +from tqdm import tqdm +from tqdm.contrib.logging import logging_redirect_tqdm as _redirect_logging + + +@contextlib.contextmanager +def _redirect_print(): + old_print = print + + def new_print(*args, **kwargs): + with io.StringIO() as output: + kwargs["file"] = output + kwargs["end"] = "" + old_print(*args, **kwargs) + content = output.getvalue() + tqdm.write(content) + + try: + inspect.builtins.print = new_print + yield + finally: + inspect.builtins.print = old_print + + +@contextlib.contextmanager +def redirect(): + """Redirect tqdm output to logging and print.""" + + with _redirect_logging(): + with _redirect_print(): + yield + + +__all__ = ["tqdm", "redirect"] diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py new file mode 100644 index 0000000000..e8757fd234 --- /dev/null +++ b/tests/python/model/test_llama.py @@ -0,0 +1,19 @@ +# pylint: disable=invalid-name,missing-docstring +import pytest +from mlc_chat.compiler.model.llama import LlamaConfig, LlamaForCasualLM + + +@pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) +def test_llama2_creation(model_name: str): + config = LlamaConfig.from_predefined(model_name) + model = LlamaForCasualLM(config) + mod, named_params = model.export_tvm(spec=model.get_default_spec()) + mod.show(black_format=False) + for name, param in named_params: + print(name, param.shape, param.dtype) + + +if __name__ == "__main__": + test_llama2_creation("llama2_7b") + test_llama2_creation("llama2_13b") + test_llama2_creation("llama2_70b") diff --git a/tests/python/parameter/test_hf_torch_loader.py b/tests/python/parameter/test_hf_torch_loader.py new file mode 100644 index 0000000000..745773b209 --- /dev/null +++ b/tests/python/parameter/test_hf_torch_loader.py @@ -0,0 +1,42 @@ +# pylint: disable=missing-docstring +import logging +from pathlib import Path + +import pytest +from mlc_chat.compiler.model.llama import LlamaConfig +from mlc_chat.compiler.model.llama_parameter import hf_torch +from mlc_chat.compiler.parameter import HFTorchLoader +from mlc_chat.support import tqdm + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_llama(base_path: str): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + config = LlamaConfig.from_file(path_config) + loader = HFTorchLoader(path=path_params, extern_param_map=hf_torch(config)) + with tqdm.redirect(): + for _name, _param in loader.load(): + ... + + +if __name__ == "__main__": + test_load_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_llama(base_path="./dist/models/Llama-2-70b-hf") diff --git a/tests/python/test_model_llama.py b/tests/python/test_model_llama.py deleted file mode 100644 index 3019b85bc0..0000000000 --- a/tests/python/test_model_llama.py +++ /dev/null @@ -1,70 +0,0 @@ -# pylint: disable=invalid-name,missing-docstring -import numpy as np -from tvm.relax.frontend.nn import spec - -from mlc_llm.models.llama import LlamaConfig, LlamaForCasualLM - - -def main(): - config = LlamaConfig( - hidden_act="silu", - hidden_size=256, - intermediate_size=688, - max_sequence_length=128, - num_attention_heads=8, - num_hidden_layers=8, - rms_norm_eps=1e-05, - vocab_size=4096, - position_embedding_base=10000, - ) - batch_size, total_seq_len, dtype = 1, 32, "float32" - - # Usecase 1. Define a model and export it to TVM's IRModule - model = LlamaForCasualLM(config) - model.to(dtype=dtype) - mod_spec = { - "prefill": { - "inputs": spec.Tensor([batch_size, "seq_len"], "int32"), - "total_seq_len": int, - }, - "decode": { - "inputs": spec.Tensor([batch_size, 1], "int32"), - "total_seq_len": int, - }, - "softmax_with_temperature": { - "logits": spec.Tensor([1, 1, config.vocab_size], "float32"), - "temperature": spec.Tensor([], "float32"), - }, - } - mod, _ = model.export_tvm(spec=mod_spec) - mod.show(black_format=False) - - # Usecase 2. JIT compile a model - for _, param in model.state_dict().items(): - param.data = np.random.rand(*param.shape).astype(param.dtype) - model = model.jit( - spec=mod_spec, - target="llvm", - device="cpu", - out_format="torch", - ) - - # Usecase 3. Run a model with PyTorch - import torch # pylint: disable=import-outside-toplevel - - result = model["prefill"]( - torch.from_numpy( - np.random.randint( - 0, - config.vocab_size, - size=(batch_size, total_seq_len), - dtype="int32", - ) - ), - total_seq_len, - ) - assert isinstance(result, torch.Tensor) - - -if __name__ == "__main__": - main() diff --git a/tests/python/test_param_loader_llama.py b/tests/python/test_param_loader_llama.py deleted file mode 100644 index 4c34c6964b..0000000000 --- a/tests/python/test_param_loader_llama.py +++ /dev/null @@ -1,32 +0,0 @@ -# pylint: disable=missing-docstring -import logging -from pathlib import Path - -from mlc_llm.models.llama import LlamaConfig -from mlc_llm.models.llama_param_map import hf_torch -from mlc_llm.param_loader import HFTorchLoader - -logging.basicConfig( - level=logging.DEBUG, - style="{", - datefmt="%Y-%m-%d %H:%M:%S", - format="{asctime} {levelname} {filename}:{lineno}: {message}", -) - - -def test_load_7b(): - prefix = Path("./dist/models/llama-2-7b-chat-hf/") - path_config = prefix / "config.json" - path_params = prefix / "pytorch_model.bin.index.json" - - model_config = LlamaConfig.from_file(path_config) - with HFTorchLoader( - config_path=path_params, - param_map=hf_torch(model_config), - ) as loader: - for name in loader.suggest_loading_order(): - loader.load_param(name=name) - - -if __name__ == "__main__": - test_load_7b() From 56a8004edfd783683a947e2e0f452b3195264960 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 19 Oct 2023 10:37:24 -0700 Subject: [PATCH 031/116] Update README.md for Multi-GPU (#1090) --- README.md | 16 ++ site/img/multi-gpu/figure-1.svg | 247 +++++++++++++++++++ site/img/multi-gpu/figure-2.svg | 418 ++++++++++++++++++++++++++++++++ site/img/multi-gpu/figure-3.svg | 167 +++++++++++++ 4 files changed, 848 insertions(+) create mode 100644 site/img/multi-gpu/figure-1.svg create mode 100644 site/img/multi-gpu/figure-2.svg create mode 100644 site/img/multi-gpu/figure-3.svg diff --git a/README.md b/README.md index 4ef145d4cd..f20d1c8a93 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,24 @@
Supported Model ArchitectureModel Variants with PrebuiltsArchitecturePrebuilt Model Variants
+ +**Scalable.** MLC LLM scales universally on NVIDIA and AMD GPUs, cloud and gaming GPUs. Below +showcases our single batch decoding performance with prefilling = 1 and decoding = 256. + +Performance of 4-bit CodeLlama-34B and Llama2-70B on two NVIDIA RTX 4090 and two AMD Radeon 7900 XTX: +

+ + +

+ +Scaling of fp16 and 4-bit CodeLlama-34 and Llama2-70B on A100-80G-PCIe and A10G-24G-PCIe, up to 8 GPUs: +

+ +

+ ## News +* [10/18/2023] [[Post]](https://blog.mlc.ai/2023/10/19/Scalable-Language-Model-Inference-on-Multiple-NVDIA-AMD-GPUs) Scalable multi-GPU support for CUDA and ROCm are official. * [09/02/2023] Prebuilt ROCm 5.7 and CUDA 12.2 package is [available](https://llm.mlc.ai/docs/install/tvm.html#option-1-prebuilt-package). * [08/25/2023] CodeLlama support is up. * [08/14/2023] [[Post]](https://blog.mlc.ai/2023/08/09/GPU-Accelerated-LLM-on-Orange-Pi) Mali GPU support is up on Orange Pi. diff --git a/site/img/multi-gpu/figure-1.svg b/site/img/multi-gpu/figure-1.svg new file mode 100644 index 0000000000..d3083cf775 --- /dev/null +++ b/site/img/multi-gpu/figure-1.svg @@ -0,0 +1,247 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/img/multi-gpu/figure-2.svg b/site/img/multi-gpu/figure-2.svg new file mode 100644 index 0000000000..70d35f5037 --- /dev/null +++ b/site/img/multi-gpu/figure-2.svg @@ -0,0 +1,418 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/site/img/multi-gpu/figure-3.svg b/site/img/multi-gpu/figure-3.svg new file mode 100644 index 0000000000..078231fae6 --- /dev/null +++ b/site/img/multi-gpu/figure-3.svg @@ -0,0 +1,167 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From b0373d172e2ebd20ece84792d8db13134b9d1dfc Mon Sep 17 00:00:00 2001 From: Rick Zhou Date: Thu, 19 Oct 2023 12:09:12 -0700 Subject: [PATCH 032/116] Support lib_path override in C++. Improvements on docs and error messages (#1086) * Support lib_path option in C++ CLI. Disable ChatConfig.model_lib override in Python API. Improvements on helper messages and error messages * Update docs * Rename lib_path -> model_lib_path --- cpp/cli_main.cc | 64 +++++++++++++++++++++++----------- docs/deploy/cli.rst | 12 +++++++ docs/deploy/python.rst | 12 ++++++- python/mlc_chat/chat_module.py | 37 ++++++++++++-------- 4 files changed, 89 insertions(+), 36 deletions(-) diff --git a/cpp/cli_main.cc b/cpp/cli_main.cc index 4ad899214e..db35457b79 100644 --- a/cpp/cli_main.cc +++ b/cpp/cli_main.cc @@ -163,7 +163,7 @@ struct ModelPaths { */ std::filesystem::path lib; - static ModelPaths Find(const std::string& device_name, const std::string& local_id); + static ModelPaths Find(const std::string& device_name, const std::string& local_id, const std::string& user_lib_path); }; /*! @@ -337,7 +337,7 @@ std::string ReadStringFromJSONFile(const std::filesystem::path& config_path, return config[key].get(); } -ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id) { +ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id, const std::string &user_lib_path) { // Step 1. Find config path std::filesystem::path config_path; if (auto path = TryInferMLCChatConfig(local_id)) { @@ -368,10 +368,17 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l } std::cout << "Use model weights: " << params_json << std::endl; // Step 3. Find model lib path - std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib"); - std::string lib_name = lib_local_id + "-" + device_name; std::filesystem::path lib_path; - if (auto path = FindFile({lib_local_id, + if (!user_lib_path.empty()) { + lib_path = user_lib_path; + if (!std::filesystem::exists(lib_path) || !std::filesystem::is_regular_file(lib_path)) { + LOG(FATAL) << "The `lib_path` you passed in is not a file: " << user_lib_path << "\n"; + exit(1); + } + } else { + std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib"); + std::string lib_name = lib_local_id + "-" + device_name; + if (auto path = FindFile({lib_local_id, "dist/prebuilt/lib", // Using prebuilt workflow "dist/" + local_id, "dist/prebuilt/" + lib_local_id}, { @@ -379,15 +386,18 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l lib_name, }, GetLibSuffixes())) { - lib_path = path.value(); - } else { - LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n" - << "We searched over the following possible paths: \n" - << "- " + lib_local_id << "\n" - << "- dist/prebuilt/lib \n" - << "- dist/" + local_id << "\n" - << "- dist/prebuilt/" + lib_local_id; - exit(1); + lib_path = path.value(); + } else { + LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n" + << "We searched over the following possible paths: \n" + << "- " + lib_local_id << "\n" + << "- dist/prebuilt/lib \n" + << "- dist/" + local_id << "\n" + << "- dist/prebuilt/" + lib_local_id << "\n" + << "If you would like to directly specify the full model library path, you may " + << "consider passing in the `--model-lib-path` argument.\n"; + exit(1); + } } std::cout << "Use model library: " << lib_path << std::endl; return ModelPaths{config_path, params_json, lib_path}; @@ -427,8 +437,8 @@ void Converse(ChatModule* chat, const std::string& input, int stream_interval, * \param stream_interval The interval that should be used for streaming the response. */ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id, - int stream_interval = 2) { - ModelPaths model = ModelPaths::Find(device_name, local_id); + std::string lib_path, int stream_interval = 2) { + ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path); PrintSpecialCommands(); chat->Reload(model); chat->ProcessSystemPrompts(); @@ -456,7 +466,7 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id if (new_local_id.empty()) { new_local_id = local_id; } - model = ModelPaths::Find(device_name, new_local_id); + model = ModelPaths::Find(device_name, new_local_id, lib_path); chat->Reload(model); local_id = new_local_id; } else if (input.substr(0, 5) == "/help") { @@ -470,7 +480,17 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id int main(int argc, char* argv[]) { argparse::ArgumentParser args("mlc_chat"); - args.add_argument("--model"); + args.add_description("MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n" + "Note: the --model argument is required. It can either be the model name with its " + "quantization scheme or a full path to the model folder. In the former case, the " + "provided name will be used to search for the model folder over possible paths. " + "--model-lib-path argument is optional. If unspecified, the --model argument will be used " + "to search for the library file over possible paths."); + + args.add_argument("--model") + .help("[required] the model to use"); + args.add_argument("--model-lib-path") + .help("[optional] the full path to the model library file to use"); args.add_argument("--device").default_value("auto"); args.add_argument("--evaluate").default_value(false).implicit_value(true); args.add_argument("--eval-prompt-len").default_value(128).scan<'i', int>(); @@ -485,6 +505,10 @@ int main(int argc, char* argv[]) { } std::string local_id = args.get("--model"); + std::string lib_path; + if (args.present("--model-lib-path")) { + lib_path = args.get("--model-lib-path"); + } auto [device_name, device_id] = DetectDevice(args.get("--device")); try { @@ -494,14 +518,14 @@ int main(int argc, char* argv[]) { // that are not supposed to be used in chat app setting int prompt_len = args.get("--eval-prompt-len"); int gen_len = args.get("--eval-gen-len"); - ModelPaths model = ModelPaths::Find(device_name, local_id); + ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path); tvm::runtime::Module chat_mod = mlc::llm::CreateChatModule(GetDevice(device_name, device_id)); std::string model_path = model.config.parent_path().string(); tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(model.lib.string()); chat_mod.GetFunction("reload")(lib, tvm::String(model_path)); chat_mod.GetFunction("evaluate")(prompt_len, gen_len); } else { - Chat(&chat, device_name, local_id); + Chat(&chat, device_name, local_id, lib_path); } } catch (const std::runtime_error& err) { std::cerr << err.what() << std::endl; diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index 79501b113d..460ac71c7d 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -111,6 +111,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o - Model lib should be placed at ``./dist/prebuilt/lib/$(local_id)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(local_id)/``. + .. note:: + Please make sure that you have the same directory structure as above, because the CLI tool + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument + to the CLI + .. collapse:: Example .. code:: shell @@ -134,6 +140,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o - Model libraries should be placed at ``./dist/$(local_id)/$(local_id)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/$(local_id)/params/``. + .. note:: + Please make sure that you have the same directory structure as above, because the CLI tool + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument + to the CLI + .. collapse:: Example .. code:: shell diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index b27d8ff935..22e13702d2 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -51,6 +51,11 @@ If you do not have the MLC-compiled ``model`` ready: - Model lib should be placed at ``./dist/prebuilt/lib/$(model)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(model)/``. + .. note:: + Please make sure that you have the same directory structure as above, because Python API + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path`` + .. collapse:: Example .. code:: shell @@ -74,6 +79,11 @@ If you do not have the MLC-compiled ``model`` ready: - Model libraries should be placed at ``./dist/$(model)/$(model)-$(arch).$(suffix)``. - Model weights and chat config are located under ``./dist/$(model)/params/``. + .. note:: + Please make sure that you have the same directory structure as above, because Python API + relies on it to automatically search for model lib and weights. If you would like to directly + provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path`` + .. collapse:: Example .. code:: shell @@ -157,7 +167,7 @@ You can also checkout the :doc:`/prebuilt_models` page to run other models. | .. note:: - You could also specify the address of ``model`` and ``lib_path`` explicitly. If + You could also specify the address of ``model`` and ``model_lib_path`` explicitly. If you only specify ``model`` as ``model_name`` and ``quantize_mode``, we will do a search for you. See more in the documentation of :meth:`mlc_chat.ChatModule.__init__`. diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index db46c080f4..2ab9334b67 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -5,6 +5,7 @@ import logging import os import sys +import warnings from dataclasses import asdict, dataclass, fields from enum import Enum from typing import List, Optional @@ -351,6 +352,12 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi # We override using user's chat config for field in fields(user_chat_config): field_name = field.name + if field_name == 'model_lib': + warn_msg = ('WARNING: Do not override "model_lib" in ChatConfig. ' + 'This override will be ignored. ' + 'Please use ChatModule.model_lib_path to override the full model library path instead.') + warnings.warn(warn_msg) + continue field_value = getattr(user_chat_config, field_name) if field_value is not None: setattr(final_chat_config, field_name, field_value) @@ -389,7 +396,7 @@ def _get_lib_module_path( model: str, model_path: str, chat_config: ChatConfig, - lib_path: Optional[str], + model_lib_path: Optional[str], device_name: str, config_file_path: str, ) -> str: @@ -403,7 +410,7 @@ def _get_lib_module_path( Model path found by `_get_model_path`. chat_config : ChatConfig Chat config after potential overrides. Returned by ``_get_chat_config``. - lib_path : Optional[str] + model_lib_path : Optional[str] User's input. Supposedly a full path to model library. Prioritized to use. device_name : str User's input. Used to construct the library model file name. @@ -412,21 +419,21 @@ def _get_lib_module_path( Returns ------ - lib_path : str + model_lib_path : str The path pointing to the model library we find. Raises ------ FileNotFoundError: if we cannot find a valid model library file. """ - # 1. Use user's lib_path if provided - if lib_path is not None: - if os.path.isfile(lib_path): - logging.info(f"Using library model: {lib_path}") - return lib_path + # 1. Use user's model_lib_path if provided + if model_lib_path is not None: + if os.path.isfile(model_lib_path): + logging.info(f"Using library model: {model_lib_path}") + return model_lib_path else: err_msg = ( - f"The `lib_path` you passed in is not a file: {lib_path}.\nPlease checkout " + f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\nPlease checkout " f"{_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on how to load a model." ) raise FileNotFoundError(err_msg) @@ -482,7 +489,7 @@ def _get_lib_module_path( err_msg += f"- {candidate}\n" err_msg += ( "If you would like to directly specify the model library path, you may " - "consider passing in the `lib_path` parameter.\n" + "consider passing in the `ChatModule.model_lib_path` parameter.\n" f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example " "on how to load a model." ) @@ -659,7 +666,7 @@ class ChatModule: A ``ChatConfig`` instance partially filled. Will be used to override the ``mlc-chat-config.json``. - lib_path : Optional[str] + model_lib_path : Optional[str] The full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use the provided ``model`` to search over possible paths. @@ -670,7 +677,7 @@ def __init__( model: str, device: str = "auto", chat_config: Optional[ChatConfig] = None, - lib_path: Optional[str] = None, + model_lib_path: Optional[str] = None, ): device_err_msg = ( f"Invalid device name: {device}. Please enter the device in the form " @@ -732,15 +739,15 @@ def __init__( self.chat_config = _get_chat_config(self.config_file_path, chat_config) # 5. Look up model library - self.lib_path = _get_lib_module_path( - model, self.model_path, self.chat_config, lib_path, device_name, self.config_file_path + self.model_lib_path = _get_lib_module_path( + model, self.model_path, self.chat_config, model_lib_path, device_name, self.config_file_path ) # 6. Call reload user_chat_config_json_str = _convert_chat_config_to_json_str( self.chat_config, self.chat_config.conv_template ) - self._reload(self.lib_path, self.model_path, user_chat_config_json_str) + self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str) def generate( self, From 830656fa9779ecfb121b7eef218d04e1ad7e50bf Mon Sep 17 00:00:00 2001 From: Varshith Bathini Date: Fri, 20 Oct 2023 00:40:14 +0530 Subject: [PATCH 033/116] StreamIterator (#1057) Co-authored-by: Varshith --- docs/deploy/python.rst | 32 ++++++++++++++++++ examples/python/sample_chat_stream.py | 30 +++++++++++++++++ python/mlc_chat/callback.py | 47 +++++++++++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 examples/python/sample_chat_stream.py diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index 22e13702d2..1a046538f9 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -306,6 +306,38 @@ We provide an example below. Additionally, system prompts will not be run when instantiating a `mlc_chat.ChatModule`, unless explicitly given inside the prompt. +Stream Iterator in Python +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Stream Iterator gives users an option to stream generated text to the function that the API is called from, +instead of streaming to stdout, which could be a necessity when building services on top of MLC Chat. + +We provide an example below. + +.. code:: python + + from mlc_chat import ChatModule + from mlc_chat.callback import StreamIterator + + # Create a ChatModule instance + cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + + # Stream to an Iterator + from threading import Thread + + stream = StreamIterator(callback_interval=2) + generation_thread = Thread( + target=cm.generate, + kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, + ) + generation_thread.start() + + output = "" + for delta_message in stream: + output += delta_message + + generation_thread.join() + API Reference ------------- diff --git a/examples/python/sample_chat_stream.py b/examples/python/sample_chat_stream.py new file mode 100644 index 0000000000..980e833d20 --- /dev/null +++ b/examples/python/sample_chat_stream.py @@ -0,0 +1,30 @@ +from mlc_chat import ChatModule +from mlc_chat.callback import StreamToStdout, StreamIterator + +# From the mlc-llm directory, run +# $ python examples/python/sample_chat_stream.py + +# Create a ChatModule instance +cm = ChatModule(model="Llama-2-7b-chat-hf-q4f16_1") + +# Stream to Stdout +output = cm.generate( + prompt="What is the meaning of life?", + progress_callback=StreamToStdout(callback_interval=2), +) + +# Stream to an Iterator +from threading import Thread + +stream = StreamIterator(callback_interval=2) +generation_thread = Thread( + target=cm.generate, + kwargs={"prompt": "What is the meaning of life?", "progress_callback": stream}, +) +generation_thread.start() + +output = "" +for delta_message in stream: + output += delta_message + +generation_thread.join() diff --git a/python/mlc_chat/callback.py b/python/mlc_chat/callback.py index faf2dbd953..921d9c0052 100644 --- a/python/mlc_chat/callback.py +++ b/python/mlc_chat/callback.py @@ -1,6 +1,9 @@ """Namespace of callback functions in Python API.""" #! pylint: disable=unused-import, invalid-name, unnecessary-pass +from queue import Queue +from typing import Optional + from .base import get_delta_message @@ -74,3 +77,47 @@ def delta_callback(self, delta_message: str): def stopped_callback(self): r"""Stream an additional '\n' when generation ends.""" print() + + +class StreamIterator(DeltaCallback): + """Stream the output using an iterator. + A queue stores the delta messages""" + + def __init__(self, callback_interval: int = 2, timeout: Optional[float] = None): + r"""Initialize the callback class with callback interval and queue timeout. + + Parameters + ---------- + callback_interval : int + The refresh rate of the streaming process. + timeout : Optional[float] + Timeout for put and get from the delta messages queue + """ + super().__init__() + self.delta_messages = Queue() + self.callback_interval = callback_interval + self.timeout = timeout + + def delta_callback(self, delta_message: str): + r"""Stream the delta message to iterator (adding). + + Parameters + ---------- + delta_message : str + The delta message (the part that has not been added to queue yet). + """ + self.delta_messages.put(delta_message, timeout=self.timeout) + + def stopped_callback(self): + """Using None as the stop signal for the iterator""" + self.delta_messages.put(None, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.delta_messages.get(timeout=self.timeout) + if value: + return value + else: + raise StopIteration() From 9bf5723945ca6016a50eedffae295390b7c11ac6 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 19 Oct 2023 15:49:40 -0700 Subject: [PATCH 034/116] Update `benchmark.py` according to #1086 (#1091) Update `benchmark.py` --- python/mlc_chat/cli/benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_chat/cli/benchmark.py b/python/mlc_chat/cli/benchmark.py index 308921e3d0..0a4d5d97f3 100644 --- a/python/mlc_chat/cli/benchmark.py +++ b/python/mlc_chat/cli/benchmark.py @@ -61,7 +61,7 @@ def main(): chat_config=ChatConfig( num_shards=args.num_shards, ), - lib_path=args.model_lib, + model_lib_path=args.model_lib, ) output = chat_module.benchmark_generate(args.prompt, generate_length=args.generate_length) print(f"Generated text:\n{output}\n") From 62d0c031288b2be7cd1f7573cf1109ce96b033bc Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Fri, 20 Oct 2023 15:33:51 -0700 Subject: [PATCH 035/116] Disable Disco for q4f16_ft and q8f16_ft quantization (#1094) --- mlc_llm/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 34e3041e2f..d187ebe453 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -597,6 +597,9 @@ def build_model_from_args(args: argparse.Namespace): "`num_shards` should be used together with " "`--build-model-only` and `--convert-weight-only`" ) + use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] + if use_ft_quant: + raise ValueError("Multi-GPU deployments are not available for ft quantization.") os.makedirs(args.artifact_path, exist_ok=True) if args.debug_dump: os.makedirs(os.path.join(args.artifact_path, "debug"), exist_ok=True) @@ -614,7 +617,6 @@ def build_model_from_args(args: argparse.Namespace): config = json.load(i_f) if not use_cache or args.convert_weight_only: - model_generators = { "llama": llama, "mistral": llama, From cf39bf6f00c24e32bbd21cee18c6d4afa6202874 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 20 Oct 2023 15:38:49 -0700 Subject: [PATCH 036/116] [Format] Apply isort and black for `python/` (#1097) [Format] Apply isort and black on `python/` The commands I am using are: ``` isort --profile black python/ black python/ ``` It is always recommended to format the code before submission, given we don't have a linter CI yet. --- python/mlc_chat/__init__.py | 5 +---- python/mlc_chat/chat_module.py | 17 ++++++++++++----- python/mlc_chat/embeddings/openai.py | 22 +++++++++------------- python/mlc_chat/gradio.py | 14 ++++++++++---- python/mlc_chat/interface/openai_api.py | 21 ++++++++++++++++++--- python/mlc_chat/rest.py | 13 ++++++------- 6 files changed, 56 insertions(+), 36 deletions(-) diff --git a/python/mlc_chat/__init__.py b/python/mlc_chat/__init__.py index eb2bdeebc1..756c785bf6 100644 --- a/python/mlc_chat/__init__.py +++ b/python/mlc_chat/__init__.py @@ -2,8 +2,5 @@ MLC Chat is the app runtime of MLC LLM. """ +from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ -from .chat_module import ChatModule -from .chat_module import ConvConfig -from .chat_module import ChatConfig -from .chat_module import GenerationConfig diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 2ab9334b67..3bc32309e7 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -352,10 +352,12 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi # We override using user's chat config for field in fields(user_chat_config): field_name = field.name - if field_name == 'model_lib': - warn_msg = ('WARNING: Do not override "model_lib" in ChatConfig. ' - 'This override will be ignored. ' - 'Please use ChatModule.model_lib_path to override the full model library path instead.') + if field_name == "model_lib": + warn_msg = ( + 'WARNING: Do not override "model_lib" in ChatConfig. ' + "This override will be ignored. " + "Please use ChatModule.model_lib_path to override the full model library path instead." + ) warnings.warn(warn_msg) continue field_value = getattr(user_chat_config, field_name) @@ -740,7 +742,12 @@ def __init__( # 5. Look up model library self.model_lib_path = _get_lib_module_path( - model, self.model_path, self.chat_config, model_lib_path, device_name, self.config_file_path + model, + self.model_path, + self.chat_config, + model_lib_path, + device_name, + self.config_file_path, ) # 6. Call reload diff --git a/python/mlc_chat/embeddings/openai.py b/python/mlc_chat/embeddings/openai.py index 5795ed8158..ed8dd5ea93 100644 --- a/python/mlc_chat/embeddings/openai.py +++ b/python/mlc_chat/embeddings/openai.py @@ -1,17 +1,11 @@ from __future__ import annotations -from langchain.embeddings import OpenAIEmbeddings -from langchain.embeddings.openai import embed_with_retry, async_embed_with_retry - import logging -from typing import ( - List, - Optional, - Sequence, - Tuple, -) +from typing import List, Optional, Sequence, Tuple import numpy as np +from langchain.embeddings import OpenAIEmbeddings +from langchain.embeddings.openai import async_embed_with_retry, embed_with_retry logger = logging.getLogger(__name__) @@ -121,9 +115,9 @@ def _get_len_safe_embeddings( self, input="", **self._invocation_params, - )[ - "data" - ][0]["embedding"] + )["data"][ + 0 + ]["embedding"] for _result, num_tokens in zip(results, num_tokens_in_batch): if len(_result) == 0: average = empty_average @@ -155,7 +149,9 @@ async def _aget_len_safe_embeddings( input="", **self._invocation_params, ) - )["data"][0]["embedding"] + )[ + "data" + ][0]["embedding"] for _result, num_tokens in zip(results, num_tokens_in_batch): if len(_result) == 0: average = empty_average diff --git a/python/mlc_chat/gradio.py b/python/mlc_chat/gradio.py index 5975a8681d..8f0e16ab26 100644 --- a/python/mlc_chat/gradio.py +++ b/python/mlc_chat/gradio.py @@ -2,10 +2,11 @@ # pylint: disable=import-error, import-outside-toplevel, invalid-name, line-too-long, protected-access # too-many-instance-attributes, too-many-locals, unused-import -from typing import Dict import argparse -import os import glob +import os +from typing import Dict + import gradio as gr from .chat_module import ChatModule @@ -148,7 +149,12 @@ def gradio_stats(self): def launch_gradio( - artifact_path: str = "dist", device: str = "auto", port: int = 7860, share: bool = False, host: str = "127.0.0.1"): + artifact_path: str = "dist", + device: str = "auto", + port: int = 7860, + share: bool = False, + host: str = "127.0.0.1", +): r"""Launch the gradio interface with a given port, creating a publically sharable link if specified.""" # create a gradio module @@ -230,7 +236,7 @@ def launch_gradio( stats_button.click(mod.gradio_stats, [], [stats_output]) # launch to the web - demo.launch(share=share, enable_queue=True, server_port=port,server_name=host) + demo.launch(share=share, enable_queue=True, server_port=port, server_name=host) if __name__ == "__main__": diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index a707608ab1..2a94607741 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -2,10 +2,11 @@ Adapted from FastChat's OpenAI protocol: https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py """ -from typing import Literal, Optional, List, Dict, Any, Union -from pydantic import BaseModel, Field -import shortuuid import time +from typing import Any, Dict, List, Literal, Optional, Union + +import shortuuid +from pydantic import BaseModel, Field class ChatMessage(BaseModel): @@ -13,6 +14,7 @@ class ChatMessage(BaseModel): content: str name: str | None = None + class ChatCompletionRequest(BaseModel): model: str messages: list[ChatMessage] @@ -35,16 +37,19 @@ class ChatCompletionRequest(BaseModel): # logit_bias # user: Optional[str] = None + class UsageInfo(BaseModel): prompt_tokens: int = 0 completion_tokens: int | None = 0 total_tokens: int = 0 + class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length"] | None = None + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion" @@ -53,21 +58,25 @@ class ChatCompletionResponse(BaseModel): # TODO: Implement support for the following fields usage: UsageInfo | None = None + class DeltaMessage(BaseModel): role: str | None = None content: str | None = None + class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Literal["stop", "length"] | None = None + class ChatCompletionStreamResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) choices: list[ChatCompletionResponseStreamChoice] + class CompletionRequest(BaseModel): model: str prompt: str | list[str] @@ -91,12 +100,14 @@ class CompletionRequest(BaseModel): # logit_bias # user: Optional[str] = None + class CompletionResponseChoice(BaseModel): index: int text: str logprobs: int | None = None finish_reason: Literal["stop", "length"] | None = None + class CompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") object: str = "text_completion" @@ -104,22 +115,26 @@ class CompletionResponse(BaseModel): choices: list[CompletionResponseChoice] usage: UsageInfo + class CompletionResponseStreamChoice(BaseModel): index: int text: str finish_reason: Optional[Literal["stop", "length"]] = None + class CompletionStreamResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) choices: List[CompletionResponseStreamChoice] + class EmbeddingsRequest(BaseModel): model: Optional[str] = None input: Union[str, List[Any]] user: Optional[str] = None + class EmbeddingsResponse(BaseModel): object: str = "list" data: List[Dict[str, Any]] diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index 1703f97826..486b20e965 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -1,22 +1,19 @@ import argparse import asyncio from contextlib import asynccontextmanager +from dataclasses import dataclass, field, fields -from mlc_chat.chat_module import GenerationConfig - +import numpy as np import uvicorn from fastapi import FastAPI -from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware - -from dataclasses import dataclass, field, fields +from fastapi.responses import StreamingResponse +from mlc_chat.chat_module import GenerationConfig from .base import set_global_random_seed from .chat_module import ChatModule from .interface.openai_api import * -import numpy as np - @dataclass class RestAPIArgs: @@ -327,6 +324,7 @@ async def read_stats(): """ return session["chat_mod"].stats() + @app.get("/verbose_stats") async def read_stats_verbose(): """ @@ -334,6 +332,7 @@ async def read_stats_verbose(): """ return session["chat_mod"].stats(verbose=True) + ARGS = convert_args_to_argparser().parse_args() if __name__ == "__main__": uvicorn.run("mlc_chat.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) From e9b85ce13c6817cc3250530f98a9d180dd8bea52 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 20 Oct 2023 21:36:28 -0700 Subject: [PATCH 037/116] More formatting (#1099) --- tests/{debug => python/legacy}/compare_lib.py | 31 ++---- .../legacy}/dump_intermediate.py | 10 +- tests/{ => python/legacy}/evaluate.py | 8 +- .../legacy}/test_batching_llama.py | 0 tests/python/{ => legacy}/test_build_args.py | 42 ++++---- .../test_build_model_from_args.py | 100 ++++++++++-------- 6 files changed, 92 insertions(+), 99 deletions(-) rename tests/{debug => python/legacy}/compare_lib.py (93%) rename tests/{debug => python/legacy}/dump_intermediate.py (95%) rename tests/{ => python/legacy}/evaluate.py (96%) rename tests/{debug => python/legacy}/test_batching_llama.py (100%) rename tests/python/{ => legacy}/test_build_args.py (89%) rename tests/python/{ => legacy}/test_build_model_from_args.py (64%) diff --git a/tests/debug/compare_lib.py b/tests/python/legacy/compare_lib.py similarity index 93% rename from tests/debug/compare_lib.py rename to tests/python/legacy/compare_lib.py index 9c2e35f014..5bcea1e699 100644 --- a/tests/debug/compare_lib.py +++ b/tests/python/legacy/compare_lib.py @@ -1,17 +1,14 @@ -from typing import List - import argparse -import os import json +import os +from typing import List -import tvm -from tvm import relax -from tvm import rpc -from tvm.relax.testing.lib_comparator import LibCompareVMInstrument import numpy as np - import torch +import tvm from transformers import AutoTokenizer, LlamaTokenizer +from tvm import relax, rpc +from tvm.relax.testing.lib_comparator import LibCompareVMInstrument from mlc_llm import utils @@ -53,7 +50,7 @@ def compare( if self.time_eval and name not in self.time_eval_results: res = self.mod.time_evaluator( - name, self.device, number=20, repeat=3#, cache_flush_bytes=256 * 10**6 + name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6 )(*new_args) self.time_eval_results[name] = (res.mean, 1) print(f"Time-eval result {name} on {self.device}: {res}") @@ -121,9 +118,7 @@ def __init__(self, args): ) ) self.cmp_device = tvm.device(args.cmp_device) - self.const_params_dict = utils.load_params( - args.artifact_path, self.primary_device - ) + self.const_params_dict = utils.load_params(args.artifact_path, self.primary_device) self.cmp_instrument = LibCompare( self.lib, self.cmp_device, @@ -134,9 +129,7 @@ def __init__(self, args): def deploy_to_pipeline(args) -> None: - with open( - os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r" - ) as f: + with open(os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r") as f: config = json.load(f) primary_device = tvm.device(args.primary_device) @@ -157,18 +150,14 @@ def deploy_to_pipeline(args) -> None: tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(), primary_device, ) - first_sampled_token = tvm.nd.array( - np.array([[6234]]).astype("int32"), primary_device - ) + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) kv_caches = state.vm["create_kv_cache"]() print("Running inference...") print("======================= Starts Encoding =======================") - logits, kv_caches = state.vm["prefill"]( - inputs, seq_len_shape, kv_caches, const_params - ) + logits, kv_caches = state.vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) print_as_table( sorted( state.cmp_instrument.time_eval_results.items(), diff --git a/tests/debug/dump_intermediate.py b/tests/python/legacy/dump_intermediate.py similarity index 95% rename from tests/debug/dump_intermediate.py rename to tests/python/legacy/dump_intermediate.py index 84cc8c74b1..52536ad760 100644 --- a/tests/debug/dump_intermediate.py +++ b/tests/python/legacy/dump_intermediate.py @@ -1,12 +1,12 @@ import argparse import os +import pickle import numpy as np import torch import tvm from transformers import AutoTokenizer from tvm import relax -import pickle from mlc_llm import utils @@ -77,12 +77,8 @@ def deploy_to_pipeline(args) -> None: ) print("Tokenizing...") - inputs = ( - tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() - ) - first_sampled_token = tvm.nd.array( - np.array([[6234]]).astype("int32"), primary_device - ) + inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy() + first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device) seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]]) second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1]) kv_caches = state.vm["create_kv_cache"]() diff --git a/tests/evaluate.py b/tests/python/legacy/evaluate.py similarity index 96% rename from tests/evaluate.py rename to tests/python/legacy/evaluate.py index f37fdabf5f..4a370c517c 100644 --- a/tests/evaluate.py +++ b/tests/python/legacy/evaluate.py @@ -58,9 +58,7 @@ def compare( repeat=3, )(*new_args).mean shapes = [arg.shape for arg in new_args] - total_bytes = sum( - arg.numpy().size * arg.numpy().itemsize for arg in new_args - ) + total_bytes = sum(arg.numpy().size * arg.numpy().itemsize for arg in new_args) self.time_eval_results[name] = (res, 1, shapes, total_bytes) else: record = self.time_eval_results[name] @@ -177,9 +175,7 @@ def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals print("Profiling...") kv_caches = vm["create_kv_cache"]() - logits, kv_caches = vm["prefill"]( - inputs, seq_len_shape, kv_caches, const_params - ) + logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params) print("======================= Encoding Profiling =======================") print_as_table( sorted( diff --git a/tests/debug/test_batching_llama.py b/tests/python/legacy/test_batching_llama.py similarity index 100% rename from tests/debug/test_batching_llama.py rename to tests/python/legacy/test_batching_llama.py diff --git a/tests/python/test_build_args.py b/tests/python/legacy/test_build_args.py similarity index 89% rename from tests/python/test_build_args.py rename to tests/python/legacy/test_build_args.py index 3805b29199..8f32d123b6 100644 --- a/tests/python/test_build_args.py +++ b/tests/python/legacy/test_build_args.py @@ -3,11 +3,12 @@ import dataclasses import unittest -from mlc_llm import BuildArgs, utils, core +from mlc_llm import BuildArgs, core, utils + def old_make_args(): """The exact old way of creating `ArgumentParser`, used to test whether - `BuildArgs` is equivalent to this. """ + `BuildArgs` is equivalent to this.""" args = argparse.ArgumentParser() args.add_argument( "--model", @@ -17,7 +18,7 @@ def old_make_args(): 'The name of the model to build. If it is "auto", we will ' 'automatically set the model name according to "--model-path", ' '"hf-path" or the model folders under "--artifact-path/models"' - ) + ), ) args.add_argument( "--hf-path", @@ -30,19 +31,16 @@ def old_make_args(): type=str, choices=[*utils.quantization_schemes.keys()], default=list(utils.quantization_schemes.keys())[0], - help="The quantization mode we use to compile." + help="The quantization mode we use to compile.", ) args.add_argument( "--max-seq-len", type=int, default=-1, - help="The maximum allowed sequence length for the model." + help="The maximum allowed sequence length for the model.", ) args.add_argument( - "--target", - type=str, - default="auto", - help="The target platform to compile the model for." + "--target", type=str, default="auto", help="The target platform to compile the model for." ) args.add_argument( "--reuse-lib", @@ -51,10 +49,7 @@ def old_make_args(): help="Whether to reuse a previously generated lib.", ) args.add_argument( - "--artifact-path", - type=str, - default="dist", - help="Where to store the output." + "--artifact-path", type=str, default="dist", help="Where to store the output." ) args.add_argument( "--use-cache", @@ -66,13 +61,13 @@ def old_make_args(): "--debug-dump", action="store_true", default=False, - help="Whether to dump debugging files during compilation." + help="Whether to dump debugging files during compilation.", ) args.add_argument( "--debug-load-script", action="store_true", default=False, - help="Whether to load the script for debugging." + help="Whether to load the script for debugging.", ) args.add_argument( "--llvm-mingw", @@ -81,10 +76,7 @@ def old_make_args(): help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows.", ) args.add_argument( - "--system-lib", - action="store_true", - default=False, - help="A parameter to `relax.build`." + "--system-lib", action="store_true", default=False, help="A parameter to `relax.build`." ) args.add_argument( "--sep-embed", @@ -99,17 +91,20 @@ def old_make_args(): return args + # Referred to HfArgumentParserTest from https://github.com/huggingface/ # transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils # /test_hf_argparser.py#L143 class BuildArgsTest(unittest.TestCase): """Tests whether BuildArgs reaches parity with regular ArgumentParser.""" - def argparsers_equal(self, parse_a: argparse.ArgumentParser, - parse_b: argparse.ArgumentParser): + + def argparsers_equal(self, parse_a: argparse.ArgumentParser, parse_b: argparse.ArgumentParser): """ Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances. """ - self.assertEqual(len(parse_a._actions), len(parse_b._actions)) # pylint: disable=protected-access + self.assertEqual( + len(parse_a._actions), len(parse_b._actions) + ) # pylint: disable=protected-access for x, y in zip(parse_a._actions, parse_b._actions): # pylint: disable=protected-access xx = {k: v for k, v in vars(x).items() if k != "container"} yy = {k: v for k, v in vars(y).items() if k != "container"} @@ -175,5 +170,6 @@ def test_namespaces_are_equivalent_str_boolean_int(self): build_args_namespace = argparse.Namespace(**build_args_as_dict) self.assertNotEqual(build_args_namespace, parsed_args) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/python/test_build_model_from_args.py b/tests/python/legacy/test_build_model_from_args.py similarity index 64% rename from tests/python/test_build_model_from_args.py rename to tests/python/legacy/test_build_model_from_args.py index a5ce550e9f..c7990d63df 100644 --- a/tests/python/test_build_model_from_args.py +++ b/tests/python/legacy/test_build_model_from_args.py @@ -1,27 +1,25 @@ - import argparse import os import unittest from unittest.mock import MagicMock, mock_open, patch from mlc_llm import utils - from mlc_llm.core import build_model_from_args class MockMkdir(object): def __init__(self): self.received_args = None - + def __call__(self, *args): self.received_args = args -class BuildModelTest(unittest.TestCase): +class BuildModelTest(unittest.TestCase): def setUp(self): self._orig_mkdir = os.mkdir os.mkdir = MockMkdir() - + self.mock_args = argparse.Namespace() self.mock_args.quantization = utils.quantization_schemes["q8f16_1"] self.mock_args.debug_dump = False @@ -38,29 +36,36 @@ def setUp(self): self.mock_args.model = "/tmp/" self.mock_args.target_kind = "cuda" self.mock_args.max_seq_len = 2048 - + def tearDown(self): os.mkdir = self._orig_mkdir @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ {} ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_llama_model(self, mock_file): self.mock_args.model_category = "llama" build_model_from_args(self.mock_args) @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { - "use_parallel_residual": False, - "hidden_size": 32, - "intermediate_size": 32, - "num_attention_heads": 32, - "num_hidden_layers": 28, - "vocab_size": 1024, - "rotary_pct": 1, - "rotary_emb_base": 1, - "layer_norm_eps": 1, - } ])) + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "use_parallel_residual": False, + "hidden_size": 32, + "intermediate_size": 32, + "num_attention_heads": 32, + "num_hidden_layers": 28, + "vocab_size": 1024, + "rotary_pct": 1, + "rotary_emb_base": 1, + "layer_norm_eps": 1, + } + ] + ), + ) def test_gpt_neox_model(self, mock_file): self.mock_args.model_category = "gpt_neox" self.mock_args.model = "dolly-test" @@ -68,7 +73,7 @@ def test_gpt_neox_model(self, mock_file): build_model_from_args(self.mock_args) @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ {} ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_gpt_bigcode_model(self, mock_file): self.mock_args.model_category = "gpt_bigcode" self.mock_args.model = "gpt_bigcode" @@ -76,51 +81,62 @@ def test_gpt_bigcode_model(self, mock_file): build_model_from_args(self.mock_args) @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ {} ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_minigpt_model(self, mock_file): self.mock_args.model_category = "minigpt" self.mock_args.model = "minigpt4-7b" build_model_from_args(self.mock_args) - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { - "vocab_size": 1024, - "n_embd": 32, - "n_inner": 32, - "n_head": 32, - "n_layer": 28, - "bos_token_id": 28, - "eos_token_id": 1, - "rotary_dim": 1, - "tie_word_embeddings": 1, - } ])) + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "vocab_size": 1024, + "n_embd": 32, + "n_inner": 32, + "n_head": 32, + "n_layer": 28, + "bos_token_id": 28, + "eos_token_id": 1, + "rotary_dim": 1, + "tie_word_embeddings": 1, + } + ] + ), + ) def test_gptj_model(self, mock_file): self.mock_args.model_category = "gptj" self.mock_args.model = "gpt-j-" build_model_from_args(self.mock_args) - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { - "num_hidden_layers": 16, - "vocab_size": 1024, - "hidden_size": 16, - "intermediate_size": 32, - } ])) + @patch( + "json.load", + MagicMock( + side_effect=[ + { + "num_hidden_layers": 16, + "vocab_size": 1024, + "hidden_size": 16, + "intermediate_size": 32, + } + ] + ), + ) def test_rwkv_model(self, mock_file): self.mock_args.model_category = "rwkv" self.mock_args.model = "rwkv-" build_model_from_args(self.mock_args) - @patch("builtins.open", new_callable=mock_open, read_data="data") - @patch("json.load", MagicMock(side_effect = [ { } ])) + @patch("json.load", MagicMock(side_effect=[{}])) def test_chatglm_model(self, mock_file): self.mock_args.model_category = "chatglm" self.mock_args.model = "chatglm2" - build_model_from_args(self.mock_args) \ No newline at end of file + build_model_from_args(self.mock_args) From 03c641ad693debdf23643013a26cc42e1b1ff71b Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 20 Oct 2023 21:51:51 -0700 Subject: [PATCH 038/116] Enable Python Linter (#1098) This PR enables two Python formatters "black" and "isort" on the following directory: - `./python/` - `./tests/python/` Enabling pylint and mypy is left for future work --- .github/workflows/python_lint.yml | 37 +++++++++++++++++++++++++++++++ ci/task/black.sh | 9 ++++++++ ci/task/isort.sh | 9 ++++++++ 3 files changed, 55 insertions(+) create mode 100644 .github/workflows/python_lint.yml create mode 100755 ci/task/black.sh create mode 100755 ci/task/isort.sh diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml new file mode 100644 index 0000000000..1a255785ac --- /dev/null +++ b/.github/workflows/python_lint.yml @@ -0,0 +1,37 @@ +name: Python Lint + +on: [push, pull_request] + +env: + IMAGE: 'mlcaidev/ci-cpu:8a87699' + +jobs: + isort: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/isort.sh + + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/black.sh diff --git a/ci/task/black.sh b/ci/task/black.sh new file mode 100755 index 0000000000..0e8555cf63 --- /dev/null +++ b/ci/task/black.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +NUM_THREADS=$(nproc) + +black --check --workers $NUM_THREADS ./python/ +black --check --workers $NUM_THREADS ./tests/python diff --git a/ci/task/isort.sh b/ci/task/isort.sh new file mode 100755 index 0000000000..cdeb030cc6 --- /dev/null +++ b/ci/task/isort.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +NUM_THREADS=$(nproc) + +isort --check-only -j $NUM_THREADS --profile black ./python/ +isort --check-only -j $NUM_THREADS --profile black ./tests/python/ From 46d11e6133e90f0e4f4dd4ad43d661c56f3d93ef Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 20 Oct 2023 23:39:28 -0700 Subject: [PATCH 039/116] Add Basic Pylint and Mypy Tooling (#1100) Add pylint/mypy tooling into pyproject.toml This PR establishes the initial Python tooling infra with Pylint and Mypy. Currently only the newest modules, i.e. `mlc_chat.support` and `mlc_chat.compiler` are covered, and we expect to cover the entire package, as being tracked in #1101. --- .github/workflows/python_lint.yml | 34 +++++++++++++++++-- ci/task/black.sh | 3 +- ci/task/isort.sh | 3 +- ci/task/mypy.sh | 10 ++++++ ci/task/pylint.sh | 13 +++++++ pyproject.toml | 17 +++++++++- .../compiler/model/llama_parameter.py | 16 +++++---- python/mlc_chat/compiler/parameter/mapping.py | 18 +++++++--- python/mlc_chat/support/config.py | 4 +-- .../python/parameter/test_hf_torch_loader.py | 3 +- 10 files changed, 100 insertions(+), 21 deletions(-) create mode 100755 ci/task/mypy.sh create mode 100755 ci/task/pylint.sh diff --git a/.github/workflows/python_lint.yml b/.github/workflows/python_lint.yml index 1a255785ac..dab28b9261 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/python_lint.yml @@ -1,9 +1,7 @@ name: Python Lint - on: [push, pull_request] - env: - IMAGE: 'mlcaidev/ci-cpu:8a87699' + IMAGE: 'mlcaidev/ci-cpu:2c03e7f' jobs: isort: @@ -35,3 +33,33 @@ jobs: - name: Lint run: | ./ci/bash.sh $IMAGE bash ./ci/task/black.sh + + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh + + pylint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + submodules: 'recursive' + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh diff --git a/ci/task/black.sh b/ci/task/black.sh index 0e8555cf63..9e17a4c37a 100755 --- a/ci/task/black.sh +++ b/ci/task/black.sh @@ -3,7 +3,8 @@ set -eo pipefail source ~/.bashrc micromamba activate ci-lint -NUM_THREADS=$(nproc) +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" black --check --workers $NUM_THREADS ./python/ black --check --workers $NUM_THREADS ./tests/python diff --git a/ci/task/isort.sh b/ci/task/isort.sh index cdeb030cc6..0cf5ef9144 100755 --- a/ci/task/isort.sh +++ b/ci/task/isort.sh @@ -3,7 +3,8 @@ set -eo pipefail source ~/.bashrc micromamba activate ci-lint -NUM_THREADS=$(nproc) +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" isort --check-only -j $NUM_THREADS --profile black ./python/ isort --check-only -j $NUM_THREADS --profile black ./tests/python/ diff --git a/ci/task/mypy.sh b/ci/task/mypy.sh new file mode 100755 index 0000000000..68713ac1ae --- /dev/null +++ b/ci/task/mypy.sh @@ -0,0 +1,10 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +mypy ./python/mlc_chat/compiler ./python/mlc_chat/support +mypy ./tests/python/model ./tests/python/parameter diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh new file mode 100755 index 0000000000..c29f5ad44e --- /dev/null +++ b/ci/task/pylint.sh @@ -0,0 +1,13 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +# TVM Unity is a dependency to this testing +pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly + +pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support +pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter diff --git a/pyproject.toml b/pyproject.toml index 2310e9aa60..85ca20eb24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,4 +19,19 @@ profile = "black" [tool.black] line-length = 100 -target-version = ['py310'] + +[tool.mypy] +ignore_missing_imports = true +show_column_numbers = true +show_error_context = true +follow_imports = "skip" +ignore_errors = false +strict_optional = false +install_types = true +non_interactive = true + +[tool.pylint.messages_control] +max-line-length = 100 +disable = """ +duplicate-code, +""" diff --git a/python/mlc_chat/compiler/model/llama_parameter.py b/python/mlc_chat/compiler/model/llama_parameter.py index 39a8921a05..b0fa867130 100644 --- a/python/mlc_chat/compiler/model/llama_parameter.py +++ b/python/mlc_chat/compiler/model/llama_parameter.py @@ -2,6 +2,8 @@ This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace PyTorch, HuggingFace safetensors. """ +from typing import Callable, Dict, List + import numpy as np from ..parameter import ExternMapping @@ -26,8 +28,8 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping: _, named_params = model.export_tvm(spec=model.get_default_spec()) parameter_names = {name for name, _ in named_params} - param_map = {} - map_func = {} + param_map: Dict[str, List[str]] = {} + map_func: Dict[str, Callable] = {} unused_params = set() for i in range(model_config.num_hidden_layers): @@ -35,24 +37,24 @@ def hf_torch(model_config: LlamaConfig) -> ExternMapping: attn = f"model.layers.{i}.self_attn" assert f"{attn}.qkv_proj.weight" in parameter_names map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0) - param_map[f"{attn}.qkv_proj.weight"] = ( + param_map[f"{attn}.qkv_proj.weight"] = [ f"{attn}.q_proj.weight", f"{attn}.k_proj.weight", f"{attn}.v_proj.weight", - ) + ] # Add gates in MLP mlp = f"model.layers.{i}.mlp" assert f"{mlp}.gate_up_proj.weight" in parameter_names map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0) - param_map[f"{mlp}.gate_up_proj.weight"] = ( + param_map[f"{mlp}.gate_up_proj.weight"] = [ f"{mlp}.gate_proj.weight", f"{mlp}.up_proj.weight", - ) + ] # inv_freq is not used in the model unused_params.add(f"{attn}.rotary_emb.inv_freq") for name in parameter_names: if name not in map_func: map_func[name] = lambda x: x - param_map[name] = (name,) + param_map[name] = [name] return ExternMapping(param_map, map_func, unused_params) diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/parameter/mapping.py index 3018c91ca3..6f63dce71a 100644 --- a/python/mlc_chat/compiler/parameter/mapping.py +++ b/python/mlc_chat/compiler/parameter/mapping.py @@ -1,10 +1,18 @@ """Parameter mapping for converting different LLM implementations to MLC LLM.""" import dataclasses -from typing import Callable, Dict, List, Set +from typing import Callable, Dict, List, Set, Union import numpy as np from tvm.runtime import NDArray +MapFuncVariadic = Union[ + Callable[[], np.ndarray], + Callable[[np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], + Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], +] + @dataclasses.dataclass class ExternMapping: @@ -33,8 +41,8 @@ class ExternMapping: """ param_map: Dict[str, List[str]] - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] - unused_params: Set[str] = dataclasses.field(default_factory=dict) + map_func: Dict[str, MapFuncVariadic] + unused_params: Set[str] = dataclasses.field(default_factory=set) @dataclasses.dataclass @@ -72,8 +80,8 @@ class QuantizeMapping: used to convert the quantized parameters into the desired form. """ - param_map: Dict[str, Callable[str, List[str]]] - map_func: Dict[str, Callable[NDArray, List[NDArray]]] + param_map: Dict[str, Callable[[str], List[str]]] + map_func: Dict[str, Callable[[NDArray], List[NDArray]]] __all__ = ["ExternMapping", "QuantizeMapping"] diff --git a/python/mlc_chat/support/config.py b/python/mlc_chat/support/config.py index 62270ffd9c..9e42b815bc 100644 --- a/python/mlc_chat/support/config.py +++ b/python/mlc_chat/support/config.py @@ -37,10 +37,10 @@ def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass: cfg : ConfigClass An instance of the config object. """ - field_names = [field.name for field in dataclasses.fields(cls)] + field_names = [field.name for field in dataclasses.fields(cls)] # type: ignore[arg-type] fields = {k: v for k, v in source.items() if k in field_names} kwargs = {k: v for k, v in source.items() if k not in field_names} - return cls(**fields, kwargs=kwargs) + return cls(**fields, kwargs=kwargs) # type: ignore[call-arg] @classmethod def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass: diff --git a/tests/python/parameter/test_hf_torch_loader.py b/tests/python/parameter/test_hf_torch_loader.py index 745773b209..9cc8d0ea6c 100644 --- a/tests/python/parameter/test_hf_torch_loader.py +++ b/tests/python/parameter/test_hf_torch_loader.py @@ -1,6 +1,7 @@ # pylint: disable=missing-docstring import logging from pathlib import Path +from typing import Union import pytest from mlc_chat.compiler.model.llama import LlamaConfig @@ -24,7 +25,7 @@ "./dist/models/Llama-2-70b-hf", ], ) -def test_load_llama(base_path: str): +def test_load_llama(base_path: Union[str, Path]): base_path = Path(base_path) path_config = base_path / "config.json" path_params = base_path / "pytorch_model.bin.index.json" From 6159cc4903a722dd1d9d5815eb0f13940ae78b5a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Oct 2023 02:22:55 -0700 Subject: [PATCH 040/116] [CI] Add clang-format (#1103) --- .../workflows/{python_lint.yml => lint.yml} | 42 +++++++-- android/src/cpp/tvm_runtime.h | 3 +- ci/bash.sh | 91 +++++++++++++++++++ ci/task/black.sh | 7 +- ci/task/clang-format.sh | 67 ++++++++++++++ ci/task/isort.sh | 7 +- ci/task/mypy.sh | 8 +- ci/task/pylint.sh | 2 + cpp/cli_main.cc | 60 ++++++------ cpp/conversation.h | 3 +- ios/MLCSwift/Sources/ObjC/LLMChat.mm | 17 ++-- ios/MLCSwift/Sources/ObjC/include/LLMChat.h | 11 ++- tests/cpp/conv_unittest.cc | 4 +- 13 files changed, 262 insertions(+), 60 deletions(-) rename .github/workflows/{python_lint.yml => lint.yml} (62%) create mode 100755 ci/bash.sh create mode 100755 ci/task/clang-format.sh diff --git a/.github/workflows/python_lint.yml b/.github/workflows/lint.yml similarity index 62% rename from .github/workflows/python_lint.yml rename to .github/workflows/lint.yml index dab28b9261..478f75e8fd 100644 --- a/.github/workflows/python_lint.yml +++ b/.github/workflows/lint.yml @@ -1,15 +1,16 @@ -name: Python Lint +name: Lint on: [push, pull_request] env: - IMAGE: 'mlcaidev/ci-cpu:2c03e7f' + IMAGE: 'mlcaidev/ci-cpu:caab922' jobs: isort: + name: Python / isort runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: - submodules: 'recursive' + submodules: '' - name: Version run: | wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh @@ -20,11 +21,12 @@ jobs: ./ci/bash.sh $IMAGE bash ./ci/task/isort.sh black: + name: Python / black runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: - submodules: 'recursive' + submodules: '' - name: Version run: | wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh @@ -35,11 +37,12 @@ jobs: ./ci/bash.sh $IMAGE bash ./ci/task/black.sh mypy: + name: Python / mypy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: - submodules: 'recursive' + submodules: '' - name: Version run: | wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh @@ -50,11 +53,12 @@ jobs: ./ci/bash.sh $IMAGE bash ./ci/task/mypy.sh pylint: + name: Python / pylint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: - submodules: 'recursive' + submodules: '' - name: Version run: | wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh @@ -63,3 +67,21 @@ jobs: - name: Lint run: | ./ci/bash.sh $IMAGE bash ./ci/task/pylint.sh + + clang-format: + name: C++ / clang-format + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + submodules: '' + ref: ${{ github.event.pull_request.head.sha }} + fetch-depth: 0 + - name: Version + run: | + wget https://raw.githubusercontent.com/mlc-ai/package/main/docker/bash.sh -O ./ci/bash.sh + chmod u+x ./ci/bash.sh + ./ci/bash.sh $IMAGE "conda env export --name ci-lint" + - name: Lint + run: | + ./ci/bash.sh $IMAGE bash ./ci/task/clang-format.sh diff --git a/android/src/cpp/tvm_runtime.h b/android/src/cpp/tvm_runtime.h index 5a1267119d..2caaaaeb1a 100644 --- a/android/src/cpp/tvm_runtime.h +++ b/android/src/cpp/tvm_runtime.h @@ -1,13 +1,12 @@ #define DMLC_USE_LOGGING_LIBRARY #define TVM_USE_LIBBACKTRACE 0 +#include #include #include #include #include -#include - static_assert(TVM_LOG_CUSTOMIZE == 1, "TVM_LOG_CUSTOMIZE must be 1"); namespace tvm { diff --git a/ci/bash.sh b/ci/bash.sh new file mode 100755 index 0000000000..d54eae48ef --- /dev/null +++ b/ci/bash.sh @@ -0,0 +1,91 @@ +#!/usr/bin/env bash + +# +# Start a bash, mount /workspace to be current directory. +# +# Usage: docker/bash.sh +# Starts an interactive session +# +# Usage2: docker/bash.sh [COMMAND] +# Execute command in the docker image, non-interactive +# +if [ "$#" -lt 1 ]; then + echo "Usage: docker/bash.sh [--no-gpu] [COMMAND]" + exit -1 +fi + +if [ "$1" == "--no-gpu" ]; then + ENABLE_NV_DOCKER=0 + shift +else + ENABLE_NV_DOCKER=1 +fi + +DOCKER_IMAGE_NAME=("$1") + + +if [ "$#" -eq 1 ]; then + COMMAND="bash" + if [[ $(uname) == "Darwin" ]]; then + # Docker's host networking driver isn't supported on macOS. + # Use default bridge network and expose port for jupyter notebook. + DOCKER_EXTRA_PARAMS=("-it -p 8888:8888") + else + DOCKER_EXTRA_PARAMS=("-it --net=host") + fi +else + shift 1 + COMMAND=("$@") +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +WORKSPACE="$(pwd)" + +# Use nvidia-docker if the container is GPU. +if [[ ! -z $CUDA_VISIBLE_DEVICES ]]; then + CUDA_ENV="-e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}" +else + CUDA_ENV="" +fi + +# If this is an wheel test command then pass the env var to docker. +if [[ ! -z $WHEEL_TEST ]]; then + WHEEL_TEST="-e WHEEL_TEST=${WHEEL_TEST}" +fi + +if [[ "${DOCKER_IMAGE_NAME}" == *"cu"* ]]; then + if [ "$ENABLE_NV_DOCKER" -eq 1 ]; then + if ! type "nvidia-docker" 1> /dev/null 2> /dev/null + then + DOCKER_BINARY="docker" + CUDA_ENV=" --gpus all "${CUDA_ENV} + else + DOCKER_BINARY="nvidia-docker" + fi + else + DOCKER_BINARY="docker" + fi +else + DOCKER_BINARY="docker" +fi + +# Print arguments. +echo "WORKSPACE: ${WORKSPACE}" +echo "DOCKER CONTAINER NAME: ${DOCKER_IMAGE_NAME}" +echo "" + +echo "Running '${COMMAND[@]}' inside ${DOCKER_IMAGE_NAME}..." + +# By default we cleanup - remove the container once it finish running (--rm) +# and share the PID namespace (--pid=host) so the process inside does not have +# pid 1 and SIGKILL is propagated to the process inside (jenkins can kill it). + +${DOCKER_BINARY} run --rm --pid=host\ + -v ${WORKSPACE}:/workspace \ + -v ${SCRIPT_DIR}:/docker \ + -w /workspace \ + ${CUDA_ENV} \ + ${WHEEL_TEST} \ + ${DOCKER_EXTRA_PARAMS[@]} \ + ${DOCKER_IMAGE_NAME} \ + ${COMMAND[@]} diff --git a/ci/task/black.sh b/ci/task/black.sh index 9e17a4c37a..dcc4b42555 100755 --- a/ci/task/black.sh +++ b/ci/task/black.sh @@ -6,5 +6,8 @@ micromamba activate ci-lint export NUM_THREADS=$(nproc) export PYTHONPATH="./python:$PYTHONPATH" -black --check --workers $NUM_THREADS ./python/ -black --check --workers $NUM_THREADS ./tests/python +set -x + +black --check --workers $NUM_THREADS \ + ./python/ \ + ./tests/python diff --git a/ci/task/clang-format.sh b/ci/task/clang-format.sh new file mode 100755 index 0000000000..54780ec4f9 --- /dev/null +++ b/ci/task/clang-format.sh @@ -0,0 +1,67 @@ +#!/bin/bash +set -eo pipefail + +source ~/.bashrc +micromamba activate ci-lint +export NUM_THREADS=$(nproc) +export PYTHONPATH="./python:$PYTHONPATH" + +set -x +git config --global --add safe.directory '*' + +INPLACE_FORMAT=${INPLACE_FORMAT:=false} +LINT_ALL_FILES=true +REVISION=$(git rev-list --max-parents=0 HEAD) + +while (($#)); do + case "$1" in + -i) + INPLACE_FORMAT=true + shift 1 + ;; + --rev) + LINT_ALL_FILES=false + REVISION=$2 + shift 2 + ;; + *) + echo "Usage: clang-format.sh [-i] [--rev ]" + echo "" + echo "Run clang-format on files that changed since or on all files in the repo" + echo "Examples:" + echo "- Compare last one commit: clang-format.sh --rev HEAD~1" + echo "- Compare against upstream/main: clang-format.sh --rev upstream/main" + echo "The -i will format files in-place instead of checking them." + exit 1 + ;; + esac +done + +cleanup() { + if [ -f /tmp/$$.clang-format.txt ]; then + echo "" + echo "---------clang-format log----------" + cat /tmp/$$.clang-format.txt + fi + rm -rf /tmp/$$.clang-format.txt +} +trap cleanup 0 + +if [[ "$INPLACE_FORMAT" == "true" ]]; then + echo "Running inplace git-clang-format against $REVISION" + git-clang-format --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" + exit 0 +fi + +if [[ "$LINT_ALL_FILES" == "true" ]]; then + echo "Running git-clang-format against all C++ files" + git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" 1>/tmp/$$.clang-format.txt +else + echo "Running git-clang-format against $REVISION" + git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm "$REVISION" 1>/tmp/$$.clang-format.txt +fi + +if grep --quiet -E "diff" (); } -ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id, const std::string &user_lib_path) { +ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id, + const std::string& user_lib_path) { // Step 1. Find config path std::filesystem::path config_path; if (auto path = TryInferMLCChatConfig(local_id)) { @@ -372,31 +374,31 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l if (!user_lib_path.empty()) { lib_path = user_lib_path; if (!std::filesystem::exists(lib_path) || !std::filesystem::is_regular_file(lib_path)) { - LOG(FATAL) << "The `lib_path` you passed in is not a file: " << user_lib_path << "\n"; - exit(1); + LOG(FATAL) << "The `lib_path` you passed in is not a file: " << user_lib_path << "\n"; + exit(1); } } else { std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib"); std::string lib_name = lib_local_id + "-" + device_name; if (auto path = FindFile({lib_local_id, - "dist/prebuilt/lib", // Using prebuilt workflow - "dist/" + local_id, "dist/prebuilt/" + lib_local_id}, - { - lib_name + GetArchSuffix(), - lib_name, - }, - GetLibSuffixes())) { - lib_path = path.value(); + "dist/prebuilt/lib", // Using prebuilt workflow + "dist/" + local_id, "dist/prebuilt/" + lib_local_id}, + { + lib_name + GetArchSuffix(), + lib_name, + }, + GetLibSuffixes())) { + lib_path = path.value(); } else { - LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n" - << "We searched over the following possible paths: \n" - << "- " + lib_local_id << "\n" - << "- dist/prebuilt/lib \n" - << "- dist/" + local_id << "\n" - << "- dist/prebuilt/" + lib_local_id << "\n" - << "If you would like to directly specify the full model library path, you may " - << "consider passing in the `--model-lib-path` argument.\n"; - exit(1); + LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n" + << "We searched over the following possible paths: \n" + << "- " + lib_local_id << "\n" + << "- dist/prebuilt/lib \n" + << "- dist/" + local_id << "\n" + << "- dist/prebuilt/" + lib_local_id << "\n" + << "If you would like to directly specify the full model library path, you may " + << "consider passing in the `--model-lib-path` argument.\n"; + exit(1); } } std::cout << "Use model library: " << lib_path << std::endl; @@ -480,15 +482,15 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id int main(int argc, char* argv[]) { argparse::ArgumentParser args("mlc_chat"); - args.add_description("MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n" - "Note: the --model argument is required. It can either be the model name with its " - "quantization scheme or a full path to the model folder. In the former case, the " - "provided name will be used to search for the model folder over possible paths. " - "--model-lib-path argument is optional. If unspecified, the --model argument will be used " - "to search for the library file over possible paths."); + args.add_description( + "MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n" + "Note: the --model argument is required. It can either be the model name with its " + "quantization scheme or a full path to the model folder. In the former case, the " + "provided name will be used to search for the model folder over possible paths. " + "--model-lib-path argument is optional. If unspecified, the --model argument will be used " + "to search for the library file over possible paths."); - args.add_argument("--model") - .help("[required] the model to use"); + args.add_argument("--model").help("[required] the model to use"); args.add_argument("--model-lib-path") .help("[optional] the full path to the model library file to use"); args.add_argument("--device").default_value("auto"); diff --git a/cpp/conversation.h b/cpp/conversation.h index 82332aede6..6211c24c25 100644 --- a/cpp/conversation.h +++ b/cpp/conversation.h @@ -283,7 +283,8 @@ class Conversation { /* place_in_prompt= */ place_in_prompt); } else { ICHECK(this->separator_style == SeparatorStyle::kLM || - this->separator_style == SeparatorStyle::kCodeCompletion) << "Unsupported separator_style"; + this->separator_style == SeparatorStyle::kCodeCompletion) + << "Unsupported separator_style"; // special handle LM, LM mode have no memory // and only returns last one if (this->messages.size() >= 2) { diff --git a/ios/MLCSwift/Sources/ObjC/LLMChat.mm b/ios/MLCSwift/Sources/ObjC/LLMChat.mm index da5edc177e..dcf57c5db2 100644 --- a/ios/MLCSwift/Sources/ObjC/LLMChat.mm +++ b/ios/MLCSwift/Sources/ObjC/LLMChat.mm @@ -23,7 +23,8 @@ // The input message is only the beginning part of a prompt, no role name and separator should be // appended after the message since there will be future messages appended after the message. kBegin, - // The input message is in the middle of a prompt, nothing should be appended before or after the message. + // The input message is in the middle of a prompt, nothing should be appended before or after the + // message. kMiddle, // The input message is the ending part of a prompt, no role name and separator should be appended // prior to it since the message is concatenated to some prior messages. @@ -118,7 +119,9 @@ - (void)unload { unload_func_(); } -- (void)reload:(NSString*)modelLib modelPath:(NSString*)modelPath appConfigJson:(NSString*)appConfigJson { +- (void)reload:(NSString*)modelLib + modelPath:(NSString*)modelPath + appConfigJson:(NSString*)appConfigJson { std::string lib_prefix = modelLib.UTF8String; std::string model_path = modelPath.UTF8String; std::string app_config_json = appConfigJson.UTF8String; @@ -194,7 +197,9 @@ - (void)resetImageModule { first_input_after_image = false; } -- (void)prefillImage:(UIImage*)image prevPlaceholder:(NSString*)prevPlaceholder postPlaceholder:(NSString*)postPlaceholder { +- (void)prefillImage:(UIImage*)image + prevPlaceholder:(NSString*)prevPlaceholder + postPlaceholder:(NSString*)postPlaceholder { // prefill the previous placeholder string std::string prev_placeholder = prevPlaceholder.UTF8String; prefill_func_(prev_placeholder, false, (int)PlaceInPrompt::kBegin); @@ -206,9 +211,9 @@ - (void)prefillImage:(UIImage*)image prevPlaceholder:(NSString*)prevPlaceholder NSUInteger bytesPerPixel = 4; NSUInteger bytesPerRow = bytesPerPixel * image_width; NSUInteger bitsPerComponent = 8; - CGContextRef context = CGBitmapContextCreate(image_data.data(), image_width, image_height, - bitsPerComponent, bytesPerRow, colorSpace, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + CGContextRef context = CGBitmapContextCreate( + image_data.data(), image_width, image_height, bitsPerComponent, bytesPerRow, colorSpace, + kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); CGColorSpaceRelease(colorSpace); CGContextDrawImage(context, CGRectMake(0, 0, image_width, image_height), imageRef); CGContextRelease(context); diff --git a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h index a996eaa55f..0aab17adb1 100644 --- a/ios/MLCSwift/Sources/ObjC/include/LLMChat.h +++ b/ios/MLCSwift/Sources/ObjC/include/LLMChat.h @@ -40,9 +40,12 @@ * * @param modelLib The name of the modelLib * @param modelPath The path to the model artifacts. - * @param appConfigJson The partial config that is used to partially override the model configuration. + * @param appConfigJson The partial config that is used to partially override the model + * configuration. */ -- (void)reload:(NSString*)modelLib modelPath:(NSString*)modelPath appConfigJson:(NSString*)appConfigJson; +- (void)reload:(NSString*)modelLib + modelPath:(NSString*)modelPath + appConfigJson:(NSString*)appConfigJson; /** * Reset the current chat session. @@ -118,5 +121,7 @@ * @param prevPlaceholder The previous placeholder in the prompt, i.e. . * @param postPlaceholder The post placeholder in the prompt, i.e. . */ -- (void)prefillImage:(UIImage*)image prevPlaceholder:(NSString*)prevPlaceholder postPlaceholder:(NSString*)postPlaceholder; +- (void)prefillImage:(UIImage*)image + prevPlaceholder:(NSString*)prevPlaceholder + postPlaceholder:(NSString*)postPlaceholder; @end diff --git a/tests/cpp/conv_unittest.cc b/tests/cpp/conv_unittest.cc index 214736320d..98d01a58ba 100644 --- a/tests/cpp/conv_unittest.cc +++ b/tests/cpp/conv_unittest.cc @@ -24,6 +24,4 @@ TEST(ConversationTest, ConversationJSONRoundTripTest) { _TestConversationJSONRoundTrip("LM"); } -TEST(ConversationTest, ConversationPartialUpdateTest) { - _TestConversationPartialUpdate(); -} +TEST(ConversationTest, ConversationPartialUpdateTest) { _TestConversationPartialUpdate(); } From 16dd2aeaf2da7f6b3d7188dc3f0778eb2c53ac13 Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Sun, 22 Oct 2023 19:51:10 -0700 Subject: [PATCH 041/116] [Slim-LM] Smart path finding for config and weight (#1088) --- python/mlc_chat/support/auto_config.py | 34 +++++++ python/mlc_chat/support/auto_weight.py | 125 +++++++++++++++++++++++++ tests/python/test_auto_config.py | 44 +++++++++ tests/python/test_auto_weight.py | 104 ++++++++++++++++++++ 4 files changed, 307 insertions(+) create mode 100644 python/mlc_chat/support/auto_config.py create mode 100644 python/mlc_chat/support/auto_weight.py create mode 100644 tests/python/test_auto_config.py create mode 100644 tests/python/test_auto_weight.py diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py new file mode 100644 index 0000000000..1a4d9bf765 --- /dev/null +++ b/python/mlc_chat/support/auto_config.py @@ -0,0 +1,34 @@ +"""Help function for detecting the model configuration file `config.json`""" +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def detect_config(config_path: Path) -> Path: + """Detect and return the path that points to config.json. If config_path is a directory, + it looks for config.json below it. + + Parameters + --------- + config_path : pathlib.Path + The path to config.json or the directory containing config.json. + + Returns + ------- + config_json_path : pathlib.Path + The path points to config.json. + """ + if not config_path.exists(): + raise ValueError(f"{config_path} does not exist.") + + if config_path.is_dir(): + # search config.json under config_path + config_json_path = config_path / "config.json" + if not config_json_path.exists(): + raise ValueError(f"Fail to find config.json under {config_path}.") + else: + config_json_path = config_path + + logger.info("Found config.json: %s", config_json_path) + return config_json_path diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py new file mode 100644 index 0000000000..74e8a8b8c0 --- /dev/null +++ b/python/mlc_chat/support/auto_weight.py @@ -0,0 +1,125 @@ +"""Help functions for detecting weight paths and weight formats.""" +import json +import logging +from pathlib import Path +from typing import Tuple + +logger = logging.getLogger(__name__) + + +def detect_weight( + weight_path: Path, config_json_path: Path, weight_format: str = "auto" +) -> Tuple[Path, str]: + """Detect the weight directory, and detect the weight format. + + Parameters + --------- + weight_path : pathlib.Path + The path to weight files. If `weight_path` is not None, check if it exists. Otherwise, find + `weight_path` in `config.json` or use the same directory as `config.json`. + + config_json_path: pathlib.Path + The path to `config.json`. + + weight_format : str + The hint for the weight format. If it is "auto", guess the weight format. + Otherwise, check the weights are in that format. + Available weight formats: + - auto (guess the weight format) + - PyTorch (validate via checking pytorch_model.bin.index.json) + - SafeTensor (validate via checking model.safetensors.index.json) + - AWQ + - GGML/GGUF + + Returns + ------- + weight_path : pathlib.Path + The path that points to the weights. + + weight_format : str + The valid weight format. + """ + if weight_path is None: + assert ( + config_json_path is not None and config_json_path.exists() + ), "Please provide config.json path." + + # 1. Find the weight_path in config.json + with open(config_json_path, encoding="utf-8") as i_f: + config = json.load(i_f) + if "weight_path" in config: + weight_path = Path(config["weight_path"]) + logger.info('Found "weight_path" in config.json: %s', weight_path) + if not weight_path.exists(): + raise ValueError(f"weight_path doesn't exist: {weight_path}") + else: + # 2. Find the weights file in the same directory as config.json + weight_path = config_json_path.parent + else: + if not weight_path.exists(): + raise ValueError(f"weight_path doesn't exist: {weight_path}") + + logger.info("Loading weights from directory: %s", weight_path) + + # check weight format + # weight_format = "auto", guess the weight format. + # otherwise, check the weight format is valid. + if weight_format == "auto": + weight_format = _guess_weight_format(weight_path) + + if weight_format not in AVAILABLE_WEIGHT_FORMAT: + raise ValueError( + f"Available weight format list: {AVAILABLE_WEIGHT_FORMAT}, but got {weight_format}" + ) + if weight_format in CHECK_FORMAT_METHODS: + check_func = CHECK_FORMAT_METHODS[weight_format] + if not check_func(weight_path): + raise ValueError(f"The weight is not in {weight_format} format.") + return weight_path, weight_format + + +def _guess_weight_format(weight_path: Path): + possible_formats = [] + for weight_format, check_func in CHECK_FORMAT_METHODS.items(): + if check_func(weight_path): + possible_formats.append(weight_format) + + if len(possible_formats) == 0: + raise ValueError( + "Fail to detect weight format. Use `--weight-format` to manually specify the format." + ) + + selected_format = possible_formats[0] + logging.info( + "Using %s format now. Use `--weight-format` to manually specify the format.", + selected_format, + ) + return selected_format + + +def _check_pytorch(weight_path: Path): + pytorch_json_path = weight_path / "pytorch_model.bin.index.json" + result = pytorch_json_path.exists() + if result: + logger.info("[Y] Found Huggingface PyTorch: %s", pytorch_json_path) + else: + logger.info("[X] Not found: Huggingface PyTorch") + return result + + +def _check_safetensor(weight_path: Path): + safetensor_json_path = weight_path / "model.safetensors.index.json" + result = safetensor_json_path.exists() + if result: + logger.info("[Y] Found SafeTensor: %s", safetensor_json_path) + else: + logger.info("[X] Not found: SafeTensor") + return result + + +CHECK_FORMAT_METHODS = { + "PyTorch": _check_pytorch, + "SafeTensor": _check_safetensor, +} + +AVAILABLE_WEIGHT_FORMAT = ["PyTorch", "SafeTensor", "GGML", "GGUF", "AWQ"] diff --git a/tests/python/test_auto_config.py b/tests/python/test_auto_config.py new file mode 100644 index 0000000000..6209186c32 --- /dev/null +++ b/tests/python/test_auto_config.py @@ -0,0 +1,44 @@ +# pylint: disable=missing-docstring +import json +import logging +import tempfile +from pathlib import Path + +import pytest +from mlc_chat.support.auto_config import detect_config + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="{asctime} {levelname} {filename}:{lineno}: {message}", +) + + +def _create_json_file(json_path, data): + with open(json_path, "w", encoding="utf-8") as i_f: + json.dump(data, i_f) + + +def test_detect_config(): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + config_json_path = base_path / "config.json" + _create_json_file(config_json_path, {}) + + assert detect_config(base_path) == config_json_path + assert detect_config(config_json_path) == config_json_path + + +def test_detect_config_fail(): + with pytest.raises(ValueError): + detect_config(Path("do/not/exist")) + + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + with pytest.raises(ValueError): + assert detect_config(base_path) + + +if __name__ == "__main__": + pass diff --git a/tests/python/test_auto_weight.py b/tests/python/test_auto_weight.py new file mode 100644 index 0000000000..a0363ed1c4 --- /dev/null +++ b/tests/python/test_auto_weight.py @@ -0,0 +1,104 @@ +# pylint: disable=missing-docstring +import json +import logging +import os +import tempfile +from pathlib import Path + +import pytest +from mlc_chat.support.auto_weight import detect_weight + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="{asctime} {levelname} {filename}:{lineno}: {message}", +) + + +def _create_json_file(json_path, data): + with open(json_path, "w", encoding="utf-8") as i_f: + json.dump(data, i_f) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), + ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), + ("GGML", None, "GGML"), + ("GGUF", None, "GGUF"), + ("AWQ", None, "AWQ"), + ("auto", "pytorch_model.bin.index.json", "PyTorch"), + ("auto", "model.safetensors.index.json", "SafeTensor"), + ], +) +def test_detect_weight(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + if index_filename is not None: + weight_index_file = base_path / index_filename + _create_json_file(weight_index_file, {}) + assert detect_weight(base_path, None, weight_format) == (base_path, result) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), + ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), + ("GGML", None, "GGML"), + ("GGUF", None, "GGUF"), + ("AWQ", None, "AWQ"), + ("auto", "pytorch_model.bin.index.json", "PyTorch"), + ("auto", "model.safetensors.index.json", "SafeTensor"), + ], +) +def test_detect_weight_in_config_json(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as config_dir, tempfile.TemporaryDirectory() as weight_dir: + config_path = Path(config_dir) + weight_path = Path(weight_dir) + config_json_path = config_path / "config.json" + _create_json_file(config_json_path, {"weight_path": weight_dir}) + if index_filename is not None: + weight_index_file = weight_path / index_filename + _create_json_file(weight_index_file, {}) + + assert detect_weight(None, config_json_path, weight_format) == (weight_path, result) + + +@pytest.mark.parametrize( + "weight_format, index_filename, result", + [ + ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), + ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), + ("GGML", None, "GGML"), + ("GGUF", None, "GGUF"), + ("AWQ", None, "AWQ"), + ("auto", "pytorch_model.bin.index.json", "PyTorch"), + ("auto", "model.safetensors.index.json", "SafeTensor"), + ], +) +def test_detect_weight_same_dir_config_json(weight_format, index_filename, result): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + config_json_path = base_path / "config.json" + _create_json_file(config_json_path, {}) + if index_filename is not None: + weight_index_file = os.path.join(tmpdir, index_filename) + _create_json_file(weight_index_file, {}) + assert detect_weight(None, config_json_path, weight_format) == (base_path, result) + + +def test_find_weight_fail(): + with tempfile.TemporaryDirectory() as tmpdir: + base_path = Path(tmpdir) + with pytest.raises(ValueError): + detect_weight(Path("do/not/exist"), base_path, "AWQ") + + with pytest.raises(AssertionError): + detect_weight(None, Path("do/not/exist"), "AWQ") + + +if __name__ == "__main__": + pass From f57c9c9ccd80d9afb21928a561cf0feb4565e2f0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 23 Oct 2023 13:31:24 -0500 Subject: [PATCH 042/116] [Transform] Provide IRModule transform for rewrite_attention (#1052) Prior to this commit, `mlc_llm.transform.rewrite_attention` updated a single function. This commit modifies it to instead be a transform operating on any pattern matches within an `IRModule`. --- mlc_llm/core.py | 8 +--- mlc_llm/transform/rewrite_attention.py | 59 +++++++++++++++----------- 2 files changed, 36 insertions(+), 31 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index d187ebe453..a2ffef49a5 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -429,13 +429,7 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: - if args.use_flash_attn_mqa: - mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=True) - mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=True) - - mod["prefill"] = rewrite_attention(mod["prefill"], use_flash_mqa=False) - mod["decode"] = rewrite_attention(mod["decode"], use_flash_mqa=False) - + mod = rewrite_attention(use_flash_mqa=args.use_flash_attn_mqa)(mod) patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: diff --git a/mlc_llm/transform/rewrite_attention.py b/mlc_llm/transform/rewrite_attention.py index b6d2a493ab..d6d5693762 100644 --- a/mlc_llm/transform/rewrite_attention.py +++ b/mlc_llm/transform/rewrite_attention.py @@ -1,35 +1,46 @@ +import tvm from tvm.relax.dpl import PatternContext, is_const, is_op, rewrite_call, wildcard from tvm.script import relax as R -def rewrite_attention(f, use_flash_mqa=False): - Q = wildcard() - K = wildcard() - V = wildcard() +def rewrite_attention(use_flash_mqa=False): + @tvm.ir.transform.module_pass(opt_level=0, name="mlc_llm.transform.rewrite_attention") + def ir_module_transform(mod: tvm.IRModule, context) -> tvm.IRModule: + Q = wildcard() + K = wildcard() + V = wildcard() - Q_BNSH = is_op("relax.permute_dims")(Q) + Q_BNSH = is_op("relax.permute_dims")(Q) - if use_flash_mqa: - K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) - V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) - else: - K_BNSH = is_op("relax.permute_dims")(K) - V_BNSH = is_op("relax.permute_dims")(V) + if use_flash_mqa: + K_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(K)) + V_BNSH = is_op("relax.permute_dims")(is_op("relax.repeat")(V)) + else: + K_BNSH = is_op("relax.permute_dims")(K) + V_BNSH = is_op("relax.permute_dims")(V) - K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) + K_BNSH_T = is_op("relax.permute_dims")(K_BNSH) - matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) - divide = is_op("relax.divide")(matmul1, is_const()) - max = is_op("relax.maximum")(divide, is_const()) - min = is_op("relax.minimum")(max, wildcard()) - softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) - matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) + matmul1 = is_op("relax.matmul")(Q_BNSH, K_BNSH_T) + divide = is_op("relax.divide")(matmul1, is_const()) + max = is_op("relax.maximum")(divide, is_const()) + min = is_op("relax.minimum")(max, wildcard()) + softmax = is_op("relax.nn.softmax")(is_op("relax.astype")(min)) + matmul2 = is_op("relax.matmul")(is_op("relax.astype")(softmax), V_BNSH) - pattern = is_op("relax.permute_dims")(matmul2) + pattern = is_op("relax.permute_dims")(matmul2) - def callback(_, matchings): - return R.nn.attention( - matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" - ) + def callback(_, matchings): + return R.nn.attention( + matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight" + ) - return rewrite_call(pattern, callback, f) + new_module = {} + for gvar, func in mod.functions.items(): + if isinstance(func, tvm.relax.Function): + func = rewrite_call(pattern, callback, func) + new_module[gvar] = func + + return tvm.IRModule(new_module, mod.type_definitions, mod.attrs, mod.global_infos) + + return ir_module_transform From e5927cee3b932b6e3116b43778008a3aa11ef0a3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 23 Oct 2023 13:31:44 -0500 Subject: [PATCH 043/116] [ParamManager] Use BundleModelParams for transform_dequantize (#1056) * [ParamManager] Use BundleModelParams for transform_quantize Prior to this commit, `ParamManager.transform_quantize` function took as input functions with separate parameters for each weight tensor, and produced output functions with a tuple parameter for all weights. Because `LiftTransformParams` had the same convention, neither could be applied as part of the same build flow. This commit updates `ParamManager.transform_quantize` pass to produce outputs with separate tensor parameters, using the `BundleModelParams` transform to later combine them into a single tuple parameter. The analogous change was also performed for `LiftTransformParams` as part of https://github.com/apache/tvm/pull/15657. In addition, prior to this commit, the `ParamManager.transform_dequantize` function operated directly on a `IRModule` object. As a result, any debug instrumentation (e.g. before/after printouts for each pass, before/after verification with `relax.analysis.well_formed`, etc.) did not apply to this `transform_dequantize`. This commit updates `ParamManager.transform_dequantize` to return a `ir.transform.Pass`. * Correct type annotation --- mlc_llm/core.py | 3 +- mlc_llm/relax_model/param_manager.py | 114 +++++++++++++-------------- 2 files changed, 55 insertions(+), 62 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index a2ffef49a5..81de89b7cb 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -396,7 +396,8 @@ def mod_transform_before_build( if args.model.lower().startswith("rwkv-"): model_names += ["reset_kv_cache"] - mod = param_manager.transform_dequantize(mod) + mod = param_manager.transform_dequantize()(mod) + mod = relax.transform.BundleModelParams()(mod) use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 590b60d76b..f20b526fff 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -369,7 +369,7 @@ def set_param_loading_func( else: self.pidx2pname = dict() - def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: + def transform_dequantize(self) -> tvm.ir.transform.Pass: """Apply dequantization to the input IRModule. Parameters @@ -386,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: The IRModule updated with the dequantization computation. """ - # For each Relax function in the input IRModule (e.g., "prefill"), - # we create its input relax.Var of all the quantized data, and - # store the mapping from function name to the var. - func2param_var: Dict[str, relax.Var] = {} - for gv, func in mod.functions.items(): - if not isinstance(func, relax.Function): - continue - if func.attrs is None or not "num_input" in func.attrs: - continue - func2param_var[gv.name_hint] = relax.Var( - "params", self.get_quantized_param_info(gv.name_hint) - ) + @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") + def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: + # For each Relax function in the input IRModule (e.g., "prefill"), + # we create its input relax.Var of all the quantized data, and + # store the mapping from function name to the var. + func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} - # Cache mapping to avoid duplicate dequantization. - dequantized_cache: Dict[relax.Var, relax.Var] = {} + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: + quantized_param_info = self.get_quantized_param_info(gv.name_hint) + param_vars = [ + relax.Var(f"param_{i}", info) + for i, info in enumerate(quantized_param_info.fields) + ] + func_name_to_quantized_params[gv.name_hint] = param_vars - # Define a var replacement function for applying dequantization. - def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var: - if var in dequantized_cache: - return dequantized_cache[var] - assert var in self.func_raw_param_map - func_name, param = self.func_raw_param_map[var] - dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name) - dequantized_cache[var] = dequantized - return dequantized + # Cache mapping to avoid duplicate dequantization. + dequantized_cache: Dict[relax.Var, relax.Var] = {} - # Create the function mutator for applying dequantization. - replacer = ParamReplacer(mod, func2param_var, f_replace) - # Update the input IRModule with dequantization. - mod = replacer.transform() + # Define a var replacement function for applying dequantization. + def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: + if var in dequantized_cache: + return dequantized_cache[var] + assert var in self.func_raw_param_map - return mod + func_name, param = self.func_raw_param_map[var] + quantized_params = func_name_to_quantized_params[func_name] + relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] + + dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) + + dequantized_cache[var] = dequantized + return dequantized + + # Create the function mutator for applying dequantization. + replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) + # Update the input IRModule with dequantization. + mod = replacer.transform() + + return mod + + return transform_func def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]: bb = relax.BlockBuilder() @@ -697,10 +707,9 @@ def _register_param( def _dequantize( self, param: Parameter, - quantized_tuple: relax.Var, + qparams: List[relax.Var], bb: relax.BlockBuilder, func_name: str, - qparams: List[relax.Var] = None, ) -> relax.Var: """Applying dequantization to the input parameter. This method is called by `transform_module` below, and is not @@ -711,30 +720,13 @@ def _dequantize( param : Parameter The parameter whose quantized tensors are to be dequantized. - quantized_tuple : relax.Var - The relax.Var of the quantized tensors of all parameters in the model. - - bb : relax.BlockBuilder - The Relax BlockBuilder used for inserting the dequantization computations. - - func_name : str - The name of the function which dequantization is applied to. - qparams : List[relax.Var] - The quantized parts of the parameter. - By default it is `None`, in which case we will get the quantized parts - from `quantized_tuple`. + The relax.Var of the quantized tensors of all parameters in the model. Returns ------- The dequantized parameter, in the form of a relax.Var. """ - if not qparams: - # Get the corresponding Relax vars of the quantized tensors of this parameter. - qparams: List[relax.Var] = [] - for qparam_idx in self.param2qrange[param]: - qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx))) - # Get the dequantization function of this parameter. f_dequantize = param.quant_spec.get_dequantize_func( param_info=param.param_info_dict[func_name], @@ -789,7 +781,7 @@ class ParamReplacer(PyExprMutator): mod : tvm.IRModule The IRModule of the model to be updated. - func2param_var : Dict[str, relax.Var] + func_name_to_quantized_params : Dict[str, List[relax.Var]] The mapping from each function name to its input var of quantized data tuple. f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] @@ -801,7 +793,7 @@ class ParamReplacer(PyExprMutator): """ mod: tvm.IRModule - func2param_var: Dict[str, relax.Var] + func_name_to_quantized_params: Dict[str, List[relax.Var]] f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] param_set: Set[relax.Var] @@ -810,12 +802,12 @@ class ParamReplacer(PyExprMutator): def __init__( self, mod: tvm.IRModule, - func2param_var: Dict[str, relax.Var], + func_name_to_quantized_params: Dict[str, relax.Var], f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], ): super().__init__(mod) self.mod = mod - self.func2param_var = func2param_var + self.func_name_to_quantized_params = func_name_to_quantized_params self.f_replace = f_replace self.cur_func_name = "" @@ -827,21 +819,20 @@ def transform(self) -> tvm.IRModule: continue assert ( - gv.name_hint in self.func2param_var - ), f"{gv.name_hint} not in {self.func2param_var}" - self.cur_func_name = gv.name_hint - updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint]) + gv.name_hint in self.func_name_to_quantized_params + ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" + updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) updated_func = remove_all_unused(updated_func) self.builder_.update_func(gv, updated_func) return self.builder_.get() - def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function: + def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: num_input = int(func.attrs["num_input"]) self.param_set = set(func.params[num_input:]) body = self.visit_expr(func.body) return relax.Function( - params=func.params[:num_input] + [param_var], + params=func.params[:num_input] + quantized_params, body=body, ret_struct_info=func.ret_struct_info, is_pure=func.is_pure, @@ -849,9 +840,10 @@ def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function: ).without_attr("num_input") def visit_var_(self, var: Var) -> Expr: - if var not in self.param_set: + if var in self.param_set: + return self.f_replace(var, self.builder_) + else: return super().visit_var_(var) - return self.f_replace(var, self.builder_, self.cur_func_name) ################################################################## From 7ae8c6dccf643f8ccb7ff4416f61c36afb58cc5c Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Mon, 23 Oct 2023 15:33:00 -0700 Subject: [PATCH 044/116] [Slim-LM] Introduce HFLoad for loading Pytorch and SafeTensor weights (#1113) --- .../mlc_chat/compiler/parameter/__init__.py | 2 +- .../{hf_torch_loader.py => hf_loader.py} | 148 +++--------------- python/mlc_chat/compiler/parameter/stats.py | 86 ++++++++++ python/mlc_chat/compiler/parameter/utils.py | 52 ++++++ tests/python/parameter/test_hf_loader.py | 66 ++++++++ .../python/parameter/test_hf_torch_loader.py | 43 ----- 6 files changed, 231 insertions(+), 166 deletions(-) rename python/mlc_chat/compiler/parameter/{hf_torch_loader.py => hf_loader.py} (53%) create mode 100644 python/mlc_chat/compiler/parameter/stats.py create mode 100644 python/mlc_chat/compiler/parameter/utils.py create mode 100644 tests/python/parameter/test_hf_loader.py delete mode 100644 tests/python/parameter/test_hf_torch_loader.py diff --git a/python/mlc_chat/compiler/parameter/__init__.py b/python/mlc_chat/compiler/parameter/__init__.py index 3ea9a2b46e..9976b8e336 100644 --- a/python/mlc_chat/compiler/parameter/__init__.py +++ b/python/mlc_chat/compiler/parameter/__init__.py @@ -2,5 +2,5 @@ A subpackage of the compiler that represents mapping between external parameters, quantized parameters and parameters in MLC-defined models. """ -from .hf_torch_loader import HFTorchLoader +from .hf_loader import HFLoader from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_chat/compiler/parameter/hf_torch_loader.py b/python/mlc_chat/compiler/parameter/hf_loader.py similarity index 53% rename from python/mlc_chat/compiler/parameter/hf_torch_loader.py rename to python/mlc_chat/compiler/parameter/hf_loader.py index 000642800e..29c4f2dc1f 100644 --- a/python/mlc_chat/compiler/parameter/hf_torch_loader.py +++ b/python/mlc_chat/compiler/parameter/hf_loader.py @@ -1,13 +1,11 @@ """A weight loader for HuggingFace's PyTorch format""" -import dataclasses + import gc import json import logging -import time from collections import OrderedDict, defaultdict -from contextlib import contextmanager from pathlib import Path -from typing import Dict, Iterator, List, Set, Tuple +from typing import Dict, Iterator, List, Tuple import numpy as np from tqdm import tqdm @@ -15,70 +13,15 @@ from tvm.runtime.ndarray import array as as_ndarray from .mapping import ExternMapping +from .stats import Stats +from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard logger = logging.getLogger(__name__) -@dataclasses.dataclass -class Stats: - """Statistics of the loading process of HuggingFace PyTorch loader. - - Attributes - ---------- - load_time_sec : float - Time used in loading the parameters. - - map_time_sec : float - Time used in applying the mapping function, i.e. `ExternMapping.map_func`. - - quant_time_sec : float - Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`. - - current_memory_gb : float - The current RAM usage in GB. - - total_memory_gb : float - The total size data loaded from disk in GB. - - max_memory_gb : float - The maximum RAM usage in GB. - """ - - load_time_sec: float = 0.0 - map_time_sec: float = 0.0 - quant_time_sec: float = 0.0 - - current_memory_gb: float = 0.0 - total_memory_gb: float = 0.0 - max_memory_gb: float = 0.0 - - def timer(self, attr): - """A context manager to time the scope and add the time to the attribute.""" - - @contextmanager - def timed_scope(): - start_time = time.time() - yield - elapsed_time = time.time() - start_time - setattr(self, attr, getattr(self, attr) + elapsed_time) - - return timed_scope() - - def mem_add(self, nbytes: int): - """Add the memory usage by the given number of bytes.""" - mem_gb = float(nbytes) / float(1024**3) - self.current_memory_gb += mem_gb - self.total_memory_gb += mem_gb - self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb) - - def mem_rm(self, nbytes: int): - """Remove the memory usage by the given number of bytes.""" - mem_gb = float(nbytes) / float(1024**3) - self.current_memory_gb -= mem_gb - - -class HFTorchLoader: # pylint: disable=too-few-public-methods - """A loader loading HuggingFace's PyTorch format and converts them to MLC's parameters. +class HFLoader: # pylint: disable=too-few-public-methods + """A loader loading HuggingFace's PyTorch/SafeTensor format and converts them + to MLC's parameters. Attributes ---------- @@ -86,11 +29,11 @@ class HFTorchLoader: # pylint: disable=too-few-public-methods Statistics of the loading process. extern_param_map : ExternMapping - The parameter mapping from MLC to HuggingFace PyTorch. + The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor. torch_to_path : Dict[str, Path] - A mapping from PyTorch parameter name to the path of the file containing it, or the path - meaning all parameters are stored in a single file. + A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it, + or the path meaning all parameters are stored in a single file. cached_files : Dict[Path, Dict[str, np.ndarray]] A cache of the loaded files. The key is the path of the file, and the value is a mapping @@ -113,20 +56,23 @@ def __init__( ---------- path : pathlib.Path Path to either a JSON indexing file, or a PyTorch bin file. - 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` in the repo, - which contains a `weight_map` that maps each PyTorch parameter to the file containing - the weight. 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo, + 1) For JSON indexing file, it is usually `pytorch_model.bin.index.json` + or `model.safetensors.index.json` in the repo, which contains a `weight_map` that + maps each PyTorch parameter to the file containing the weight. + 2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo, + which contains all the parameters. + 3) For safetensor file, it is usually `model.safetensors` in the repo, which contains all the parameters. extern_param_map : ExternMapping - Maps an MLC parameter to a list of PyTorch parameters. + Maps an MLC parameter to a list of PyTorch/SafeTensor parameters. """ assert path.is_file() self.stats = Stats() self.extern_param_map = extern_param_map self.cached_files = {} self.torch_to_path = {} - if path.suffix == ".bin": + if path.suffix in (".bin", ".safetensors"): self._load_file(path) for name in self.cached_files[path].keys(): self.torch_to_path[name] = path @@ -137,7 +83,7 @@ def __init__( self.torch_to_path[torch_name] = path.parent / path_str else: raise FileNotFoundError(f"Unknown file suffix: {path}") - _check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) + check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) def load(self) -> Iterator[Tuple[str, NDArray]]: """Load the parameters and yield the MLC parameter and its value.""" @@ -148,21 +94,8 @@ def load(self) -> Iterator[Tuple[str, NDArray]]: cached_files = list(self.cached_files.keys()) for path in cached_files: self._unload_file(path) - - logger.info( - "Time used: " - "PyTorch loading: %.3f sec; " - "Pre-quantization mapping: %.3f sec; " - "Quantization: %.3f sec", - self.stats.load_time_sec, - self.stats.map_time_sec, - self.stats.quant_time_sec, - ) - logger.info( - "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB", - self.stats.total_memory_gb, - self.stats.max_memory_gb, - ) + self.stats.log_time_info("HF") + self.stats.log_mem_usage() def _load_mlc_param(self, mlc_name: str) -> np.ndarray: torch_names = self.extern_param_map.param_map[mlc_name] @@ -190,16 +123,17 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray: return param def _load_file(self, path: Path) -> None: - logger.info("Loading PyTorch parameters from: %s", path) + logger.info("Loading HF parameters from: %s", path) + load_func = load_safetensor_shard if path.suffix == ".safetensors" else load_torch_shard with self.stats.timer("load_time_sec"): result = {} - for name, param in _load_torch_shard(path): + for name, param in load_func(path): result[name] = param self.stats.mem_add(param.nbytes) self.cached_files[path] = result def _unload_file(self, path: Path) -> None: - logger.info("Unloading PyTorch weight file: %s", path) + logger.info("Unloading HF weight file: %s", path) with self.stats.timer("load_time_sec"): for _, param in self.cached_files[path].items(): self.stats.mem_rm(param.nbytes) @@ -207,36 +141,6 @@ def _unload_file(self, path: Path) -> None: gc.collect() -def _check_parameter_usage(param_map: ExternMapping, torch_weights: Set[str]): - used_torch_names = set(sum(param_map.param_map.values(), ())) - # Check 1. All PyTorch parameters in the weight files are used unless explicitly specified - unused_torch_names = torch_weights - used_torch_names - param_map.unused_params - if unused_torch_names: - logger.warning( - "Unused torch parameters: %s", - ", ".join(sorted(unused_torch_names)), - ) - # Check 2. All PyTorch parameters required are stored in the weight files - nonexistent_torch_names = used_torch_names - torch_weights - if nonexistent_torch_names: - raise ValueError( - "The following torch parameters do not exist in the weight files:\n " - + "\n ".join(sorted(nonexistent_torch_names)), - ) - - -def _load_torch_shard(path: Path): - import torch # pylint: disable=import-outside-toplevel - - for name, param in torch.load(path, map_location=torch.device("cpu")).items(): - param = param.detach().cpu() - dtype = str(param.dtype) - if dtype == "torch.bfloat16": - param = param.float() - param = param.numpy() - yield name, param - - def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]: # Step 1. Build a map from path to torch parameters path_to_torch: Dict[Path, List[str]] = defaultdict(list) @@ -257,4 +161,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> return list(order.keys()) -__all__ = ["HFTorchLoader"] +__all__ = ["HFLoader"] diff --git a/python/mlc_chat/compiler/parameter/stats.py b/python/mlc_chat/compiler/parameter/stats.py new file mode 100644 index 0000000000..9f5d1e16fa --- /dev/null +++ b/python/mlc_chat/compiler/parameter/stats.py @@ -0,0 +1,86 @@ +"""Statistics of the loading process of parameter loaders""" +import dataclasses +import logging +import time +from contextlib import contextmanager + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class Stats: + """Statistics of the loading process of parameter loaders. + + Attributes + ---------- + load_time_sec : float + Time used in loading the parameters. + + map_time_sec : float + Time used in applying the mapping function, i.e. `ExternMapping.map_func`. + + quant_time_sec : float + Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`. + + current_memory_gb : float + The current RAM usage in GB. + + total_memory_gb : float + The total size data loaded from disk in GB. + + max_memory_gb : float + The maximum RAM usage in GB. + """ + + load_time_sec: float = 0.0 + map_time_sec: float = 0.0 + quant_time_sec: float = 0.0 + + current_memory_gb: float = 0.0 + total_memory_gb: float = 0.0 + max_memory_gb: float = 0.0 + + def timer(self, attr): + """A context manager to time the scope and add the time to the attribute.""" + + @contextmanager + def timed_scope(): + start_time = time.time() + yield + elapsed_time = time.time() - start_time + setattr(self, attr, getattr(self, attr) + elapsed_time) + + return timed_scope() + + def mem_add(self, nbytes: int): + """Add the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb += mem_gb + self.total_memory_gb += mem_gb + self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb) + + def mem_rm(self, nbytes: int): + """Remove the memory usage by the given number of bytes.""" + mem_gb = float(nbytes) / float(1024**3) + self.current_memory_gb -= mem_gb + + def log_time_info(self, weight_format: str): + """Log the time used in loading, pre-quantization and quantization.""" + logger.info( + "Time used: " + "%s loading: %.3f sec; " + "Pre-quantization mapping: %.3f sec; " + "Quantization: %.3f sec", + weight_format, + self.load_time_sec, + self.map_time_sec, + self.quant_time_sec, + ) + + def log_mem_usage(self): + """Log the Memory usage information.""" + logger.info( + "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB", + self.total_memory_gb, + self.max_memory_gb, + ) diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/parameter/utils.py new file mode 100644 index 0000000000..596941aaca --- /dev/null +++ b/python/mlc_chat/compiler/parameter/utils.py @@ -0,0 +1,52 @@ +"""Common utilities for loading parameters""" +import logging +from pathlib import Path +from typing import Iterator, Set, Tuple + +import numpy as np + +from .mapping import ExternMapping + +logger = logging.getLogger(__name__) + + +def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]): + """Check that all external parameters have been used and are stored in the weights file.""" + used_extern_names = set(sum(param_map.param_map.values(), [])) + # Check 1. All extern parameters in the weight files are used unless explicitly specified + unused_extern_names = extern_weights - used_extern_names - param_map.unused_params + if unused_extern_names: + logger.warning( + "Unused extern parameters: %s", + ", ".join(sorted(unused_extern_names)), + ) + # Check 2. All extern parameters required are stored in the weight files + nonexistent_extern_names = used_extern_names - extern_weights + if nonexistent_extern_names: + raise ValueError( + "The following extern parameters do not exist in the weight files:\n " + + "\n ".join(sorted(nonexistent_extern_names)), + ) + + +def load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]: + """Load and yield PyTorch format parameters.""" + import torch # pylint: disable=import-outside-toplevel + + for name, param in torch.load(path, map_location=torch.device("cpu")).items(): + param = param.detach().cpu() + dtype = str(param.dtype) + if dtype == "torch.bfloat16": + param = param.float() + param = param.numpy() + yield name, param + + +def load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]: + """Load and yield SafeTensor format parameters.""" + import safetensors # pylint: disable=import-outside-toplevel,import-error + + with safetensors.safe_open(path, framework="numpy", device="cpu") as in_file: + for name in in_file.keys(): + param = in_file.get_tensor(name) + yield name, param diff --git a/tests/python/parameter/test_hf_loader.py b/tests/python/parameter/test_hf_loader.py new file mode 100644 index 0000000000..4e983d83dd --- /dev/null +++ b/tests/python/parameter/test_hf_loader.py @@ -0,0 +1,66 @@ +# pylint: disable=missing-docstring +import logging +from pathlib import Path +from typing import Union + +import pytest +from mlc_chat.compiler.model.llama import LlamaConfig +from mlc_chat.compiler.model.llama_parameter import hf_torch +from mlc_chat.compiler.parameter import HFLoader +from mlc_chat.support import tqdm + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_torch_llama(base_path: Union[str, Path]): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + config = LlamaConfig.from_file(path_config) + loader = HFLoader(path=path_params, extern_param_map=hf_torch(config)) + with tqdm.redirect(): + for _name, _param in loader.load(): + ... + + +@pytest.mark.parametrize( + "base_path", + [ + "./dist/models/Llama-2-7b-hf", + "./dist/models/Llama-2-13b-hf", + "./dist/models/Llama-2-70b-hf", + ], +) +def test_load_safetensor_llama(base_path: Union[str, Path]): + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "model.safetensors.index.json" + + config = LlamaConfig.from_file(path_config) + loader = HFLoader(path=path_params, extern_param_map=hf_torch(config)) + with tqdm.redirect(): + for _name, _param in loader.load(): + ... + + +if __name__ == "__main__": + test_load_torch_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_torch_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_torch_llama(base_path="./dist/models/Llama-2-70b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-7b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-13b-hf") + test_load_safetensor_llama(base_path="./dist/models/Llama-2-70b-hf") diff --git a/tests/python/parameter/test_hf_torch_loader.py b/tests/python/parameter/test_hf_torch_loader.py deleted file mode 100644 index 9cc8d0ea6c..0000000000 --- a/tests/python/parameter/test_hf_torch_loader.py +++ /dev/null @@ -1,43 +0,0 @@ -# pylint: disable=missing-docstring -import logging -from pathlib import Path -from typing import Union - -import pytest -from mlc_chat.compiler.model.llama import LlamaConfig -from mlc_chat.compiler.model.llama_parameter import hf_torch -from mlc_chat.compiler.parameter import HFTorchLoader -from mlc_chat.support import tqdm - -logging.basicConfig( - level=logging.DEBUG, - style="{", - datefmt="%Y-%m-%d %H:%M:%S", - format="[{asctime}] {levelname} {filename}:{lineno}: {message}", -) - - -@pytest.mark.parametrize( - "base_path", - [ - "./dist/models/Llama-2-7b-hf", - "./dist/models/Llama-2-13b-hf", - "./dist/models/Llama-2-70b-hf", - ], -) -def test_load_llama(base_path: Union[str, Path]): - base_path = Path(base_path) - path_config = base_path / "config.json" - path_params = base_path / "pytorch_model.bin.index.json" - - config = LlamaConfig.from_file(path_config) - loader = HFTorchLoader(path=path_params, extern_param_map=hf_torch(config)) - with tqdm.redirect(): - for _name, _param in loader.load(): - ... - - -if __name__ == "__main__": - test_load_llama(base_path="./dist/models/Llama-2-7b-hf") - test_load_llama(base_path="./dist/models/Llama-2-13b-hf") - test_load_llama(base_path="./dist/models/Llama-2-70b-hf") From 5a7dcd823773ca9420a988eba79552ec1317292e Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 24 Oct 2023 00:00:41 -0400 Subject: [PATCH 045/116] [WINDOWS] reduce noise in windows build (#1115) --- 3rdparty/tvm | 2 +- mlc_llm/relax_model/param_manager.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index e5ca38dd73..631f37b6bf 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit e5ca38dd735ba4d30782a4a58bf6195861642eb0 +Subproject commit 631f37b6bf8b101d16ecc55de7e6a749a3588570 diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index f20b526fff..c6729b41e0 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -269,8 +269,8 @@ def register_params( relax_param, getattr(quantization_scheme, quant_kind.name), func_name, - getattr(relax_param, "shard_dim", None), - getattr(relax_param, "shard_strategy", None), + relax_param.__dict__.get("shard_dim", None), + relax_param.__dict__.get("shard_strategy", None), ) self.params_in_func[func_name].append(param) From 61179a0ec924c4714ef4f4bce08a9130f11acc53 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 23 Oct 2023 23:58:01 -0700 Subject: [PATCH 046/116] Add CLI commands for compilation (#1109) --- python/mlc_chat/cli/compile.py | 141 +++++++++ python/mlc_chat/compiler/compile.py | 53 ++++ python/mlc_chat/compiler/model/__init__.py | 2 +- .../model/{llama.py => llama_model.py} | 4 +- .../compiler/model/llama_parameter.py | 5 +- python/mlc_chat/compiler/model/model.py | 39 +++ .../mlc_chat/compiler/parameter/__init__.py | 2 +- .../{hf_loader.py => huggingface_loader.py} | 4 +- python/mlc_chat/support/auto_config.py | 6 +- python/mlc_chat/support/auto_target.py | 293 ++++++++++++++++++ python/mlc_chat/support/auto_weight.py | 4 +- python/mlc_chat/support/style.py | 62 ++++ tests/python/model/test_llama.py | 3 +- ...{test_hf_loader.py => test_huggingface.py} | 10 +- 14 files changed, 612 insertions(+), 16 deletions(-) create mode 100644 python/mlc_chat/cli/compile.py create mode 100644 python/mlc_chat/compiler/compile.py rename python/mlc_chat/compiler/model/{llama.py => llama_model.py} (98%) create mode 100644 python/mlc_chat/compiler/model/model.py rename python/mlc_chat/compiler/parameter/{hf_loader.py => huggingface_loader.py} (98%) create mode 100644 python/mlc_chat/support/auto_target.py create mode 100644 python/mlc_chat/support/style.py rename tests/python/parameter/{test_hf_loader.py => test_huggingface.py} (82%) diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py new file mode 100644 index 0000000000..bf976ef7f2 --- /dev/null +++ b/python/mlc_chat/cli/compile.py @@ -0,0 +1,141 @@ +"""Command line entrypoint of compilation.""" +import argparse +import json +import logging +from pathlib import Path +from typing import Union + +from mlc_chat.compiler.compile import compile # pylint: disable=redefined-builtin +from mlc_chat.compiler.model import MODELS, Model + +from ..support.auto_config import detect_config +from ..support.auto_target import detect_target_and_host + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def _parse_config(path: Union[str, Path]) -> Path: + try: + return detect_config(Path(path)) + except ValueError as err: + raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") + + +def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + parent = path.parent + if not parent.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") + return path + + +def _parse_model_type(model_type: str, config: Path) -> Model: + if model_type == "auto": + with open(config, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + if "model_type" not in cfg: + raise ValueError( + f"'model_type' not found in: {config}. " + f"Please explicitly specify `--model-type` instead" + ) + model_type = cfg["model_type"] + if model_type not in MODELS: + raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") + return MODELS[model_type] + + +def main(): + """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" + parser = argparse.ArgumentParser("MLC LLM Compiler") + parser.add_argument( + "--config", + type=_parse_config, + required=True, + help="Path to config.json file or to the directory that contains config.json, which is " + "a HuggingFace standard that defines model architecture, for example, " + "https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=[ + "q0f16", + "q0f32", + "q3f16_1", + "q3f32_1", + "q4f16_1", + "q4f16_ft", + "q4f32_1", + ], + help="The quantization format. TBD", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help="Model architecture, for example, llama. If not set, it is inferred " + "from the config.json file.", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="The GPU device to compile the model to. If not set, it is inferred from locally " + "available GPUs.", + ) + parser.add_argument( + "--host", + type=str, + default="auto", + choices=[ + "auto", + "arm", + "arm64", + "aarch64", + "x86-64", + ], + help="The host CPU ISA to compile the model to. If not set, it is inferred from the " + "local CPU.", + ) + parser.add_argument( + "--opt", + type=str, + default="", + help="Optimization flags.", + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help="The name of the output file. The suffix determines if the output file is a " + "shared library or a static library. Available suffixes: " + "1) Linux: .so (shared), .a (static); " + "2) macOS: .dylib (shared), .a (static); " + "3) Windows: .dll (shared), .lib (static); " + "4) Android, iOS: .tar (static); " + "5) Web: .wasm (web assembly)", + ) + parsed = parser.parse_args() + target, build_func = detect_target_and_host(parsed.device, parsed.host) + parsed.model_type = _parse_model_type(parsed.model_type, parsed.config) + compile( + config=parsed.config, + quantization=parsed.quantization, + model_type=parsed.model_type, + target=target, + opt=parsed.opt, + build_func=build_func, + output=parsed.output, + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py new file mode 100644 index 0000000000..2d68154bf3 --- /dev/null +++ b/python/mlc_chat/compiler/compile.py @@ -0,0 +1,53 @@ +"""Python entrypoint of compilation.""" +import dataclasses +import logging +from io import StringIO +from pathlib import Path +from typing import Callable + +from mlc_chat.compiler.model import Model +from tvm import IRModule # pylint: disable=wrong-import-order +from tvm.target import Target # pylint: disable=wrong-import-order + +from ..support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class CompileArgs: + """Arguments to MLC LLM's compiler.""" + + config: Path + quantization: str + model_type: Model + target: Target + opt: str + build_func: Callable[[IRModule, "CompileArgs"], None] + output: Path + + +def _echo_args(args: CompileArgs) -> None: + out = StringIO() + print(f"{bold('Compiling with arguments:')}", file=out) + print(f" {bold('--config'):<25} {args.config}", file=out) + print(f" {bold('--quantization'):<25} {args.quantization}", file=out) + print(f" {bold('--model-type'):<25} {args.model_type.name}", file=out) + print(f" {bold('--target'):<25} {args.target.export()}", file=out) + print(f" {bold('--opt'):<25} {args.opt}", file=out) + print(f" {bold('--output'):<25} {args.output}", file=out) + print(out.getvalue().rstrip()) + + +def compile( # pylint: disable=too-many-arguments,redefined-builtin + config: Path, + quantization, + model_type: Model, + target: Target, + opt, + build_func: Callable[[IRModule, CompileArgs], None], + output: Path, +): + """Compile a model given its configuration and quantization format to a specific target.""" + args = CompileArgs(config, quantization, model_type, target, opt, build_func, output) + _echo_args(args) diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py index b568bd84f7..8bb4879e7d 100644 --- a/python/mlc_chat/compiler/model/__init__.py +++ b/python/mlc_chat/compiler/model/__init__.py @@ -1,2 +1,2 @@ """Model definition for the compiler.""" -from . import llama, llama_config, llama_parameter +from .model import MODELS, Model diff --git a/python/mlc_chat/compiler/model/llama.py b/python/mlc_chat/compiler/model/llama_model.py similarity index 98% rename from python/mlc_chat/compiler/model/llama.py rename to python/mlc_chat/compiler/model/llama_model.py index 663e6d93c2..49e947f741 100644 --- a/python/mlc_chat/compiler/model/llama.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -156,11 +156,11 @@ def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor class LlamaForCasualLM(nn.Module): - def __init__(self, config: LlamaConfig, dtype: str = "float32"): + def __init__(self, config: LlamaConfig): self.model = LlamaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.vocab_size = config.vocab_size - self.dtype = dtype + self.dtype = "float32" def to(self, dtype: Optional[str] = None): super().to(dtype=dtype) diff --git a/python/mlc_chat/compiler/model/llama_parameter.py b/python/mlc_chat/compiler/model/llama_parameter.py index b0fa867130..4c68fdc899 100644 --- a/python/mlc_chat/compiler/model/llama_parameter.py +++ b/python/mlc_chat/compiler/model/llama_parameter.py @@ -7,10 +7,11 @@ import numpy as np from ..parameter import ExternMapping -from .llama import LlamaConfig, LlamaForCasualLM +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM -def hf_torch(model_config: LlamaConfig) -> ExternMapping: +def huggingface(model_config: LlamaConfig, _) -> ExternMapping: """Returns a parameter mapping that maps from the names of MLC LLM parameters to the names of HuggingFace PyTorch parameters. diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py new file mode 100644 index 0000000000..9c9db8c1f9 --- /dev/null +++ b/python/mlc_chat/compiler/model/model.py @@ -0,0 +1,39 @@ +"""A centralized registry of all existing model architures and their configurations.""" +import dataclasses +from pathlib import Path +from typing import Any, Callable, Dict, Optional + +from tvm.relax.frontend import nn + +from ..parameter import ExternMapping, QuantizeMapping +from . import llama_config, llama_model, llama_parameter + +ModelConfig = Any +QuantizeConfig = Any + +LoaderType = Callable[[ModelConfig, QuantizeConfig], ExternMapping] +QuantizerType = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping] + + +@dataclasses.dataclass +class Model: + """All about a model architecture: its configuration, its parameter loader and quantization.""" + + name: str + model: Callable[[ModelConfig], nn.Module] + config: Callable[[Path], ModelConfig] + source_loader_huggingface: Optional[LoaderType] = None + source_loader_awq: Optional[LoaderType] = None + quantizer_group_quant: Optional[QuantizerType] = None + + +MODELS: Dict[str, Model] = { + "llama": Model( + name="llama", + model=llama_model.LlamaForCasualLM, + config=llama_config.LlamaConfig.from_file, + source_loader_huggingface=llama_parameter.huggingface, + source_loader_awq=None, + quantizer_group_quant=None, + ) +} diff --git a/python/mlc_chat/compiler/parameter/__init__.py b/python/mlc_chat/compiler/parameter/__init__.py index 9976b8e336..f119b01f91 100644 --- a/python/mlc_chat/compiler/parameter/__init__.py +++ b/python/mlc_chat/compiler/parameter/__init__.py @@ -2,5 +2,5 @@ A subpackage of the compiler that represents mapping between external parameters, quantized parameters and parameters in MLC-defined models. """ -from .hf_loader import HFLoader +from .huggingface_loader import HuggingFaceLoader from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_chat/compiler/parameter/hf_loader.py b/python/mlc_chat/compiler/parameter/huggingface_loader.py similarity index 98% rename from python/mlc_chat/compiler/parameter/hf_loader.py rename to python/mlc_chat/compiler/parameter/huggingface_loader.py index 29c4f2dc1f..fa6beb40eb 100644 --- a/python/mlc_chat/compiler/parameter/hf_loader.py +++ b/python/mlc_chat/compiler/parameter/huggingface_loader.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) -class HFLoader: # pylint: disable=too-few-public-methods +class HuggingFaceLoader: # pylint: disable=too-few-public-methods """A loader loading HuggingFace's PyTorch/SafeTensor format and converts them to MLC's parameters. @@ -161,4 +161,4 @@ def _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> return list(order.keys()) -__all__ = ["HFLoader"] +__all__ = ["HuggingFaceLoader"] diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 1a4d9bf765..f5c33acb9f 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -2,8 +2,12 @@ import logging from pathlib import Path +from .style import green + logger = logging.getLogger(__name__) +FOUND = green("Found") + def detect_config(config_path: Path) -> Path: """Detect and return the path that points to config.json. If config_path is a directory, @@ -30,5 +34,5 @@ def detect_config(config_path: Path) -> Path: else: config_json_path = config_path - logger.info("Found config.json: %s", config_json_path) + logger.info("%s model configuration: %s", FOUND, config_json_path) return config_json_path diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py new file mode 100644 index 0000000000..d64c8f9daa --- /dev/null +++ b/python/mlc_chat/support/auto_target.py @@ -0,0 +1,293 @@ +"""Helper functioms for target auto-detection.""" +import logging +from typing import TYPE_CHECKING, Callable, Optional, Tuple + +from tvm import IRModule, relax +from tvm._ffi import register_func +from tvm.contrib import tar, xcode +from tvm.target import Target + +from .style import green, red + +if TYPE_CHECKING: + from mlc_chat.compiler.compile import CompileArgs + + +logger = logging.getLogger(__name__) + +# TODO: add help message on how to specify the target manually # pylint: disable=fixme +# TODO: revisit system_lib_prefix handling # pylint: disable=fixme +# TODO: include host detection logic below after the new TVM build is done. # pylint: disable=fixme +HELP_MSG = """TBD""" +FOUND = green("Found") +NOT_FOUND = red("Not found") +BuildFunc = Callable[[IRModule, "CompileArgs"], None] + + +def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, BuildFunc]: + """Detect the configuration for the target device and its host, for example, target GPU and + the host CPU. + + Parameters + ---------- + target_hint : str + The hint for the target device. + + host_hint : str + The hint for the host CPU. + """ + target, build_func = _detect_target_gpu(target_hint) + if target.host is None: + target = Target(target, host=_detect_target_host(host_hint)) + return target, build_func + + +def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: + if hint in ["iphone", "android", "webgpu", "mali", "opencl"]: + hint += ":generic" + if hint == "auto": + logger.info("Detecting potential target devices: %s", ", ".join(AUTO_DETECT_DEVICES)) + target: Optional[Target] = None + for device in AUTO_DETECT_DEVICES: + device_target = _detect_target_from_device(device + ":0") + if device_target is not None and target is None: + target = device_target + if target is None: + raise ValueError("No GPU target detected. Please specify explicitly") + return target, _build_default() + if hint in AUTO_DETECT_DEVICES: + target = _detect_target_from_device(hint + ":0") + if target is None: + raise ValueError(f"No GPU target detected from device: {hint}") + return target, _build_default() + if hint in PRESET: + preset = PRESET[hint] + target = Target(preset["target"]) # type: ignore[index] + build = preset.get("build", _build_default) # type: ignore[attr-defined] + return target, build() + if _is_device(hint): + logger.info("Detecting target device: %s", hint) + target = Target.from_device(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + try: + logger.info("Try creating device target from string: %s", hint) + target = Target(hint) + logger.info("%s target: %s", FOUND, target.export()) + return target, _build_default() + except Exception as err: + logger.info("%s: Failed to create target", NOT_FOUND) + raise ValueError(f"Invalid target: {hint}") from err + + +def _detect_target_host(hint: str) -> Target: + """Detect the host CPU architecture.""" + # cpu = codegen.llvm_get_system_cpu() + # triple = codegen.llvm_get_system_triple() + # vendor = codegen.llvm_get_system_x86_vendor() + if hint == "auto": + hint = "x86-64" + if hint == "x86-64": + hint = "x86_64" + return Target({"kind": "llvm", "mtriple": f"{hint}-unknown-unknown"}) + + +def _is_device(device: str): + if " " in device: + return False + if device.count(":") != 1: + return False + return True + + +def _detect_target_from_device(device: str) -> Optional[Target]: + try: + target = Target.from_device(device) + except ValueError: + logger.info("%s: target device: %s", NOT_FOUND, device) + return None + logger.info( + '%s configuration of target device "%s": %s', + FOUND, + device, + target.export(), + ) + return target + + +def _build_metal_x86_64(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + assert output.suffix == ".dylib" + relax.build( + mod, + target=args.target, + ).export_library( + str(output), + fcompile=xcode.create_dylib, + sdk="macosx", + arch="x86_64", + ) + + return build + + +def _build_iphone(): + @register_func("tvm_callback_metal_compile", override=True) + def compile_metal(src, target): + if target.libs: + return xcode.compile_metal(src, sdk=target.libs[0]) + return xcode.compile_metal(src) + + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + system_lib_prefix = f"{args.model_type}_{args.quantization}_".replace("-", "_") + assert output.suffix == ".tar" + relax.build( + mod.with_attr("system_lib_prefix", system_lib_prefix), + target=args.target, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_android(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + system_lib_prefix = f"{args.model_type}_{args.quantization}_".replace("-", "_") + assert output.suffix == ".tar" + relax.build( + mod.with_attr("system_lib_prefix", system_lib_prefix), + target=args.target, + system_lib=True, + ).export_library( + str(output), + fcompile=tar.tar, + ) + + return build + + +def _build_webgpu(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + assert output.suffix == ".wasm" + relax.build( + mod, + target=args.target, + system_lib=True, + ).export_library( + str(output), + ) + + return build + + +def _build_default(): + def build(mod: IRModule, args: "CompileArgs"): + output = args.output + if output.suffix in [".a", ".lib"]: + system_lib = True + elif output.suffix in [".so", ".dylib", ".dll"]: + system_lib = False + else: + logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix) + system_lib = False + relax.build( + mod, + target=args.target, + system_lib=system_lib, + ).export_library( + str(output), + ) + + return build + + +AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"] + +PRESET = { + "iphone:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "libs": ["iphoneos"], + "host": { + "kind": "llvm", + "mtriple": "arm64-apple-darwin", + }, + }, + "build": _build_iphone, + }, + "android:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-android", + }, + }, + "build": _build_android, + }, + "metal:x86-64": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + "build": _build_metal_x86_64, + }, + "webgpu:generic": { + "target": { + "kind": "webgpu", + "host": { + "kind": "llvm", + "mtriple": "wasm32-unknown-unknown-wasm", + }, + }, + "build": _build_webgpu, + }, + "opencl:generic": { + "target": { + "kind": "opencl", + }, + }, + "mali:generic": { + "target": { + "kind": "opencl", + "host": { + "kind": "llvm", + "mtriple": "aarch64-linux-gnu", + }, + }, + }, + "metal:generic": { + "target": { + "kind": "metal", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + }, + }, + "vulkan:generic": { + "target": { + "kind": "vulkan", + "max_threads_per_block": 256, + "max_shared_memory_per_block": 32768, + "thread_warp_size": 1, + "supports_float16": 1, + "supports_int16": 1, + "supports_int8": 1, + "supports_8bit_buffer": 1, + "supports_16bit_buffer": 1, + "supports_storage_buffer_storage_class": 1, + }, + }, +} diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py index 74e8a8b8c0..b19ec6b07a 100644 --- a/python/mlc_chat/support/auto_weight.py +++ b/python/mlc_chat/support/auto_weight.py @@ -8,7 +8,9 @@ def detect_weight( - weight_path: Path, config_json_path: Path, weight_format: str = "auto" + weight_path: Path, + config_json_path: Path, + weight_format: str = "auto", ) -> Tuple[Path, str]: """Detect the weight directory, and detect the weight format. diff --git a/python/mlc_chat/support/style.py b/python/mlc_chat/support/style.py new file mode 100644 index 0000000000..5b2272e1a0 --- /dev/null +++ b/python/mlc_chat/support/style.py @@ -0,0 +1,62 @@ +"""Printing styles.""" + +from enum import Enum + + +class Styles(Enum): + """Predefined set of styles to be used. + + Reference: + - https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit + - https://stackoverflow.com/a/17303428 + """ + + RED = "\033[91m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + PURPLE = "\033[95m" + CYAN = "\033[96m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +def red(text: str) -> str: + """Return red text.""" + return f"{Styles.RED.value}{text}{Styles.END.value}" + + +def green(text: str) -> str: + """Return green text.""" + return f"{Styles.GREEN.value}{text}{Styles.END.value}" + + +def yellow(text: str) -> str: + """Return yellow text.""" + return f"{Styles.YELLOW.value}{text}{Styles.END.value}" + + +def blue(text: str) -> str: + """Return blue text.""" + return f"{Styles.BLUE.value}{text}{Styles.END.value}" + + +def purple(text: str) -> str: + """Return purple text.""" + return f"{Styles.PURPLE.value}{text}{Styles.END.value}" + + +def cyan(text: str) -> str: + """Return cyan text.""" + return f"{Styles.CYAN.value}{text}{Styles.END.value}" + + +def bold(text: str) -> str: + """Return bold text.""" + return f"{Styles.BOLD.value}{text}{Styles.END.value}" + + +def underline(text: str) -> str: + """Return underlined text.""" + return f"{Styles.UNDERLINE.value}{text}{Styles.END.value}" diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index e8757fd234..cb77a59b71 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler.model.llama import LlamaConfig, LlamaForCasualLM +from mlc_chat.compiler.model.llama_config import LlamaConfig +from mlc_chat.compiler.model.llama_model import LlamaForCasualLM @pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) diff --git a/tests/python/parameter/test_hf_loader.py b/tests/python/parameter/test_huggingface.py similarity index 82% rename from tests/python/parameter/test_hf_loader.py rename to tests/python/parameter/test_huggingface.py index 4e983d83dd..851bd42aea 100644 --- a/tests/python/parameter/test_hf_loader.py +++ b/tests/python/parameter/test_huggingface.py @@ -4,9 +4,9 @@ from typing import Union import pytest -from mlc_chat.compiler.model.llama import LlamaConfig -from mlc_chat.compiler.model.llama_parameter import hf_torch -from mlc_chat.compiler.parameter import HFLoader +from mlc_chat.compiler.model.llama_config import LlamaConfig +from mlc_chat.compiler.model.llama_parameter import huggingface +from mlc_chat.compiler.parameter import HuggingFaceLoader from mlc_chat.support import tqdm logging.basicConfig( @@ -31,7 +31,7 @@ def test_load_torch_llama(base_path: Union[str, Path]): path_params = base_path / "pytorch_model.bin.index.json" config = LlamaConfig.from_file(path_config) - loader = HFLoader(path=path_params, extern_param_map=hf_torch(config)) + loader = HuggingFaceLoader(path=path_params, extern_param_map=huggingface(config, None)) with tqdm.redirect(): for _name, _param in loader.load(): ... @@ -51,7 +51,7 @@ def test_load_safetensor_llama(base_path: Union[str, Path]): path_params = base_path / "model.safetensors.index.json" config = LlamaConfig.from_file(path_config) - loader = HFLoader(path=path_params, extern_param_map=hf_torch(config)) + loader = HuggingFaceLoader(path=path_params, extern_param_map=huggingface(config, None)) with tqdm.redirect(): for _name, _param in loader.load(): ... From 8ce77931a5fb7512982b363e2b11febeb771b96b Mon Sep 17 00:00:00 2001 From: Git bot Date: Tue, 24 Oct 2023 07:30:53 +0000 Subject: [PATCH 047/116] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 631f37b6bf..30b4fa3c13 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 631f37b6bf8b101d16ecc55de7e6a749a3588570 +Subproject commit 30b4fa3c13fc80d5c9151a9dc445d22c57ced3e0 From 488017da110555a67f55b947ea8bd41974f129fc Mon Sep 17 00:00:00 2001 From: SingLi Date: Tue, 24 Oct 2023 08:19:31 -0500 Subject: [PATCH 048/116] fix mismatched argument name (#1117) fix error introduced by recent code changes fixes #1116 --- python/mlc_chat/rest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index 486b20e965..2c26edfa9a 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -115,7 +115,7 @@ async def lifespan(app: FastAPI): chat_mod = ChatModule( model=ARGS.model, device=ARGS.device, - lib_path=ARGS.lib_path, + model_lib_path=ARGS.lib_path, ) session["chat_mod"] = chat_mod From 206103b57a9215b84fc1a730b4517a454c553fcc Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Tue, 24 Oct 2023 11:54:01 -0400 Subject: [PATCH 049/116] [Docs] Add doc for max and mean gen len, shift factor; and buildArgs (#1119) * Add doc for max and mean gen len, shift factor * Update python docs for BuildArgs --- mlc_llm/core.py | 26 +++++++++++++++++++++++++- python/mlc_chat/chat_module.py | 13 +++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 81de89b7cb..e720d19542 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -79,11 +79,35 @@ class BuildArgs: Build with separated embedding layer, only applicable to LlaMa. This feature is in testing stage, and will be formally replaced after massive overhaul of embedding feature for all models and use cases. + cc_path: str + ``/path/to/cross_compiler_path``; currently only used for cross-compile + for nvidia/jetson device. + use_safetensors: bool + Specifies whether to use ``.safetensors`` instead of the default ``.bin`` + when loading in model weights. enable_batching: bool Build the model for batched inference. This is a temporary flag used to control the model execution flow in single- sequence and batching settings for now. We will eventually merge two flows in the future and remove this flag then. + no_cutlass_attn: bool + Disable offloading attention operations to CUTLASS. + no_cutlass_norm: bool + Disable offloading layer and RMS norm operations to CUTLASS. + no_cublas: bool + Disable the step that offloads matmul to cuBLAS. Without this flag, + matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or + ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. + use_cuda_graph: bool + Specifies whether to enable CUDA Graph for the decoder. MLP and QKV + projection between two attention layers are put into a graph. + num_shards: int + Number of shards to split the model into in tensor parallelism multi-gpu + inference. Only useful when ``build_model_only`` is set. + use_flash_attn_mqa: bool + Offload multi-query attention workload to Flash Attention. + pdb: bool + If set, drop into a pdb debugger on error. """ model: str = field( default="auto", @@ -217,7 +241,7 @@ class BuildArgs: "help": ( "Disable the step that offloads matmul to cuBLAS. Without this flag, " "matmul will be offloaded to cuBLAS if quantization mode is q0f16 or q0f32, " - "target is CUDA and TVM has been built with cuBLAS enbaled." + "target is CUDA and TVM has been built with cuBLAS enabled." ), "action": "store_true", }, diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 3bc32309e7..0e8e871534 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -91,7 +91,7 @@ class ChatConfig: :class:`mlc_chat.ChatModule` instance to override the default setting in ``mlc-chat-config.json`` under the model folder. - Since the configuraiton is partial, everything will be ``Optional``. + Since the configuration is partial, everything will be ``Optional``. Note that we will exploit this class to also represent ``mlc-chat-config.json`` during intermediate processing. @@ -131,14 +131,19 @@ class ChatConfig: For additional information on top-p sampling, please refer to this blog post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. mean_gen_len : Optional[int] + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. max_gen_len : Optional[int] + The maximum number of tokens to be generated in each round. Would simply + stop generating after this number is exceeded. shift_fill_factor : Optional[float] + The fraction of maximum window size to shift when it is exceeded. tokenizer_files : Optional[List[str]] List of tokenizer files of the model. conv_config : Optional[ConvConfig] The partial overriding configuration for conversation template. Will first load the predefined template with the name specified in ``conv_template`` - and then override some of the configuraitons specified in ``conv_config``. + and then override some of the configurations specified in ``conv_config``. model_category : Optional[str] The category of the model's architecture (e.g. ``llama``, ``gpt_neox``, ``rwkv``). model_name : Optional[str] @@ -216,7 +221,11 @@ class GenerationConfig: For additional information on top-p sampling, please refer to this blog post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. mean_gen_len : Optional[int] + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. max_gen_len : Optional[int] + The maximum number of tokens to be generated in each round. Would simply + stop generating after this number is exceeded. """ temperature: Optional[float] = None From 2aa6809e57c4cc8543f0e0d55792f2f0bebac401 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 24 Oct 2023 09:03:38 -0700 Subject: [PATCH 050/116] Revert "[ParamManager] Use BundleModelParams for transform_dequantize" (#1120) Revert "[ParamManager] Use BundleModelParams for transform_dequantize (#1056)" This reverts commit e5927cee3b932b6e3116b43778008a3aa11ef0a3. This causes a regression impacting all MLC LLM nightlies as it violates the existing calling convention in MLC Chat runtime. An example: https://github.com/mlc-ai/mlc-llm/issues/1060#issuecomment-1776761032 --- mlc_llm/core.py | 3 +- mlc_llm/relax_model/param_manager.py | 114 ++++++++++++++------------- 2 files changed, 62 insertions(+), 55 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index e720d19542..8c3a75c374 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -420,8 +420,7 @@ def mod_transform_before_build( if args.model.lower().startswith("rwkv-"): model_names += ["reset_kv_cache"] - mod = param_manager.transform_dequantize()(mod) - mod = relax.transform.BundleModelParams()(mod) + mod = param_manager.transform_dequantize(mod) use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index c6729b41e0..04f56a5152 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -369,7 +369,7 @@ def set_param_loading_func( else: self.pidx2pname = dict() - def transform_dequantize(self) -> tvm.ir.transform.Pass: + def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: """Apply dequantization to the input IRModule. Parameters @@ -386,48 +386,38 @@ def transform_dequantize(self) -> tvm.ir.transform.Pass: The IRModule updated with the dequantization computation. """ - @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") - def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: - # For each Relax function in the input IRModule (e.g., "prefill"), - # we create its input relax.Var of all the quantized data, and - # store the mapping from function name to the var. - func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} - - for gv, func in mod.functions.items(): - if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: - quantized_param_info = self.get_quantized_param_info(gv.name_hint) - param_vars = [ - relax.Var(f"param_{i}", info) - for i, info in enumerate(quantized_param_info.fields) - ] - func_name_to_quantized_params[gv.name_hint] = param_vars - - # Cache mapping to avoid duplicate dequantization. - dequantized_cache: Dict[relax.Var, relax.Var] = {} - - # Define a var replacement function for applying dequantization. - def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: - if var in dequantized_cache: - return dequantized_cache[var] - assert var in self.func_raw_param_map - - func_name, param = self.func_raw_param_map[var] - quantized_params = func_name_to_quantized_params[func_name] - relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] - - dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) + # For each Relax function in the input IRModule (e.g., "prefill"), + # we create its input relax.Var of all the quantized data, and + # store the mapping from function name to the var. + func2param_var: Dict[str, relax.Var] = {} + for gv, func in mod.functions.items(): + if not isinstance(func, relax.Function): + continue + if func.attrs is None or not "num_input" in func.attrs: + continue + func2param_var[gv.name_hint] = relax.Var( + "params", self.get_quantized_param_info(gv.name_hint) + ) - dequantized_cache[var] = dequantized - return dequantized + # Cache mapping to avoid duplicate dequantization. + dequantized_cache: Dict[relax.Var, relax.Var] = {} - # Create the function mutator for applying dequantization. - replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) - # Update the input IRModule with dequantization. - mod = replacer.transform() + # Define a var replacement function for applying dequantization. + def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var: + if var in dequantized_cache: + return dequantized_cache[var] + assert var in self.func_raw_param_map + func_name, param = self.func_raw_param_map[var] + dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name) + dequantized_cache[var] = dequantized + return dequantized - return mod + # Create the function mutator for applying dequantization. + replacer = ParamReplacer(mod, func2param_var, f_replace) + # Update the input IRModule with dequantization. + mod = replacer.transform() - return transform_func + return mod def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]: bb = relax.BlockBuilder() @@ -707,9 +697,10 @@ def _register_param( def _dequantize( self, param: Parameter, - qparams: List[relax.Var], + quantized_tuple: relax.Var, bb: relax.BlockBuilder, func_name: str, + qparams: List[relax.Var] = None, ) -> relax.Var: """Applying dequantization to the input parameter. This method is called by `transform_module` below, and is not @@ -720,13 +711,30 @@ def _dequantize( param : Parameter The parameter whose quantized tensors are to be dequantized. - qparams : List[relax.Var] + quantized_tuple : relax.Var The relax.Var of the quantized tensors of all parameters in the model. + bb : relax.BlockBuilder + The Relax BlockBuilder used for inserting the dequantization computations. + + func_name : str + The name of the function which dequantization is applied to. + + qparams : List[relax.Var] + The quantized parts of the parameter. + By default it is `None`, in which case we will get the quantized parts + from `quantized_tuple`. + Returns ------- The dequantized parameter, in the form of a relax.Var. """ + if not qparams: + # Get the corresponding Relax vars of the quantized tensors of this parameter. + qparams: List[relax.Var] = [] + for qparam_idx in self.param2qrange[param]: + qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx))) + # Get the dequantization function of this parameter. f_dequantize = param.quant_spec.get_dequantize_func( param_info=param.param_info_dict[func_name], @@ -781,7 +789,7 @@ class ParamReplacer(PyExprMutator): mod : tvm.IRModule The IRModule of the model to be updated. - func_name_to_quantized_params : Dict[str, List[relax.Var]] + func2param_var : Dict[str, relax.Var] The mapping from each function name to its input var of quantized data tuple. f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] @@ -793,7 +801,7 @@ class ParamReplacer(PyExprMutator): """ mod: tvm.IRModule - func_name_to_quantized_params: Dict[str, List[relax.Var]] + func2param_var: Dict[str, relax.Var] f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] param_set: Set[relax.Var] @@ -802,12 +810,12 @@ class ParamReplacer(PyExprMutator): def __init__( self, mod: tvm.IRModule, - func_name_to_quantized_params: Dict[str, relax.Var], + func2param_var: Dict[str, relax.Var], f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], ): super().__init__(mod) self.mod = mod - self.func_name_to_quantized_params = func_name_to_quantized_params + self.func2param_var = func2param_var self.f_replace = f_replace self.cur_func_name = "" @@ -819,20 +827,21 @@ def transform(self) -> tvm.IRModule: continue assert ( - gv.name_hint in self.func_name_to_quantized_params - ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" - updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) + gv.name_hint in self.func2param_var + ), f"{gv.name_hint} not in {self.func2param_var}" + self.cur_func_name = gv.name_hint + updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint]) updated_func = remove_all_unused(updated_func) self.builder_.update_func(gv, updated_func) return self.builder_.get() - def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: + def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function: num_input = int(func.attrs["num_input"]) self.param_set = set(func.params[num_input:]) body = self.visit_expr(func.body) return relax.Function( - params=func.params[:num_input] + quantized_params, + params=func.params[:num_input] + [param_var], body=body, ret_struct_info=func.ret_struct_info, is_pure=func.is_pure, @@ -840,10 +849,9 @@ def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> rel ).without_attr("num_input") def visit_var_(self, var: Var) -> Expr: - if var in self.param_set: - return self.f_replace(var, self.builder_) - else: + if var not in self.param_set: return super().visit_var_(var) + return self.f_replace(var, self.builder_, self.cur_func_name) ################################################################## From 9cb8e8ed317ce98c650b49a0cb878ef8b3a009c7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 24 Oct 2023 09:04:45 -0700 Subject: [PATCH 051/116] Remove inaccurate warning message (#1121) This PR removes an inaccurate warning from #1086, which warns about `model_lib` overriding regardless of whether or not it's actually overridden. With this commit, we only warn if its value is not None. --- python/mlc_chat/chat_module.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 0e8e871534..38eb8f1f33 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -356,22 +356,22 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi final_chat_config = None with open(config_file_path, mode="rt", encoding="utf-8") as f: json_object = json.load(f) - final_chat_config = ChatConfig._from_json(json_object) + final_chat_config = ChatConfig._from_json(json_object) # pylint: disable=protected-access if user_chat_config is not None: # We override using user's chat config for field in fields(user_chat_config): field_name = field.name - if field_name == "model_lib": - warn_msg = ( - 'WARNING: Do not override "model_lib" in ChatConfig. ' - "This override will be ignored. " - "Please use ChatModule.model_lib_path to override the full model library path instead." - ) - warnings.warn(warn_msg) - continue field_value = getattr(user_chat_config, field_name) if field_value is not None: - setattr(final_chat_config, field_name, field_value) + if field_name == "model_lib": + warn_msg = ( + 'WARNING: Do not override "model_lib" in ChatConfig. ' + "This override will be ignored. Please use ChatModule.model_lib_path to " + "override the full model library path instead." + ) + warnings.warn(warn_msg) + else: + setattr(final_chat_config, field_name, field_value) return final_chat_config From 9166edbf844bf039314e4453ff9f441c4738c8a6 Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Tue, 24 Oct 2023 15:07:23 -0400 Subject: [PATCH 052/116] [REST] OpenAI compatible Rest API (#1107) * add presence and frequency penalty * Added support for passing conversation history in /v1/chat/completions endpoint * Added support for RestAPI parameters max_gen_len, n, and stop_str * * add presence and frequency penalty to generation config * refactor generation config * Added documentation for parameters * replace lib_path with model_lib_path in rest.py * fixed black isort issues * fix lib_path --- cpp/llm_chat.cc | 232 +++++++++++++++--------- python/mlc_chat/chat_module.py | 138 +++++++++++--- python/mlc_chat/interface/openai_api.py | 11 +- python/mlc_chat/rest.py | 37 ++-- 4 files changed, 288 insertions(+), 130 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 339e5429d1..7b869854d6 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -650,7 +650,7 @@ class LLMChat { this->ResetRuntimeStats(); } output_ids_.clear(); - appeared_token_ids_.clear(); + appeared_token_freq_.clear(); output_message_.clear(); stop_triggered_ = false; if (append_conversation) { @@ -672,12 +672,8 @@ class LLMChat { PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll, String generation_config_str = "") { // process generation settings - picojson::object generation_config = picojson::object(); - if (!generation_config_str.empty()) { - picojson::value generation_config_json; - picojson::parse(generation_config_json, generation_config_str); - generation_config = generation_config_json.get(); - } + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); std::vector prompt_tokens = PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config); @@ -729,13 +725,8 @@ class LLMChat { return; } - // process generation settings - picojson::object generation_config = picojson::object(); - if (!generation_config_str.empty()) { - picojson::value generation_config_json; - picojson::parse(generation_config_json, generation_config_str); - generation_config = generation_config_json.get(); - } + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); int32_t next_token = this->SampleTokenFromLogits(logits_on_device, generation_config); @@ -743,7 +734,7 @@ class LLMChat { this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; this->prefill_total_tokens += token_len; - this->ProcessNextToken(next_token); + this->ProcessNextToken(next_token, generation_config); } /*! @@ -768,13 +759,8 @@ class LLMChat { return; } - // process generation settings - picojson::object generation_config = picojson::object(); - if (!generation_config_str.empty()) { - picojson::value generation_config_json; - picojson::parse(generation_config_json, generation_config_str); - generation_config = generation_config_json.get(); - } + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); std::vector prompt_tokens = this->PrepareBeforeEmbedding(inp, append_conversation, place_in_prompt, generation_config); @@ -803,17 +789,12 @@ class LLMChat { this->prefill_total_time += static_cast((tend - tstart).count()) / 1e9; this->prefill_total_tokens += token_len; - this->ProcessNextToken(next_token); + this->ProcessNextToken(next_token, generation_config); } void DecodeStep(String generation_config_str = "") { - // process generation settings - picojson::object generation_config = picojson::object(); - if (!generation_config_str.empty()) { - picojson::value generation_config_json; - picojson::parse(generation_config_json, generation_config_str); - generation_config = generation_config_json.get(); - } + picojson::object generation_config = + this->LoadGenerationConfigFromString(generation_config_str); ICHECK(!output_ids_.empty()); int32_t last_token = output_ids_.back(); @@ -830,7 +811,7 @@ class LLMChat { this->decode_total_time += static_cast((tend - tstart).count()) / 1e9; this->decode_total_tokens += 1; - this->ProcessNextToken(next_token); + this->ProcessNextToken(next_token, generation_config); } bool Stopped() { return stop_triggered_; } @@ -931,6 +912,8 @@ class LLMChat { picojson::object config; config["temperature"] = picojson::value(this->temperature_); config["repetition_penalty"] = picojson::value(this->repetition_penalty_); + config["presence_penalty"] = picojson::value(this->presence_penalty_); + config["frequency_penalty"] = picojson::value(this->frequency_penalty_); config["top_p"] = picojson::value(this->top_p_); config["mean_gen_len"] = picojson::value(this->mean_gen_len_); config["max_gen_len"] = picojson::value(this->max_gen_len_); @@ -938,54 +921,100 @@ class LLMChat { config["conv_config"] = this->conversation_.SerializeToJSON(); return picojson::value(config); } - /*! - * \brief Sample output token from logits on device - */ - int32_t SampleTokenFromLogits(NDArray logits_on_device, - picojson::object generation_config = picojson::object()) { - // prepare generation settings - // the generation_config will not override the original config - // since is only used for this generation - double gen_temperature; - double gen_repetition_penalty; - double gen_top_p; + + picojson::object LoadGenerationConfigFromString(const std::string& generation_config_str) { + picojson::object generation_config = picojson::object(); + if (!generation_config_str.empty()) { + picojson::value generation_config_json; + picojson::parse(generation_config_json, generation_config_str); + generation_config = generation_config_json.get(); + } + return generation_config; + } + + void ReadGenerationConfig(picojson::object generation_config, double* gen_temperature, + NDArray* gen_temperature_arr, double* gen_repetition_penalty, + double* gen_presence_penalty, double* gen_frequency_penalty, + double* gen_top_p) { if (generation_config.count("temperature")) { CHECK(generation_config["temperature"].is()); - gen_temperature = generation_config["temperature"].get(); - if (gen_temperature != this->temperature_) { - this->temperature_ = gen_temperature; - float temperature_cast = static_cast(gen_temperature); - this->temperature_arr_.CopyFromBytes(&temperature_cast, sizeof(float)); - } + *gen_temperature = generation_config["temperature"].get(); + + *gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_); + float temperature_cast = static_cast(*gen_temperature); + gen_temperature_arr->CopyFromBytes(&temperature_cast, sizeof(float)); } else { - gen_temperature = this->temperature_; + *gen_temperature = this->temperature_; + *gen_temperature_arr = this->temperature_arr_; } if (generation_config.count("repetition_penalty")) { CHECK(generation_config["repetition_penalty"].is()); - gen_repetition_penalty = generation_config["repetition_penalty"].get(); + CHECK(generation_config["repetition_penalty"].get() > 0) + << "Repetition penalty must be a positive number!"; + *gen_repetition_penalty = generation_config["repetition_penalty"].get(); + } else { + *gen_repetition_penalty = this->repetition_penalty_; + } + if (generation_config.count("presence_penalty")) { + CHECK(generation_config["presence_penalty"].is()); + CHECK(abs(generation_config["presence_penalty"].get()) <= 2) + << "Presence penalty must be in the range -2 to 2!"; + *gen_presence_penalty = generation_config["presence_penalty"].get(); + } else { + *gen_presence_penalty = this->presence_penalty_; + } + if (generation_config.count("frequency_penalty")) { + CHECK(generation_config["frequency_penalty"].is()); + CHECK(abs(generation_config["frequency_penalty"].get()) <= 2) + << "Frequency penalty must be in the range -2 to 2!"; + *gen_frequency_penalty = generation_config["frequency_penalty"].get(); } else { - gen_repetition_penalty = this->repetition_penalty_; + *gen_frequency_penalty = this->frequency_penalty_; } if (generation_config.count("top_p")) { CHECK(generation_config["top_p"].is()); - gen_top_p = generation_config["top_p"].get(); + *gen_top_p = generation_config["top_p"].get(); } else { - gen_top_p = this->top_p_; + *gen_top_p = this->top_p_; } + } + + /*! + * \brief Sample output token from logits on device + */ + int32_t SampleTokenFromLogits(NDArray logits_on_device, + picojson::object generation_config = picojson::object()) { + // prepare generation settings + // the generation_config will not override the original config + // since is only used for this generation + double gen_temperature; + double gen_repetition_penalty; + double gen_presence_penalty; + double gen_frequency_penalty; + double gen_top_p; + this->ReadGenerationConfig(generation_config, &gen_temperature, &this->temperature_arr_, + &gen_repetition_penalty, &gen_presence_penalty, + &gen_frequency_penalty, &gen_top_p); // update logits - if (gen_repetition_penalty == 1.0f) { - if (gen_temperature < 1e-6f) { - this->UpdateLogitsOrProbOnCPUSync(logits_on_device); - } else { - this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); + if (gen_presence_penalty != 0.0f || gen_frequency_penalty != 0.0f) { + this->UpdateLogitsOrProbOnCPUSync(logits_on_device); + this->ApplyPresenceAndFrequencyPenaltyOnCPU(gen_presence_penalty, gen_presence_penalty); + if (gen_temperature >= 1e-6f) { + this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); } - } else { + } else if (gen_repetition_penalty != 1.0f) { this->UpdateLogitsOrProbOnCPUSync(logits_on_device); this->ApplyRepetitionPenaltyOnCPU(gen_repetition_penalty); if (gen_temperature >= 1e-6f) { this->ApplySoftmaxWithTemperatureOnCPU(gen_temperature); } + } else { + if (gen_temperature < 1e-6f) { + this->UpdateLogitsOrProbOnCPUSync(logits_on_device); + } else { + this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_)); + } } // perform sampling @@ -1018,6 +1047,25 @@ class LLMChat { gen_max_gen_len = this->max_gen_len_; } + std::vector gen_stop_strs; + gen_stop_strs.push_back(conversation_.stop_str); + + if (generation_config.count("stop")) { + if (!generation_config["stop"].is()) { + CHECK(generation_config["stop"].is() || + generation_config["stop"].is()); + if (generation_config["stop"].is()) { + gen_stop_strs.push_back(generation_config["stop"].get()); + } else { + picojson::array gen_stop_strs_arr = generation_config["stop"].get(); + for (const picojson::value& v : gen_stop_strs_arr) { + CHECK(v.is()); + gen_stop_strs.push_back(v.get()); + } + } + } + } + ICHECK(!stop_triggered_) << "Cannot call process when it is stopped"; stop_triggered_ = @@ -1026,27 +1074,35 @@ class LLMChat { if (!stop_triggered_) { output_ids_.push_back(next_token); - appeared_token_ids_.insert(next_token); + if (appeared_token_freq_.find(next_token) != appeared_token_freq_.end()) { + appeared_token_freq_[next_token] += 1; + } else { + appeared_token_freq_[next_token] = 1; + } } output_message_ = tokenizer_->Decode(output_ids_); - if (!conversation_.stop_str.empty()) { - size_t stop_pos = output_message_.rfind(conversation_.stop_str); - if (stop_pos != std::string::npos) { - stop_triggered_ = true; - if (ft_.support_backtracking_kv_) { - // back tracking, find the first set of token that is smaller - // than the length - size_t backoff = 0; - for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) { - output_ids_.pop_back(); - output_message_ = tokenizer_->Decode(output_ids_); - } - // resize kv to remove the context - ft_.fkvcache_array_popn_(kv_cache_, backoff); - total_seq_len_ -= backoff; + size_t stop_pos = std::string::npos; + for (const std::string& stop_str : gen_stop_strs) { + if (!stop_str.empty()) { + stop_pos = std::min(stop_pos, output_message_.rfind(stop_str)); + } + } + + if (stop_pos != std::string::npos) { + stop_triggered_ = true; + if (ft_.support_backtracking_kv_) { + // back tracking, find the first set of token that is smaller + // than the length + size_t backoff = 0; + for (; (output_ids_.size() > 0) && (output_message_.length() > stop_pos); ++backoff) { + output_ids_.pop_back(); + output_message_ = tokenizer_->Decode(output_ids_); } + // resize kv to remove the context + ft_.fkvcache_array_popn_(kv_cache_, backoff); + total_seq_len_ -= backoff; } } @@ -1113,15 +1169,25 @@ class LLMChat { CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; float* logits_raw_data = static_cast(logits_on_cpu_->data); - for (const int32_t& token_id : this->appeared_token_ids_) { - if (logits_raw_data[token_id] <= 0) { - logits_raw_data[token_id] *= repetition_penalty; + for (const auto& token_freq : this->appeared_token_freq_) { + if (logits_raw_data[token_freq.first] <= 0) { + logits_raw_data[token_freq.first] *= repetition_penalty; } else { // logits > 0 - logits_raw_data[token_id] /= repetition_penalty; + logits_raw_data[token_freq.first] /= repetition_penalty; } } } + void ApplyPresenceAndFrequencyPenaltyOnCPU(float presence_penalty, float frequency_penalty) { + CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; + CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; + float* logits_raw_data = static_cast(logits_on_cpu_->data); + for (const auto& token_freq : this->appeared_token_freq_) { + logits_raw_data[token_freq.first] -= + (token_freq.second * frequency_penalty + presence_penalty); + } + } + void ApplySoftmaxWithTemperatureOnCPU(float temperature) { CHECK(logits_on_cpu_.defined()) << "Logits on CPU not defined!"; CHECK(logits_on_cpu_.DataType() == DataType::Float(32)) << "Logits data type is not float32!"; @@ -1211,12 +1277,16 @@ class LLMChat { NDArray temperature_arr_; // repetition penalty double repetition_penalty_{1.0}; + // presence penalty + double presence_penalty_{0.0}; + // frequency penalty + double frequency_penalty_{0.0}; // top_p double top_p_{0.95}; // output ids till now (refresh after encoding step) std::vector output_ids_; - // appeared token ids till now (refresh after encoding step) - std::unordered_set appeared_token_ids_; + // frequency of appeared token ids till now (refresh after encoding step) + std::unordered_map appeared_token_freq_; // output message till now (refresh after encoding step) std::string output_message_; // Whether encounter stop str diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 38eb8f1f33..02625f4ef4 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -8,12 +8,13 @@ import warnings from dataclasses import asdict, dataclass, fields from enum import Enum -from typing import List, Optional +from typing import List, Optional, Union import tvm from tvm.runtime import disco from . import callback +from .interface.openai_api import ChatMessage # pylint: disable=line-too-long _PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" @@ -202,13 +203,24 @@ class GenerationConfig: The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs. + presence_penalty : Optional[float] + Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood + to talk about new topics. Negative values can increase the likelihood of + repetition. + frequency_penalty : Optional[float] + Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. Negative values can increase the likelihood of + repetition. repetition_penalty : Optional[float] The repetition penalty controls the likelihood of the model generating repeated texts. The default value is set to ``1.0``, indicating that no repetition penalty is applied. Increasing the value reduces the likelihood of repeat text generation. However, setting a high ``repetition_penalty`` may result in the model generating meaningless - texts. The ideal choice of repetition penalty may vary among models. + texts. The ideal choice of repetition penalty may vary among models. Only + Active when presence_penalty and frequency_penalty are both 0.0. For more details on how repetition penalty controls text generation, please check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). @@ -224,8 +236,17 @@ class GenerationConfig: The approximated average number of generated tokens in each round. Used to determine whether the maximum window size would be exceeded. max_gen_len : Optional[int] - The maximum number of tokens to be generated in each round. Would simply - stop generating after this number is exceeded. + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. + n : Optional[int] + This parameter determines the number of text samples to generate. The default + value is ``1``. Note that this parameter is only used when ``stream`` is set to + ``False``. + stop : Optional[Union[str, List[str]]] + When ``stop`` is encountered, the model will stop generating output. + It can be a string or a list of strings. If it is a list of strings, the model + will stop generating output when any of the strings in the list is encountered. + Note that this parameter does not override the default stop string of the model. """ temperature: Optional[float] = None @@ -233,6 +254,10 @@ class GenerationConfig: top_p: Optional[float] = None mean_gen_len: Optional[int] = None max_gen_len: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + n: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None @classmethod def _from_chat_config(generation_config_cls, chat_config_obj: ChatConfig): @@ -767,18 +792,24 @@ def __init__( def generate( self, - prompt: str, + prompt: Union[str, List[ChatMessage]], generation_config: Optional[GenerationConfig] = None, progress_callback=None, - ) -> str: + ) -> Union[str, List[str]]: r"""A high-level method that returns the full response from the chat module given a user prompt. User can optionally specify which callback method to use upon receiving the response. By default, no callback will be applied. Parameters ---------- - prompt : str + prompt : Union[str, List[ChatMessage]] The user input prompt, i.e. a question to ask the chat module. + It can also be the whole conversation history (list of messages with role and content) + eg: ```[ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ]``` generation_config: Optional[GenerationConfig] The generation config object to override the ChatConfig generation settings. progress_callback: object @@ -808,25 +839,36 @@ def generate( ) print(output) """ - self._prefill(prompt, generation_config=generation_config) - - if not progress_callback: - while not self._stopped(): - self._decode(generation_config=generation_config) - new_msg = self._get_message() - return new_msg - - # apply callback with a rate of callback_interval - i, new_msg = 0, "" - while not self._stopped(): - self._decode(generation_config=generation_config) - if i % progress_callback.callback_interval == 0 or self._stopped(): - new_msg = self._get_message() - progress_callback(new_msg) - i += 1 - progress_callback(stopped=True) + new_msgs = [] + num_return_sequences = 1 + return_str = True + if (generation_config is not None) and (generation_config.n is not None): + num_return_sequences = generation_config.n + return_str = False + else: + num_return_sequences = 1 - return new_msg + for _ in range(num_return_sequences): + self.reset_chat() + self._prefill(prompt, generation_config=generation_config) + + if not progress_callback: + while not self._stopped(): + self._decode(generation_config=generation_config) + new_msg = self._get_message() + new_msgs.append(new_msg) + else: + # apply callback with a rate of callback_interval + i, new_msg = 0, "" + while not self._stopped(): + self._decode(generation_config=generation_config) + if i % progress_callback.callback_interval == 0 or self._stopped(): + new_msg = self._get_message() + progress_callback(new_msg) + i += 1 + progress_callback(stopped=True) + new_msgs.append(new_msg) + return new_msgs[0] if return_str else new_msgs def reset_chat(self, chat_config: Optional[ChatConfig] = None): r"""Reset the chat session, clear all chat history, and potentially @@ -964,7 +1006,7 @@ def _unload(self): def _prefill( self, - input: str, + input: Union[str, List[ChatMessage]], decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, generation_config: Optional[GenerationConfig] = None, @@ -974,8 +1016,14 @@ def _prefill( Parameters ---------- - input : str - The user input string. + input : Union[str, List[ChatMessage]] + The user input prompt, i.e. a question to ask the chat module. + It can also be the whole conversation history (list of messages with role and content) + eg: ```[ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ]``` decode_next_token : bool Whether to decode the next token after prefilling. place_in_prompt: PlaceInPrompt @@ -986,7 +1034,38 @@ def _prefill( generation_config = _get_generation_config(self.chat_config, generation_config) generation_config_str = _convert_generation_config_to_json_str(generation_config) - self._prefill_func(input, decode_next_token, place_in_prompt.value, generation_config_str) + if isinstance(input, list): + # Populate conversation.messages using load_json_override + if len(input) > 1: + conv_config = json.loads(self._get_config_json())["conv_config"] + messages = [] + role0 = self._get_role_0() + role1 = self._get_role_1() + for idx, msg in enumerate(input[:-1]): + role = msg.role + content = msg.content + if role == "user": + messages.append([role0, content]) + elif role == "assistant": + messages.append([role1, content]) + else: + raise ValueError("Only user and assistant roles are supported.") + if not input[-1].role == "user": + raise ValueError("Last message should be from user.") + conv_config["messages"] = messages + conv_config[ + "offset" + ] = 0 # Otherwise, the offset will be set to the length of the conversation, which means history will be retained even after calling reset_chat + self._load_json_override( + json.dumps({"conv_config": conv_config}), partial_update=True + ) + input_str = input[-1].content + else: + input_str = input + + self._prefill_func( + input_str, decode_next_token, place_in_prompt.value, generation_config_str + ) def _embed( self, @@ -1050,7 +1129,6 @@ def _decode(self, generation_config: Optional[GenerationConfig] = None): """ generation_config = _get_generation_config(self.chat_config, generation_config) generation_config_str = _convert_generation_config_to_json_str(generation_config) - self._decode_func(generation_config_str) def _stopped(self) -> bool: diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index 2a94607741..ed08c75b0a 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -26,14 +26,15 @@ class ChatCompletionRequest(BaseModel): mean_gen_len: int = None # TODO: replace by max_tokens max_gen_len: int = None + presence_penalty: float = None + frequency_penalty: float = None + n: int = None + stop: Union[str, List[str]] = None # TODO: Implement support for the OpenAI API parameters # function [] # function_call - # n: Optional[int] = 1 # stop: Optional[Union[str, List[str]]] = None # max_tokens: Optional[int] - # presence_penalty: Optional[float] = 0.0 - # frequency_penalty: Optional[float] = 0.0 # logit_bias # user: Optional[str] = None @@ -87,6 +88,8 @@ class CompletionRequest(BaseModel): mean_gen_len: int = None # TODO: replace by max_tokens max_gen_len: int = None + presence_penalty: float = None + frequency_penalty: float = None # TODO: Implement support for the OpenAI API parameters # suffix # max_tokens: Optional[int] @@ -94,8 +97,6 @@ class CompletionRequest(BaseModel): # logprobs # echo # stop: Optional[Union[str, List[str]]] = None - # presence_penalty: Optional[float] = 0.0 - # frequency_penalty: Optional[float] = 0.0 # best_of # logit_bias # user: Optional[str] = None diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index 2c26edfa9a..d48316845d 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -167,28 +167,32 @@ async def __anext__(self): async def request_completion(request: ChatCompletionRequest): """ Creates model response for the given chat conversation. + The messages field contains a list of messages (describing the conversation history). eg: + ```"messages": [{"role": "user", "content": "What's my name?"}, + {"role": "assistant", "content": "Your name is Llama."}, + {"role": "user", "content": "No, that's your name. My name is X."}, + {"role": "assistant", "content": "Ah, my apologies! Your name is X! "}, + {"role": "user", "content": "What is the meaning of life?"}, + ] + ``` + ] """ - generation_config = GenerationConfig( temperature=request.temperature, repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, top_p=request.top_p, mean_gen_len=request.mean_gen_len, max_gen_len=request.max_gen_len, + n=request.n, + stop=request.stop, ) - if len(request.messages) > 1: - raise ValueError( - """ - The /v1/chat/completions endpoint currently only supports single message prompts. - Please ensure your request contains only one message - """ - ) + session["chat_mod"].reset_chat() # Reset previous history, KV cache, etc. if request.stream: - session["chat_mod"]._prefill( - input=request.messages[0].content, generation_config=generation_config - ) + session["chat_mod"]._prefill(input=request.messages, generation_config=generation_config) async def iter_response(): prev_txt = "" @@ -211,15 +215,18 @@ async def iter_response(): return StreamingResponse(iter_response(), media_type="text/event-stream") else: msg = session["chat_mod"].generate( - prompt=request.messages[0].content, generation_config=generation_config + prompt=request.messages, generation_config=generation_config ) + if isinstance(msg, str): + msg = [msg] return ChatCompletionResponse( choices=[ ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=msg), + index=index, + message=ChatMessage(role="assistant", content=msg[index]), finish_reason="stop", ) + for index in range(len(msg)) ], # TODO: Fill in correct usage info usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), @@ -235,6 +242,8 @@ async def request_completion(request: CompletionRequest): generation_config = GenerationConfig( temperature=request.temperature, repetition_penalty=request.repetition_penalty, + presence_penalty=request.presence_penalty, + frequency_penalty=request.frequency_penalty, top_p=request.top_p, mean_gen_len=request.mean_gen_len, max_gen_len=request.max_gen_len, From a4279e37319b9a3a0a670754a6800a2c02fe9de7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 24 Oct 2023 21:05:24 -0700 Subject: [PATCH 053/116] Add --opt flag parsing to CLI (#1123) --- ci/task/mypy.sh | 3 +- ci/task/pylint.sh | 2 +- python/mlc_chat/cli/compile.py | 87 +++++++++---------- python/mlc_chat/compiler/__init__.py | 9 +- python/mlc_chat/compiler/compile.py | 86 +++++++++++++++++- python/mlc_chat/compiler/model/model.py | 52 ++++++++--- .../compiler/quantization/__init__.py | 2 + .../compiler/quantization/quantization.py | 22 +++++ python/mlc_chat/support/auto_config.py | 44 ++++++++++ python/mlc_chat/support/auto_target.py | 27 ++++-- python/mlc_chat/support/auto_weight.py | 17 ++-- tests/python/model/test_llama.py | 8 +- tests/python/parameter/test_huggingface.py | 26 ++++-- .../python/{ => support}/test_auto_config.py | 3 +- .../python/{ => support}/test_auto_weight.py | 26 +++++- 15 files changed, 319 insertions(+), 95 deletions(-) create mode 100644 python/mlc_chat/compiler/quantization/__init__.py create mode 100644 python/mlc_chat/compiler/quantization/quantization.py rename tests/python/{ => support}/test_auto_config.py (95%) rename tests/python/{ => support}/test_auto_weight.py (69%) diff --git a/ci/task/mypy.sh b/ci/task/mypy.sh index abeec67b20..f241cf2c3c 100755 --- a/ci/task/mypy.sh +++ b/ci/task/mypy.sh @@ -11,4 +11,5 @@ set -x mypy ./python/mlc_chat/compiler \ ./python/mlc_chat/support \ ./tests/python/model \ - ./tests/python/parameter + ./tests/python/parameter \ + ./tests/python/support diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh index 652b7f63a9..9dae28767d 100755 --- a/ci/task/pylint.sh +++ b/ci/task/pylint.sh @@ -12,4 +12,4 @@ set -x pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support -pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter +pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter ./tests/python/support/ diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index bf976ef7f2..17b53797f4 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -5,10 +5,14 @@ from pathlib import Path from typing import Union -from mlc_chat.compiler.compile import compile # pylint: disable=redefined-builtin -from mlc_chat.compiler.model import MODELS, Model +from mlc_chat.compiler import ( # pylint: disable=redefined-builtin + MODELS, + QUANT, + OptimizationFlags, + compile, +) -from ..support.auto_config import detect_config +from ..support.auto_config import detect_config, detect_model_type from ..support.auto_target import detect_target_and_host logging.basicConfig( @@ -19,38 +23,22 @@ ) -def _parse_config(path: Union[str, Path]) -> Path: - try: - return detect_config(Path(path)) - except ValueError as err: - raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") - - -def _parse_output(path: Union[str, Path]) -> Path: - path = Path(path) - parent = path.parent - if not parent.is_dir(): - raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") - return path - +def main(): + """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" -def _parse_model_type(model_type: str, config: Path) -> Model: - if model_type == "auto": - with open(config, "r", encoding="utf-8") as config_file: - cfg = json.load(config_file) - if "model_type" not in cfg: - raise ValueError( - f"'model_type' not found in: {config}. " - f"Please explicitly specify `--model-type` instead" - ) - model_type = cfg["model_type"] - if model_type not in MODELS: - raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") - return MODELS[model_type] + def _parse_config(path: Union[str, Path]) -> Path: + try: + return detect_config(Path(path)) + except ValueError as err: + raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + parent = path.parent + if not parent.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") + return path -def main(): - """Parse command line argumennts and call `mlc_llm.compiler.compile`.""" parser = argparse.ArgumentParser("MLC LLM Compiler") parser.add_argument( "--config", @@ -64,16 +52,8 @@ def main(): "--quantization", type=str, required=True, - choices=[ - "q0f16", - "q0f32", - "q3f16_1", - "q3f32_1", - "q4f16_1", - "q4f16_ft", - "q4f32_1", - ], - help="The quantization format. TBD", + choices=list(QUANT.keys()), + help="Quantization format.", ) parser.add_argument( "--model-type", @@ -106,9 +86,21 @@ def main(): ) parser.add_argument( "--opt", + type=OptimizationFlags.from_str, + default="", + help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, " + "denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, " + "and O3 represents extreme optimization that could potentially break the system. " + "Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. " + '--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"', + ) + parser.add_argument( + "--prefix-symbols", type=str, default="", - help="Optimization flags.", + help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". ' + "This is useful when compiling multiple models into a single library to avoid symbol " + "conflicts. Differet from objcopy, this takes no effect for shared library.", ) parser.add_argument( "--output", @@ -117,15 +109,15 @@ def main(): required=True, help="The name of the output file. The suffix determines if the output file is a " "shared library or a static library. Available suffixes: " - "1) Linux: .so (shared), .a (static); " - "2) macOS: .dylib (shared), .a (static); " - "3) Windows: .dll (shared), .lib (static); " + "1) Linux: .so (shared), .tar (static); " + "2) macOS: .dylib (shared), .tar (static); " + "3) Windows: .dll (shared), .tar (static); " "4) Android, iOS: .tar (static); " "5) Web: .wasm (web assembly)", ) parsed = parser.parse_args() target, build_func = detect_target_and_host(parsed.device, parsed.host) - parsed.model_type = _parse_model_type(parsed.model_type, parsed.config) + parsed.model_type = detect_model_type(parsed.model_type, parsed.config) compile( config=parsed.config, quantization=parsed.quantization, @@ -133,6 +125,7 @@ def main(): target=target, opt=parsed.opt, build_func=build_func, + prefix_symbols=parsed.prefix_symbols, output=parsed.output, ) diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index 2206f480f6..c0f6c7e51b 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -2,4 +2,11 @@ A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency, but users could optionally import it if they want to use the compiler. """ -from . import model, parameter +from .compile import ( # pylint: disable=redefined-builtin + CompileArgs, + OptimizationFlags, + compile, +) +from .model import MODELS, Model +from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping +from .quantization import QUANT diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 2d68154bf3..8415ca21b8 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -1,4 +1,5 @@ """Python entrypoint of compilation.""" +import argparse import dataclasses import logging from io import StringIO @@ -15,15 +16,61 @@ @dataclasses.dataclass -class CompileArgs: +class OptimizationFlags: + """Optiization flags""" + + cutlass_attn: bool = True + cutlass_norm: bool = True + cublas_gemm: bool = False + cudagraph: bool = False + + def __repr__(self) -> str: + out = StringIO() + print(f"cutlass_attn={int(self.cutlass_attn)}", file=out, end="") + print(f";cutlass_norm={int(self.cutlass_norm)}", file=out, end="") + print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") + print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "OptimizationFlags": + """Parse optimization flags from a string.""" + + if source in OPT_FLAG_PRESET: + return OPT_FLAG_PRESET[source] + + def boolean(value: str) -> bool: + if value == "0": + return False + if value == "1": + return True + raise ValueError(f"Invalid boolean value: {value}") + + parser = argparse.ArgumentParser(description="optimization flags") + parser.add_argument("--cutlass_attn", type=boolean, default=True) + parser.add_argument("--cutlass_norm", type=boolean, default=True) + parser.add_argument("--cublas_gemm", type=boolean, default=False) + parser.add_argument("--cudagraph", type=boolean, default=False) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return OptimizationFlags( + cutlass_attn=results.cutlass_attn, + cutlass_norm=results.cutlass_norm, + cublas_gemm=results.cublas_gemm, + cudagraph=results.cudagraph, + ) + + +@dataclasses.dataclass +class CompileArgs: # pylint: disable=too-many-instance-attributes """Arguments to MLC LLM's compiler.""" config: Path quantization: str model_type: Model target: Target - opt: str + opt: OptimizationFlags build_func: Callable[[IRModule, "CompileArgs"], None] + prefix_symbols: str output: Path @@ -44,10 +91,41 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin quantization, model_type: Model, target: Target, - opt, + opt: OptimizationFlags, build_func: Callable[[IRModule, CompileArgs], None], + prefix_symbols: str, output: Path, ): """Compile a model given its configuration and quantization format to a specific target.""" - args = CompileArgs(config, quantization, model_type, target, opt, build_func, output) + args = CompileArgs( + config, quantization, model_type, target, opt, build_func, prefix_symbols, output + ) _echo_args(args) + + +OPT_FLAG_PRESET = { + "O0": OptimizationFlags( + cutlass_attn=False, + cutlass_norm=False, + cublas_gemm=False, + cudagraph=False, + ), + "O1": OptimizationFlags( + cutlass_attn=False, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=False, + ), + "O2": OptimizationFlags( + cutlass_attn=True, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=False, + ), + "O3": OptimizationFlags( + cutlass_attn=True, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=True, + ), +} diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index 9c9db8c1f9..8fd041ef32 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -1,39 +1,63 @@ """A centralized registry of all existing model architures and their configurations.""" import dataclasses -from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict from tvm.relax.frontend import nn from ..parameter import ExternMapping, QuantizeMapping +from ..quantization.quantization import QuantizeConfig from . import llama_config, llama_model, llama_parameter ModelConfig = Any -QuantizeConfig = Any +"""A ModelConfig is an object that represents a model architecture. It is required to have +a class method `from_file` with the following signature: -LoaderType = Callable[[ModelConfig, QuantizeConfig], ExternMapping] -QuantizerType = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping] + def from_file(cls, path: Path) -> ModelConfig: + ... +""" + +FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping] +FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping] @dataclasses.dataclass class Model: - """All about a model architecture: its configuration, its parameter loader and quantization.""" + """All about a model architecture: its configuration, its parameter loader and quantization. + + Parameters + ---------- + name : str + The name of the model. + + model : Callable[[ModelConfig], nn.Module] + A method that creates the `nn.Module` that represents the model from `ModelConfig`. + + config : ModelConfig + A class that has a `from_file` class method, whose signature is "Path -> ModelConfig". + + source : Dict[str, FuncGetExternMap] + A dictionary that maps the name of a source format to parameter mapping. + + quantize: Dict[str, FuncGetQuantMap] + A dictionary that maps the name of a quantization method to quantization mapping. + """ name: str + config: ModelConfig model: Callable[[ModelConfig], nn.Module] - config: Callable[[Path], ModelConfig] - source_loader_huggingface: Optional[LoaderType] = None - source_loader_awq: Optional[LoaderType] = None - quantizer_group_quant: Optional[QuantizerType] = None + source: Dict[str, FuncGetExternMap] + quantize: Dict[str, FuncGetQuantMap] MODELS: Dict[str, Model] = { "llama": Model( name="llama", model=llama_model.LlamaForCasualLM, - config=llama_config.LlamaConfig.from_file, - source_loader_huggingface=llama_parameter.huggingface, - source_loader_awq=None, - quantizer_group_quant=None, + config=llama_config.LlamaConfig, + source={ + "huggingface-torch": llama_parameter.huggingface, + "huggingface-safetensor": llama_parameter.huggingface, + }, + quantize={}, ) } diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py new file mode 100644 index 0000000000..ab352fc6c2 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -0,0 +1,2 @@ +"""A subpackage for quantization and dequantization algorithms""" +from .quantization import QUANT diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py new file mode 100644 index 0000000000..c1ba794063 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -0,0 +1,22 @@ +"""A centralized registry of all existing quantization methods and their configurations.""" +from typing import Any, Dict + +QuantizeConfig = Any +"""A QuantizeConfig is an object that represents an quantization algorithm. It is required to +have the following fields: + + name : str + The name of the quantization algorithm, for example, "q4f16_1". + + kind : str + The kind of quantization algorithm, for example, "group_quant", "faster_transformer". + +It is also required to have the following method: + + def quantize(self, module: nn.Module) -> nn.Module: + ... +""" + +QUANT: Dict[str, QuantizeConfig] = { + "q4f16_1": None, +} diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index f5c33acb9f..165c0a0f20 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -1,9 +1,14 @@ """Help function for detecting the model configuration file `config.json`""" +import json import logging from pathlib import Path +from typing import TYPE_CHECKING from .style import green +if TYPE_CHECKING: + from mlc_chat.compiler import Model # pylint: disable=unused-import + logger = logging.getLogger(__name__) FOUND = green("Found") @@ -36,3 +41,42 @@ def detect_config(config_path: Path) -> Path: logger.info("%s model configuration: %s", FOUND, config_json_path) return config_json_path + + +def detect_model_type(model_type: str, config: Path) -> "Model": + """Detect the model type from the configuration file. If `model_type` is "auto", it will be + inferred from the configuration file. Otherwise, it will be used as the model type, and sanity + check will be performed. + + Parameters + ---------- + model_type : str + The model type, for example, "llama". + + config : pathlib.Path + The path to config.json. + + Returns + ------- + model : mlc_chat.compiler.Model + The model type. + """ + + from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel + MODELS, + Model, + ) + + if model_type == "auto": + with open(config, "r", encoding="utf-8") as config_file: + cfg = json.load(config_file) + if "model_type" not in cfg: + raise ValueError( + f"'model_type' not found in: {config}. " + f"Please explicitly specify `--model-type` instead" + ) + model_type = cfg["model_type"] + logger.info("%s Model type: %s", FOUND, model_type) + if model_type not in MODELS: + raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") + return MODELS[model_type] diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index d64c8f9daa..f31e813410 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) # TODO: add help message on how to specify the target manually # pylint: disable=fixme -# TODO: revisit system_lib_prefix handling # pylint: disable=fixme # TODO: include host detection logic below after the new TVM build is done. # pylint: disable=fixme HELP_MSG = """TBD""" FOUND = green("Found") @@ -100,6 +99,19 @@ def _is_device(device: str): return True +def _add_prefix_symbol(mod: IRModule, prefix: str, is_system_lib: bool) -> IRModule: + if is_system_lib and prefix: + mod = mod.with_attr("system_lib_prefix", prefix) + elif is_system_lib: + logger.warning("--prefix-symbols is not specified when building a static library") + elif prefix: + logger.warning( + "--prefix-symbols is specified, but it will not take any effect " + "when building the shared library" + ) + return mod + + def _detect_target_from_device(device: str) -> Optional[Target]: try: target = Target.from_device(device) @@ -118,6 +130,7 @@ def _detect_target_from_device(device: str) -> Optional[Target]: def _build_metal_x86_64(): def build(mod: IRModule, args: "CompileArgs"): output = args.output + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=False) assert output.suffix == ".dylib" relax.build( mod, @@ -141,10 +154,10 @@ def compile_metal(src, target): def build(mod: IRModule, args: "CompileArgs"): output = args.output - system_lib_prefix = f"{args.model_type}_{args.quantization}_".replace("-", "_") + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=True) assert output.suffix == ".tar" relax.build( - mod.with_attr("system_lib_prefix", system_lib_prefix), + mod, target=args.target, system_lib=True, ).export_library( @@ -158,10 +171,10 @@ def build(mod: IRModule, args: "CompileArgs"): def _build_android(): def build(mod: IRModule, args: "CompileArgs"): output = args.output - system_lib_prefix = f"{args.model_type}_{args.quantization}_".replace("-", "_") + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=True) assert output.suffix == ".tar" relax.build( - mod.with_attr("system_lib_prefix", system_lib_prefix), + mod, target=args.target, system_lib=True, ).export_library( @@ -175,6 +188,7 @@ def build(mod: IRModule, args: "CompileArgs"): def _build_webgpu(): def build(mod: IRModule, args: "CompileArgs"): output = args.output + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=True) assert output.suffix == ".wasm" relax.build( mod, @@ -190,13 +204,14 @@ def build(mod: IRModule, args: "CompileArgs"): def _build_default(): def build(mod: IRModule, args: "CompileArgs"): output = args.output - if output.suffix in [".a", ".lib"]: + if output.suffix in [".tar", ".lib"]: system_lib = True elif output.suffix in [".so", ".dylib", ".dll"]: system_lib = False else: logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix) system_lib = False + mod = _add_prefix_symbol(mod, args.prefix_symbols, is_system_lib=system_lib) relax.build( mod, target=args.target, diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py index b19ec6b07a..042e7b5366 100644 --- a/python/mlc_chat/support/auto_weight.py +++ b/python/mlc_chat/support/auto_weight.py @@ -4,8 +4,13 @@ from pathlib import Path from typing import Tuple +from .style import green, red + logger = logging.getLogger(__name__) +FOUND = green("Found") +NOT_FOUND = red("Not found") + def detect_weight( weight_path: Path, @@ -61,7 +66,7 @@ def detect_weight( if not weight_path.exists(): raise ValueError(f"weight_path doesn't exist: {weight_path}") - logger.info("Loading weights from directory: %s", weight_path) + logger.info("%s weights from directory: %s", FOUND, weight_path) # check weight format # weight_format = "auto", guess the weight format. @@ -92,7 +97,7 @@ def _guess_weight_format(weight_path: Path): ) selected_format = possible_formats[0] - logging.info( + logger.info( "Using %s format now. Use `--weight-format` to manually specify the format.", selected_format, ) @@ -103,9 +108,9 @@ def _check_pytorch(weight_path: Path): pytorch_json_path = weight_path / "pytorch_model.bin.index.json" result = pytorch_json_path.exists() if result: - logger.info("[Y] Found Huggingface PyTorch: %s", pytorch_json_path) + logger.info("%s Huggingface PyTorch: %s", FOUND, pytorch_json_path) else: - logger.info("[X] Not found: Huggingface PyTorch") + logger.info("%s Huggingface PyTorch", NOT_FOUND) return result @@ -113,9 +118,9 @@ def _check_safetensor(weight_path: Path): safetensor_json_path = weight_path / "model.safetensors.index.json" result = safetensor_json_path.exists() if result: - logger.info("[Y] Found SafeTensor: %s", safetensor_json_path) + logger.info("%s SafeTensor: %s", FOUND, safetensor_json_path) else: - logger.info("[X] Not found: SafeTensor") + logger.info("%s SafeTensor", NOT_FOUND) return result diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index cb77a59b71..9e75247c32 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -1,13 +1,13 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler.model.llama_config import LlamaConfig -from mlc_chat.compiler.model.llama_model import LlamaForCasualLM +from mlc_chat.compiler import MODELS @pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) def test_llama2_creation(model_name: str): - config = LlamaConfig.from_predefined(model_name) - model = LlamaForCasualLM(config) + model_info = MODELS["llama"] + config = model_info.config.from_predefined(model_name) + model = model_info.model(config) mod, named_params = model.export_tvm(spec=model.get_default_spec()) mod.show(black_format=False) for name, param in named_params: diff --git a/tests/python/parameter/test_huggingface.py b/tests/python/parameter/test_huggingface.py index 851bd42aea..ecd8e16455 100644 --- a/tests/python/parameter/test_huggingface.py +++ b/tests/python/parameter/test_huggingface.py @@ -4,8 +4,10 @@ from typing import Union import pytest -from mlc_chat.compiler.model.llama_config import LlamaConfig -from mlc_chat.compiler.model.llama_parameter import huggingface +from mlc_chat.compiler import MODELS + +# from mlc_chat.compiler.model.llama_config import LlamaConfig +# from mlc_chat.compiler.model.llama_parameter import huggingface from mlc_chat.compiler.parameter import HuggingFaceLoader from mlc_chat.support import tqdm @@ -30,11 +32,15 @@ def test_load_torch_llama(base_path: Union[str, Path]): path_config = base_path / "config.json" path_params = base_path / "pytorch_model.bin.index.json" - config = LlamaConfig.from_file(path_config) - loader = HuggingFaceLoader(path=path_params, extern_param_map=huggingface(config, None)) + model = MODELS["llama"] + config = model.config.from_file(path_config) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-torch"](config, None), + ) with tqdm.redirect(): for _name, _param in loader.load(): - ... + return # To reduce the time of the test @pytest.mark.parametrize( @@ -50,11 +56,15 @@ def test_load_safetensor_llama(base_path: Union[str, Path]): path_config = base_path / "config.json" path_params = base_path / "model.safetensors.index.json" - config = LlamaConfig.from_file(path_config) - loader = HuggingFaceLoader(path=path_params, extern_param_map=huggingface(config, None)) + model = MODELS["llama"] + config = model.config.from_file(path_config) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-safetensor"](config, None), + ) with tqdm.redirect(): for _name, _param in loader.load(): - ... + return # To reduce the time of the test if __name__ == "__main__": diff --git a/tests/python/test_auto_config.py b/tests/python/support/test_auto_config.py similarity index 95% rename from tests/python/test_auto_config.py rename to tests/python/support/test_auto_config.py index 6209186c32..540c544c22 100644 --- a/tests/python/test_auto_config.py +++ b/tests/python/support/test_auto_config.py @@ -41,4 +41,5 @@ def test_detect_config_fail(): if __name__ == "__main__": - pass + test_detect_config() + test_detect_config_fail() diff --git a/tests/python/test_auto_weight.py b/tests/python/support/test_auto_weight.py similarity index 69% rename from tests/python/test_auto_weight.py rename to tests/python/support/test_auto_weight.py index a0363ed1c4..2987135267 100644 --- a/tests/python/test_auto_weight.py +++ b/tests/python/support/test_auto_weight.py @@ -95,10 +95,32 @@ def test_find_weight_fail(): base_path = Path(tmpdir) with pytest.raises(ValueError): detect_weight(Path("do/not/exist"), base_path, "AWQ") - with pytest.raises(AssertionError): detect_weight(None, Path("do/not/exist"), "AWQ") if __name__ == "__main__": - pass + test_detect_weight("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight("SafeTensor", "model.safetensors.index.json", "SafeTensor") + test_detect_weight("GGML", None, "GGML") + test_detect_weight("GGUF", None, "GGUF") + test_detect_weight("AWQ", None, "AWQ") + test_detect_weight("auto", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight("auto", "model.safetensors.index.json", "SafeTensor") + test_detect_weight_in_config_json("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_in_config_json("SafeTensor", "model.safetensors.index.json", "SafeTensor") + test_detect_weight_in_config_json("GGML", None, "GGML") + test_detect_weight_in_config_json("GGUF", None, "GGUF") + test_detect_weight_in_config_json("AWQ", None, "AWQ") + test_detect_weight_in_config_json("auto", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_in_config_json("auto", "model.safetensors.index.json", "SafeTensor") + test_detect_weight_same_dir_config_json("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_same_dir_config_json( + "SafeTensor", "model.safetensors.index.json", "SafeTensor" + ) + test_detect_weight_same_dir_config_json("GGML", None, "GGML") + test_detect_weight_same_dir_config_json("GGUF", None, "GGUF") + test_detect_weight_same_dir_config_json("AWQ", None, "AWQ") + test_detect_weight_same_dir_config_json("auto", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight_same_dir_config_json("auto", "model.safetensors.index.json", "SafeTensor") + test_find_weight_fail() From 973f9fcdbd7810f7034442a23e6af940f5524117 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 25 Oct 2023 10:14:46 -0500 Subject: [PATCH 054/116] [ParamManager][Redo] Use BundleModelParams for transform_dequantize (#1127) Prior to this commit, `ParamManager.transform_quantize` function took as input functions with separate parameters for each weight tensor, and produced output functions with a tuple parameter for all weights. Because `LiftTransformParams` had the same convention, neither could be applied as part of the same build flow. This commit updates `ParamManager.transform_quantize` pass to produce outputs with separate tensor parameters, using the `BundleModelParams` transform to later combine them into a single tuple parameter. The analogous change was also performed for `LiftTransformParams` as part of https://github.com/apache/tvm/pull/15657. In addition, prior to this commit, the `ParamManager.transform_dequantize` function operated directly on a `IRModule` object. As a result, any debug instrumentation (e.g. before/after printouts for each pass, before/after verification with `relax.analysis.well_formed`, etc.) did not apply to this `transform_dequantize`. This commit updates `ParamManager.transform_dequantize` to return a `ir.transform.Pass`. This commit is a repeat of the reverted PR https://github.com/mlc-ai/mlc-llm/pull/1056. This PR resolves the bug in the earlier implementation by removing the call to `.without_attr("num_input")` in `ParamReplacer.rewrite_func`. This follows an analogous update in `LiftTransformParams`, preserving the `"num_input"` attribute for use in `BundleModelParams`. --- mlc_llm/core.py | 3 +- mlc_llm/relax_model/param_manager.py | 116 +++++++++++++-------------- 2 files changed, 56 insertions(+), 63 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 8c3a75c374..e720d19542 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -420,7 +420,8 @@ def mod_transform_before_build( if args.model.lower().startswith("rwkv-"): model_names += ["reset_kv_cache"] - mod = param_manager.transform_dequantize(mod) + mod = param_manager.transform_dequantize()(mod) + mod = relax.transform.BundleModelParams()(mod) use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft"] mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 04f56a5152..7f0751b2a0 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -369,7 +369,7 @@ def set_param_loading_func( else: self.pidx2pname = dict() - def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: + def transform_dequantize(self) -> tvm.ir.transform.Pass: """Apply dequantization to the input IRModule. Parameters @@ -386,38 +386,48 @@ def transform_dequantize(self, mod: tvm.IRModule) -> tvm.IRModule: The IRModule updated with the dequantization computation. """ - # For each Relax function in the input IRModule (e.g., "prefill"), - # we create its input relax.Var of all the quantized data, and - # store the mapping from function name to the var. - func2param_var: Dict[str, relax.Var] = {} - for gv, func in mod.functions.items(): - if not isinstance(func, relax.Function): - continue - if func.attrs is None or not "num_input" in func.attrs: - continue - func2param_var[gv.name_hint] = relax.Var( - "params", self.get_quantized_param_info(gv.name_hint) - ) + @tvm.ir.transform.module_pass(opt_level=0, name="ParamManager.transform_dequantize") + def transform_func(mod: tvm.IRModule, _context) -> tvm.IRModule: + # For each Relax function in the input IRModule (e.g., "prefill"), + # we create its input relax.Var of all the quantized data, and + # store the mapping from function name to the var. + func_name_to_quantized_params: Dict[str, List[relax.Var]] = {} - # Cache mapping to avoid duplicate dequantization. - dequantized_cache: Dict[relax.Var, relax.Var] = {} + for gv, func in mod.functions.items(): + if isinstance(func, relax.Function) and func.attrs and "num_input" in func.attrs: + quantized_param_info = self.get_quantized_param_info(gv.name_hint) + param_vars = [ + relax.Var(f"param_{i}", info) + for i, info in enumerate(quantized_param_info.fields) + ] + func_name_to_quantized_params[gv.name_hint] = param_vars - # Define a var replacement function for applying dequantization. - def f_replace(var: relax.Var, bb: relax.BlockBuilder, func_name: str) -> relax.Var: - if var in dequantized_cache: - return dequantized_cache[var] - assert var in self.func_raw_param_map - func_name, param = self.func_raw_param_map[var] - dequantized = self._dequantize(param, func2param_var[func_name], bb, func_name) - dequantized_cache[var] = dequantized - return dequantized + # Cache mapping to avoid duplicate dequantization. + dequantized_cache: Dict[relax.Var, relax.Var] = {} - # Create the function mutator for applying dequantization. - replacer = ParamReplacer(mod, func2param_var, f_replace) - # Update the input IRModule with dequantization. - mod = replacer.transform() + # Define a var replacement function for applying dequantization. + def f_replace(var: relax.Var, bb: relax.BlockBuilder) -> relax.Var: + if var in dequantized_cache: + return dequantized_cache[var] + assert var in self.func_raw_param_map - return mod + func_name, param = self.func_raw_param_map[var] + quantized_params = func_name_to_quantized_params[func_name] + relevant_quantized_params = [quantized_params[i] for i in self.param2qrange[param]] + + dequantized = self._dequantize(param, relevant_quantized_params, bb, func_name) + + dequantized_cache[var] = dequantized + return dequantized + + # Create the function mutator for applying dequantization. + replacer = ParamReplacer(mod, func_name_to_quantized_params, f_replace) + # Update the input IRModule with dequantization. + mod = replacer.transform() + + return mod + + return transform_func def get_quantized_param_info(self, func_name: str) -> List[relax.TensorStructInfo]: bb = relax.BlockBuilder() @@ -697,10 +707,9 @@ def _register_param( def _dequantize( self, param: Parameter, - quantized_tuple: relax.Var, + qparams: List[relax.Var], bb: relax.BlockBuilder, func_name: str, - qparams: List[relax.Var] = None, ) -> relax.Var: """Applying dequantization to the input parameter. This method is called by `transform_module` below, and is not @@ -711,30 +720,13 @@ def _dequantize( param : Parameter The parameter whose quantized tensors are to be dequantized. - quantized_tuple : relax.Var - The relax.Var of the quantized tensors of all parameters in the model. - - bb : relax.BlockBuilder - The Relax BlockBuilder used for inserting the dequantization computations. - - func_name : str - The name of the function which dequantization is applied to. - qparams : List[relax.Var] - The quantized parts of the parameter. - By default it is `None`, in which case we will get the quantized parts - from `quantized_tuple`. + The relax.Var of the quantized tensors of all parameters in the model. Returns ------- The dequantized parameter, in the form of a relax.Var. """ - if not qparams: - # Get the corresponding Relax vars of the quantized tensors of this parameter. - qparams: List[relax.Var] = [] - for qparam_idx in self.param2qrange[param]: - qparams.append(bb.emit(relax.TupleGetItem(quantized_tuple, qparam_idx))) - # Get the dequantization function of this parameter. f_dequantize = param.quant_spec.get_dequantize_func( param_info=param.param_info_dict[func_name], @@ -789,7 +781,7 @@ class ParamReplacer(PyExprMutator): mod : tvm.IRModule The IRModule of the model to be updated. - func2param_var : Dict[str, relax.Var] + func_name_to_quantized_params : Dict[str, List[relax.Var]] The mapping from each function name to its input var of quantized data tuple. f_replace : Callable[[relax.Var, relax.BlockBuilder], relax.Var] @@ -801,7 +793,7 @@ class ParamReplacer(PyExprMutator): """ mod: tvm.IRModule - func2param_var: Dict[str, relax.Var] + func_name_to_quantized_params: Dict[str, List[relax.Var]] f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var] param_set: Set[relax.Var] @@ -810,12 +802,12 @@ class ParamReplacer(PyExprMutator): def __init__( self, mod: tvm.IRModule, - func2param_var: Dict[str, relax.Var], + func_name_to_quantized_params: Dict[str, relax.Var], f_replace: Callable[[relax.Var, relax.BlockBuilder], relax.Var], ): super().__init__(mod) self.mod = mod - self.func2param_var = func2param_var + self.func_name_to_quantized_params = func_name_to_quantized_params self.f_replace = f_replace self.cur_func_name = "" @@ -827,31 +819,31 @@ def transform(self) -> tvm.IRModule: continue assert ( - gv.name_hint in self.func2param_var - ), f"{gv.name_hint} not in {self.func2param_var}" - self.cur_func_name = gv.name_hint - updated_func = self.rewrite_func(func, self.func2param_var[gv.name_hint]) + gv.name_hint in self.func_name_to_quantized_params + ), f"{gv.name_hint} not in {self.func_name_to_quantized_params}" + updated_func = self.rewrite_func(func, self.func_name_to_quantized_params[gv.name_hint]) updated_func = remove_all_unused(updated_func) self.builder_.update_func(gv, updated_func) return self.builder_.get() - def rewrite_func(self, func: Function, param_var: relax.Var) -> relax.Function: + def rewrite_func(self, func: Function, quantized_params: List[relax.Var]) -> relax.Function: num_input = int(func.attrs["num_input"]) self.param_set = set(func.params[num_input:]) body = self.visit_expr(func.body) return relax.Function( - params=func.params[:num_input] + [param_var], + params=func.params[:num_input] + quantized_params, body=body, ret_struct_info=func.ret_struct_info, is_pure=func.is_pure, attrs=func.attrs, - ).without_attr("num_input") + ) def visit_var_(self, var: Var) -> Expr: - if var not in self.param_set: + if var in self.param_set: + return self.f_replace(var, self.builder_) + else: return super().visit_var_(var) - return self.f_replace(var, self.builder_, self.cur_func_name) ################################################################## From 24f795e0a9e184ece8df9888abc64ede6238de3f Mon Sep 17 00:00:00 2001 From: Goutham Tamilselvan Date: Fri, 27 Oct 2023 03:25:59 -0400 Subject: [PATCH 055/116] added details to windows installation (#1133) 32bit version of the zstd.dll library was causing issues, so updated the doc to be more specific and download the 64bit version. --- docs/install/mlc_llm.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index f95cc3ee9c..13fc373dbf 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -124,10 +124,10 @@ Select your operating system/compute platform and run the command in your termin FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. + It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the 64 bit version of precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. Option 2. Build from Source --------------------------- -Upcoming. \ No newline at end of file +Upcoming. From 2c492e54d76a34cd159b0ed5ffa9c661edc41a6e Mon Sep 17 00:00:00 2001 From: S A G A R <110724849+tmsagarofficial@users.noreply.github.com> Date: Sun, 29 Oct 2023 03:43:15 +0530 Subject: [PATCH 056/116] Grammatical and Typographical improvements (#1139) * Update faq.rst * Update guideline.rst * Update compile_models.rst * Update distribute_compiled_models.rst * Update get-vicuna-weight.rst * Update python.rst * Update android.rst * Update cli.rst * Update ios.rst * Update javascript.rst * Update python.rst * Update rest.rst --- docs/community/faq.rst | 2 +- docs/community/guideline.rst | 20 +++++++++---------- docs/compilation/compile_models.rst | 20 +++++++++---------- .../distribute_compiled_models.rst | 6 +++--- docs/compilation/get-vicuna-weight.rst | 8 ++++---- docs/compilation/python.rst | 8 ++++---- docs/deploy/android.rst | 10 +++++----- docs/deploy/cli.rst | 12 +++++------ docs/deploy/ios.rst | 16 +++++++-------- docs/deploy/javascript.rst | 4 ++-- docs/deploy/python.rst | 12 +++++------ docs/deploy/rest.rst | 4 ++-- 12 files changed, 61 insertions(+), 61 deletions(-) diff --git a/docs/community/faq.rst b/docs/community/faq.rst index f426a0c624..3913dd9639 100644 --- a/docs/community/faq.rst +++ b/docs/community/faq.rst @@ -5,7 +5,7 @@ Frequently Asked Questions This is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries! -... How can I customize the temperature, repetition penalty of models? +... How can I customize the temperature, and repetition penalty of models? Please check our :doc:`/get_started/mlc_chat_config` tutorial. ... What's the quantization algorithm MLC-LLM using? diff --git a/docs/community/guideline.rst b/docs/community/guideline.rst index eac77101e9..38a03e463e 100644 --- a/docs/community/guideline.rst +++ b/docs/community/guideline.rst @@ -42,11 +42,11 @@ Ready to contribute to MLC-LLM? Awesome! We are excited to see you are ready to The standard way to make changes to MLC-LLM code base is through creating a `pull-request `__, and we will review your code and merge it to the code base when it is ready. -The first step to become a developer is to `fork `__ the repository to your own +The first step to becoming a developer is to `fork `__ the repository to your own github account, you will notice a repository under ``https://github.com/username/mlc-llm`` where ``username`` is your github user name. You can clone your fork to your local machine and commit changes, or edit the contents of your fork (in the case you are just fixing typos) -on github directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository. +on GitHub directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository. .. _contribute-new-models: @@ -86,14 +86,14 @@ Fo your convenience, you can use `clang-format `__ to acknowledge contributors, -please let us know if you contribute to the project and your name is not included in the list. +please let us know if you contribute to the project and if your name is not included in the list. diff --git a/docs/compilation/compile_models.rst b/docs/compilation/compile_models.rst index b5f1044b75..98c7f2d156 100644 --- a/docs/compilation/compile_models.rst +++ b/docs/compilation/compile_models.rst @@ -4,14 +4,14 @@ Compile Models via MLC ====================== This page describes how to compile a model with MLC LLM. Model compilation takes model inputs, produces quantized model weights, -and optimized model lib for a given platform. It enables users to bring their own new model weights, try different quantization modes, +and optimizes model lib for a given platform. It enables users to bring their own new model weights, try different quantization modes, and customize the overall model optimization flow. .. note:: Before you proceed, please make sure that you have :ref:`install-tvm-unity` correctly installed on your machine. TVM-Unity is the necessary foundation for us to compile models with MLC LLM. If you want to build webgpu, please also complete :ref:`install-web-build`. - Please also follow the instruction in :ref:`deploy-cli` to obtain the CLI app that can be used to chat with the compiled model. + Please also follow the instructions in :ref:`deploy-cli` to obtain the CLI app that can be used to chat with the compiled model. Finally, we strongly recommend you read :ref:`project-overview` first to get familiarized with the high-level terminologies. @@ -25,7 +25,7 @@ Install MLC-LLM Package Work with Source Code ^^^^^^^^^^^^^^^^^^^^^ -The easiest way is to use MLC-LLM is to clone the repository, and compile models under the root directory of the repository. +The easiest way to use MLC-LLM is to clone the repository, and compile models under the root directory of the repository. .. code:: bash @@ -106,7 +106,7 @@ your personal computer. xcrun: error: unable to find utility "metallib", not a developer tool or in PATH , please check and make sure you have Command Line Tools for Xcode installed correctly. - You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed the model compiling. + You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed with the model compiling. .. group-tab:: Android @@ -172,7 +172,7 @@ We can check the output with the commands below: tokenizer_config.json We now chat with the model using the command line interface (CLI) app. - Follow the build from source instruction + Follow the build from the source instruction .. code:: shell @@ -271,7 +271,7 @@ We can check the output with the commands below: tokenizer_config.json The model lib ``dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm`` - can be uploaded to internet. You can pass a ``model_lib_map`` field to WebLLM app config to use this library. + can be uploaded to the internet. You can pass a ``model_lib_map`` field to WebLLM app config to use this library. Each compilation target produces a specific model library for the given platform. The model weight is shared across @@ -311,7 +311,7 @@ In other cases you need to specify the model via ``--model``. - ``dist/models/MODEL_NAME_OR_PATH`` (e.g., ``--model Llama-2-7b-chat-hf``), - ``MODEL_NAME_OR_PATH`` (e.g., ``--model /my-model/Llama-2-7b-chat-hf``). - When running the compile command using ``--model``, please make sure you have placed the model to compile under ``dist/models/`` or other location on the disk. + When running the compile command using ``--model``, please make sure you have placed the model to compile under ``dist/models/`` or another location on the disk. --hf-path HUGGINGFACE_NAME The name of the model's Hugging Face repository. We will download the model to ``dist/models/HUGGINGFACE_NAME`` and load the model from this directory. @@ -336,11 +336,11 @@ The following arguments are optional: we will use the maximum sequence length from the ``config.json`` in the model directory. --reuse-lib LIB_NAME Specifies the previously generated library to reuse. This is useful when building the same model architecture with different weights. - You can refer to the :ref:`model distribution ` page for detail of this argument. + You can refer to the :ref:`model distribution ` page for details of this argument. --use-cache When ``--use-cache=0`` is specified, the model compilation will not use cached file from previous builds, and will compile the model from the very start. - Using cache can help reduce the time needed to compile. + Using a cache can help reduce the time needed to compile. --debug-dump Specifies whether to dump debugging files during compilation. --use-safetensors Specifies whether to use ``.safetensors`` instead of the default ``.bin`` when loading in model weights. @@ -354,7 +354,7 @@ This section lists compile commands for more models that you can try out. .. tab:: Model: Llama-2-7B Please `request for access `_ to the Llama-2 weights from Meta first. - After granted the access, please create directory ``dist/models`` and download the model to the directory. + After granted access, please create directory ``dist/models`` and download the model to the directory. For example, you can run the following code: .. code:: shell diff --git a/docs/compilation/distribute_compiled_models.rst b/docs/compilation/distribute_compiled_models.rst index 96ac5a09a3..69dc0e847d 100644 --- a/docs/compilation/distribute_compiled_models.rst +++ b/docs/compilation/distribute_compiled_models.rst @@ -67,7 +67,7 @@ You can **optionally** customize the chat config file ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1/params/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions). You can also simply use the default configuration and skip this step. -For demonstration purpose, we update ``mean_gen_len`` to 32 and ``max_gen_len`` to 64. +For demonstration purposes, we update ``mean_gen_len`` to 32 and ``max_gen_len`` to 64. We also update ``conv_template`` to ``"LM"`` because the model is instruction-tuned. @@ -160,7 +160,7 @@ Download the Distributed Models and Run in iOS App -------------------------------------------------- For iOS app, model libraries are statically packed into the app at the time of app building. -Therefore, the iOS app supports running any models whose model libraries are integrated into the app. +Therefore, the iOS app supports running any model whose model libraries are integrated into the app. You can check the :ref:`list of supported model libraries `. To download and run the compiled RedPajama-3B instruct model on iPhone, we need to reuse the integrated ``RedPajama-INCITE-Chat-3B-v1-q4f16_1`` model library. @@ -198,7 +198,7 @@ Now we can download the model weights in iOS app and run the model by following .. tab:: Step 4 - When the download is finished, click into the model and enjoy. + When the download is finished, click on the model and enjoy. .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/iPhone-distribute-4.jpeg :align: center diff --git a/docs/compilation/get-vicuna-weight.rst b/docs/compilation/get-vicuna-weight.rst index 0cc42380e9..2ea4ba5d97 100644 --- a/docs/compilation/get-vicuna-weight.rst +++ b/docs/compilation/get-vicuna-weight.rst @@ -5,7 +5,7 @@ Getting Vicuna Weights :local: :depth: 2 -`Vicuna `_ is a open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. +`Vicuna `_ is an open-source chatbot trained by fine-tuning `LLaMA `_ on `ShartGPT `_ data. Please note that the official Vicuna weights are delta weights applied to the LLaMA weights in order to comply with the LLaMA license. Users are responsible for applying these delta weights themselves. @@ -14,7 +14,7 @@ In this tutorial, we will show how to apply the delta weights to LLaMA weights t Install FastChat ---------------- -FastChat offers convenient utility functions for applying delta to LLaMA weights. You can easily install it using pip. +FastChat offers convenient utility functions for applying the delta to LLaMA weights. You can easily install it using pip. .. code-block:: bash @@ -38,14 +38,14 @@ Then download the weights (both the LLaMA weight and Vicuna delta weight): git clone https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 -There is a name mis-alignment issue in the LLaMA weights and Vicuna delta weights. +There is a name misalignment issue in the LLaMA weights and Vicuna delta weights. Please follow these steps to modify the content of the "config.json" file: .. code-block:: bash sed -i 's/LLaMAForCausalLM/LlamaForCausalLM/g' llama-7b-hf/config.json -Then use ``fschat`` to apply delta to LLaMA weights +Then use ``fschat`` to apply the delta to LLaMA weights .. code-block:: bash diff --git a/docs/compilation/python.rst b/docs/compilation/python.rst index 99486a751b..98e4f934e7 100644 --- a/docs/compilation/python.rst +++ b/docs/compilation/python.rst @@ -5,8 +5,8 @@ Python API for Model Compilation :local: :depth: 2 -We expose Python API for compiling/building model in the package :py:mod:`mlc_llm`, so -that users may build model in any directory in their program (i.e. not just +We expose Python API for compiling/building models in the package :py:mod:`mlc_llm`, so +that users may build a model in any directory in their program (i.e. not just within the mlc-llm repo). Install MLC-LLM as a Package @@ -44,7 +44,7 @@ After installing the package, you can build the model using :meth:`mlc_llm.build which takes in an instance of :class:`BuildArgs` (a dataclass that represents the arguments for building a model). -For detailed instruction with code, please refer to `the python notebook +For detailed instructions with code, please refer to `the Python notebook `_ (executable in Colab), where we walk you through compiling Llama-2 with :py:mod:`mlc_llm` in Python. @@ -56,7 +56,7 @@ API Reference In order to use the python API :meth:`mlc_llm.build_model`, users need to create an instance of the dataclass :class:`BuildArgs`. The corresponding arguments in -command line shown in :ref:`compile-command-specification` are automatically +the command line shown in :ref:`compile-command-specification` are automatically converted from the definition of :class:`BuildArgs` and are equivalent. Then with an instantiated :class:`BuildArgs`, users can call the build API diff --git a/docs/deploy/android.rst b/docs/deploy/android.rst index c26e9f3445..0c2ed8535f 100644 --- a/docs/deploy/android.rst +++ b/docs/deploy/android.rst @@ -98,11 +98,11 @@ To deploy models on Android with reasonable performance, one has to cross-compil --model ./dist/models/$MODEL_NAME \ --quantization $QUANTIZATION -This generates directory ``./dist/$MODEL_NAME-$QUANTIZATION`` which contains the necessary components to run the model, as explained below. +This generates the directory ``./dist/$MODEL_NAME-$QUANTIZATION`` which contains the necessary components to run the model, as explained below. **Expected output format**. By default models are placed under ``./dist/${MODEL_NAME}-${QUANTIZATION}``, and the result consists of 3 major components: -- Runtime configuration: It configures conversation templates including system prompts, repetition repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` under ``params/`` along side with tokenizer configurations. +- Runtime configuration: It configures conversation templates including system prompts, repetition repetition penalty, sampling including temperature and top-p probability, maximum sequence length, etc. It is usually named as ``mlc-chat-config.json`` under ``params/``alongside with tokenizer configurations. - Model lib: The compiled library that uses mobile GPU. It is usually named as ``${MODEL_NAME}-${QUANTIZATION}-android.tar``, for example, ``Llama-2-7b-chat-hf-q4f16_0-android.tar``. - Model weights: the model weights are sharded as ``params_shard_*.bin`` under ``params/`` and the metadata is stored in ``ndarray-cache.json``. @@ -144,16 +144,16 @@ The model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime **Build the Android app**. Open folder ``./android/MLCChat`` as an Android Studio Project. Connect your Android device to your machine. In the menu bar of Android Studio, click "Build → Make Project". Once the build is finished, click "Run → Run 'app'" and you will see the app launched on your phone. .. note:: - ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at accelerated speed. + ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed. Incorporate Model Weights ------------------------- Instructions have been provided to build an Android App with MLC LLM in previous sections, but it requires run-time weight downloading from HuggingFace, as configured in `app-config.json` in previous steps under `model_url`. However, it could be desirable to bundle weights together into the app to avoid downloading over the network. In this section, we provide a simple ADB-based walkthrough that hopefully helps with further development. -**Generating APK**. Enter Android Studio, click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/MLCChat/app/release/app-release.apk``. +**Generating APK**. Enter Android Studio, and click "Build → Generate Signed Bundle/APK" to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android `_. This APK will be placed under ``android/MLCChat/app/release/app-release.apk``. -**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: +**Install ADB and USB debugging**. Enable "USB debugging" in the developer mode in your phone settings. In SDK manager, install `Android SDK Platform-Tools `_. Add the path to platform-tool path to the environment variable ``PATH``. Run the following commands, and if ADB is installed correctly, your phone will appear as a device: .. code-block:: bash diff --git a/docs/deploy/cli.rst b/docs/deploy/cli.rst index 460ac71c7d..2f0951686d 100644 --- a/docs/deploy/cli.rst +++ b/docs/deploy/cli.rst @@ -3,7 +3,7 @@ CLI and C++ API =============== -MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. You may install it from the prebuilt package we provide, or compile it from source. +MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box. You may install it from the prebuilt package we provide, or compile it from the source. .. contents:: Table of Contents :local: @@ -25,16 +25,16 @@ To use other GPU runtimes, e.g. CUDA, please instead :ref:`build it from source After installation, activating ``mlc-chat-venv`` environment in Conda will give the ``mlc_chat_cli`` command available. .. note:: - The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from source. + The prebuilt package supports **Metal** on macOS and **Vulkan** on Linux and Windows. It is possible to use other GPU runtimes such as **CUDA** by compiling MLCChat CLI from the source. .. _mlcchat_build_from_source: Option 2. Build MLC Runtime from Source ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -We also provid options to build mlc runtime libraries and ``mlc_chat_cli`` from source. +We also provide options to build mlc runtime libraries and ``mlc_chat_cli`` from source. This step is useful when you want to directly obtain a version of mlc runtime library -and the cli. Please click the details below to see the instruction. +and the cli. Please click the details below to see the instructions. .. collapse:: Details @@ -63,7 +63,7 @@ and the cli. Please click the details below to see the instruction. conda activate mlc-chat-venv .. note:: - :doc:`TVM Unity ` compiler is not a dependency to MLCChat CLI. Only its runtime is required, which is automatically included in `3rdparty/tvm `_. + :doc:`TVM Unity ` compiler is not a dependency on MLCChat CLI. Only its runtime is required, which is automatically included in `3rdparty/tvm `_. **Step 2. Configure and build.** A standard git-based workflow is recommended to download MLC LLM, after which you can specify build requirements with our lightweight config generation tool: @@ -96,7 +96,7 @@ and the cli. Please click the details below to see the instruction. Run Models through MLCChat CLI ------------------------------ -Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model on command line. +Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model on the command line. **Ensure Model Exists.** As the input to ``mlc_chat_cli``, it is always good to double check if the compiled model exists. diff --git a/docs/deploy/ios.rst b/docs/deploy/ios.rst index 597f594bfb..b6e8e7b55a 100644 --- a/docs/deploy/ios.rst +++ b/docs/deploy/ios.rst @@ -7,9 +7,9 @@ iOS App and Swift API :local: :depth: 2 -The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from source. +The MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from the source. If you are an iOS user looking to try out the models, the pre-built package is recommended. If you are a -developer seeking to integrate new features into the package, building the iOS package from source is required. +developer seeking to integrate new features into the package, building the iOS package from the source is required. Use Pre-built iOS App --------------------- @@ -23,7 +23,7 @@ The MLC Chat app is now available in App Store at no cost. You can download and Build iOS App from Source ------------------------- -This section shows how we can build the app from source. +This section shows how we can build the app from the source. Step 1. Install Build Dependencies ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -134,7 +134,7 @@ Ensure that all the necessary dependencies and configurations are correctly set up in the Xcode project. Once you have made the necessary changes, build the iOS app using Xcode. -If you have an Apple Silicon Mac, you can select target "My Mac (designed for ipad)" +If you have an Apple Silicon Mac, you can select target "My Mac (designed for iPad)" to run on your Mac. You can also directly run it on your iPad or iPhone. .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/xcode-build.jpg @@ -163,7 +163,7 @@ controls the list of model URLs and model libs to be packaged into the app. Additionally, the app prepackages the models under ``./ios/dist``. This built-in list can be controlled by editing ``prepare_params.sh``. -You can package new prebuilt models or compiled models by changing the above fields and then repeat the steps above. +You can package new prebuilt models or compiled models by changing the above fields and then repeating the steps above. Build Apps with MLC Swift API @@ -193,8 +193,8 @@ your own app. The package is located under `ios/MLCSwift`. -ltokenizers_c -You can then can import the `MLCSwift` package in your app. -The following code shows an illustrative example about how to use the chat module. +You can then import the `MLCSwift` package into your app. +The following code shows an illustrative example of how to use the chat module. .. code:: swift @@ -221,7 +221,7 @@ The following code shows an illustrative example about how to use the chat modul Because the chat module makes heavy use of GPU and thread-local resources, it needs to run on a dedicated background thread. Therefore, **avoid using** `DispatchQueue`, which can cause context switching to - different threads and segfaults due to thread-safety issue. + different threads and segfaults due to thread-safety issues. Use the `ThreadWorker` class to launch all the jobs related to the chat module. You can check out the source code of the MLCChat app for a complete example. diff --git a/docs/deploy/javascript.rst b/docs/deploy/javascript.rst index cdbd4cc79e..08dd2cde26 100644 --- a/docs/deploy/javascript.rst +++ b/docs/deploy/javascript.rst @@ -6,9 +6,9 @@ WebLLM and Javascript API :depth: 2 WebLLM is a MLC chat web runtime (`WebLLM `_) -that allows you to build chat applications directly in browser. +that allows you to build chat applications directly in the browser. -Try out Prebuilt Webpage +Try out the Prebuilt Webpage ------------------------ To get started, you can try out `WebLLM prebuilt webpage `__. diff --git a/docs/deploy/python.rst b/docs/deploy/python.rst index 1a046538f9..3df5a08241 100644 --- a/docs/deploy/python.rst +++ b/docs/deploy/python.rst @@ -28,7 +28,7 @@ that supports other GPU runtime than the prebuilt version. Please refer our :ref Get Started ^^^^^^^^^^^ After confirming that the package ``mlc_chat`` is installed, we can follow the steps -below to chat with a MLC-compiled model in Python. +below to chat with an MLC-compiled model in Python. First, let us make sure that the MLC-compiled ``model`` we want to chat with already exists. @@ -99,7 +99,7 @@ If you do not have the MLC-compiled ``model`` ready: params_shard_*.bin ... -After making sure that the files exist, using the conda environment you used +After making sure that the files exist, use the conda environment you used to install ``mlc_chat``, from the ``mlc-llm`` directory, you can create a Python file ``sample_mlc_chat.py`` and paste the following lines: @@ -253,7 +253,7 @@ We provide an example below. fields specified. It is also worth noting that ``ConvConfig`` itself is overriding the original conversation template - specified by the field ``conv_template`` in chat configuration. Learn more about it in + specified by the field ``conv_template`` in the chat configuration. Learn more about it in :ref:`Configure MLCChat in JSON`. Raw Text Generation in Python @@ -272,7 +272,7 @@ We provide an example below. # Use a `ConvConfig` to define the generation settings # Since the "LM" template only supports raw text generation, - # system prompts will not be executed even if provided + # System prompts will not be executed even if provided conv_config = ConvConfig(stop_tokens=[2,], add_bos=True, stop_str="[INST]") # Note that `conv_config` is an optional subfield of `chat_config` @@ -367,7 +367,7 @@ The :class:`mlc_chat.ChatModule` class provides the following methods: Gradio Frontend --------------- -The gradio frontend provides a web interface for the MLC-Chat model, which allows user to interact with the model in a more user-friendly way and switch between different models to compare performance. +The gradio frontend provides a web interface for the MLC-Chat model, which allows users to interact with the model in a more user-friendly way and switch between different models to compare performance. To use gradio frontend, you need to install gradio first: .. code-block:: bash @@ -385,7 +385,7 @@ Then you can run the following code to start the interface: --port The port number to run gradio. The default value is ``7860``. --share Whether to create a publicly shareable link for the interface. -After setting up properly, you are expected to see the following interface in your browser: +After setting it up properly, you are expected to see the following interface in your browser: .. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/gradio-interface.png :width: 100% diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 95d57f491e..8451624fdb 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -6,7 +6,7 @@ Rest API :depth: 2 We provide `REST API `_ -for user to interact with MLC-Chat in their own programs. +for a user to interact with MLC-Chat in their own programs. Install MLC-Chat Package ------------------------ @@ -33,7 +33,7 @@ of mlc chat runtime. You only need to do this if you choose not to use the prebu First, make sure you install TVM unity (following the instruction in :ref:`install-tvm-unity`). You can choose to only pip install `mlc-ai-nightly` that comes with the tvm unity but skip `mlc-chat-nightly`. -Then please follow the instruction in :ref:`mlcchat_build_from_source` to build the necessary libraries. +Then please follow the instructions in :ref:`mlcchat_build_from_source` to build the necessary libraries. You can now use ``mlc_chat`` package by including the `python` directory to ``PYTHONPATH`` environment variable. From 2ec0cc8370c309c304fd1a0f02c2233ae784efd3 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sat, 28 Oct 2023 15:13:48 -0700 Subject: [PATCH 057/116] Minor enhancements to `ChatModule` (#1132) Some minor enhancements to `ChatModule`, mainly handle the device parsing solely in `_parse_device_str` instead of handling it both in the member function and the `__init__` function to avoid redundancy; and some type annotation fix. --- cpp/llm_chat.cc | 2 +- python/mlc_chat/chat_module.py | 82 +++++++++++++++++----------------- 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 7b869854d6..35a8d1f41e 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -437,7 +437,7 @@ class LLMChat { /*! * \brief Reload model, tokenizers and configurations from the specified model path. - * \param executable The module to reload. + * \param reload_lib The module to reload, it can either be a path to the library or a tvm Module. * \param model_path The path to search for models. * \param app_config_json The JSON string used to partially override the configuration loaded from * disk, default to empty string. diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 02625f4ef4..b2e0ec126c 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -532,14 +532,16 @@ def _get_lib_module_path( raise FileNotFoundError(err_msg) -def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_template: str) -> str: +def _convert_chat_config_to_json_str( + chat_config: Optional[ChatConfig], conv_template: Optional[str] +) -> str: """Convert user's input ChatConfig to a json string, omitting ``None`` fields. Parameters ---------- chat_config : Optional[ChatConfig] User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``. - conv_template : str + conv_template : Optional[str] The ``conv_template`` that will be used after considering potential override. Returns @@ -591,7 +593,7 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) -def _parse_device_str(device: str): +def _parse_device_str(device: str) -> (tvm.runtime.Device, str): """Parse the input device identifier into device name and id. Parameters @@ -603,11 +605,11 @@ def _parse_device_str(device: str): Returns ------- + dev : tvm.runtime.Device + The device. + device_name : str The name of the device. - - device_id : int - The id of the device, or 0 if not specified in the input. """ device_err_msg = ( f"Invalid device name: {device}. Please enter the device in the form " @@ -616,14 +618,32 @@ def _parse_device_str(device: str): ) device_args = device.split(":") if len(device_args) == 1: - return device_args[0], 0 + device_name, device_id = device_args[0], 0 elif len(device_args) == 2: - return device_args[0], int(device_args[1]) + device_name, device_id = device_args[0], int(device_args[1]) elif len(device_args) > 2: raise ValueError(device_err_msg) + if device_name == "cuda": + device = tvm.cuda(device_id) + elif device_name == "metal": + device = tvm.metal(device_id) + elif device_name == "vulkan": + device = tvm.vulkan(device_id) + elif device_name == "rocm": + device = tvm.rocm(device_id) + elif device_name == "opencl": + device = tvm.opencl(device_id) + elif device_name == "auto": + device, device_name = _detect_local_device(device_id) + logging.info(f"System automatically detected device: {device_name}") + else: + raise ValueError(device_err_msg) + + return device, device_name -def _detect_local_device(device_id: int = 0): + +def _detect_local_device(device_id: int = 0) -> (tvm.runtime.Device, str): """Automatically detect the local device if user does not specify. Parameters @@ -633,8 +653,11 @@ def _detect_local_device(device_id: int = 0): Returns ------ - dev : Device + dev : tvm.runtime.Device The local device. + + device_name : str + The name of the device. """ if tvm.metal().exist: return tvm.metal(device_id), "metal" @@ -715,34 +738,13 @@ def __init__( chat_config: Optional[ChatConfig] = None, model_lib_path: Optional[str] = None, ): - device_err_msg = ( - f"Invalid device name: {device}. Please enter the device in the form " - "'device_name:device_id' or 'device_name', where 'device_name' needs to be " - "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'." - ) - - # 0. Retrieve device_name and device_id (if any, default 0) from device arg - device_name, device_id = _parse_device_str(device) - - # 1. Get self.device - if device_name == "cuda": - self.device = tvm.cuda(device_id) - elif device_name == "metal": - self.device = tvm.metal(device_id) - elif device_name == "vulkan": - self.device = tvm.vulkan(device_id) - elif device_name == "rocm": - self.device = tvm.rocm(device_id) - elif device_name == "opencl": - self.device = tvm.opencl(device_id) - elif device_name == "auto": - self.device, device_name = _detect_local_device(device_id) - logging.info(f"System automatically detected device: {device_name}") - else: - raise ValueError(device_err_msg) + # 0. Get device: + # Retrieve device_name and device_id (if any, default 0) from device arg + self.device, device_name = _parse_device_str(device) device_type = self.device.device_type + device_id = self.device.device_id - # 2. Populate chat module and their functions + # 1. Populate chat module and their functions fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create") assert fcreate_chat_mod is not None chat_mod = fcreate_chat_mod(device_type, device_id) @@ -768,13 +770,13 @@ def __init__( self._get_role0_func = chat_mod["get_role0"] self._get_role1_func = chat_mod["get_role1"] - # 3. Look up model_path + # 2. Look up model_path self.model_path, self.config_file_path = _get_model_path(model) - # 4. Instantiate chat_config + # 3. Instantiate chat_config self.chat_config = _get_chat_config(self.config_file_path, chat_config) - # 5. Look up model library + # 4. Look up model library self.model_lib_path = _get_lib_module_path( model, self.model_path, @@ -784,7 +786,7 @@ def __init__( self.config_file_path, ) - # 6. Call reload + # 5. Call reload user_chat_config_json_str = _convert_chat_config_to_json_str( self.chat_config, self.chat_config.conv_template ) From 27ac5ac90bd7d7705b4258682273f19929a48cf1 Mon Sep 17 00:00:00 2001 From: DavidSharma <68979667+David-Sharma@users.noreply.github.com> Date: Sat, 28 Oct 2023 20:18:16 -0400 Subject: [PATCH 058/116] Updating tvm install docs (#1143) Updating the tvm install docs to assist a user in finding and copying zstd.dll to the correct folder. --- docs/install/tvm.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index 0dc716258d..c2b7998ada 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -132,7 +132,7 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. + It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. Hint - To locate the "tvm.dll" file in Conda, navigate to your user home directory (e.g., "/users/xxxx"). Search for "tvm.dll" and find the folder whose path contains the name of the current environment, such as "mlc-chat-venv." Once located, copy "zstd.dll" to that specific folder. .. _tvm-unity-build-from-source: From 2b6d832ff7f8dab03d901993d42253502f6b6dc8 Mon Sep 17 00:00:00 2001 From: fennecJ Date: Sun, 29 Oct 2023 14:59:10 +0800 Subject: [PATCH 059/116] Make the help info consistent with program name (#1137) When user use command `mlc_chat_cli --help`, the output will be something like Usage: mlc_chat [--help] ... That's because the program name specified in `cli_main.cc` is "mlc_chat". It will be less confusing if the output of help info shows Usage: mlc_chat_cli [--help] ... --- cpp/cli_main.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cli_main.cc b/cpp/cli_main.cc index 18bbb7ccea..f23c98bfff 100644 --- a/cpp/cli_main.cc +++ b/cpp/cli_main.cc @@ -480,7 +480,7 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id } int main(int argc, char* argv[]) { - argparse::ArgumentParser args("mlc_chat"); + argparse::ArgumentParser args("mlc_chat_cli"); args.add_description( "MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n" From 878ae84b2fd95687481edee4e301c52841919503 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 29 Oct 2023 00:19:20 -0700 Subject: [PATCH 060/116] Support parameter packing (#1146) --- pyproject.toml | 1 - python/mlc_chat/compiler/compile.py | 12 ++++++++++-- python/mlc_chat/compiler/model/llama_model.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 85ca20eb24..ccf754554f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ follow_imports = "skip" ignore_errors = false strict_optional = false install_types = true -non_interactive = true [tool.pylint.messages_control] max-line-length = 100 diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 8415ca21b8..cc6b61b1c2 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -66,7 +66,7 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes config: Path quantization: str - model_type: Model + model: Model target: Target opt: OptimizationFlags build_func: Callable[[IRModule, "CompileArgs"], None] @@ -79,7 +79,7 @@ def _echo_args(args: CompileArgs) -> None: print(f"{bold('Compiling with arguments:')}", file=out) print(f" {bold('--config'):<25} {args.config}", file=out) print(f" {bold('--quantization'):<25} {args.quantization}", file=out) - print(f" {bold('--model-type'):<25} {args.model_type.name}", file=out) + print(f" {bold('--model-type'):<25} {args.model.name}", file=out) print(f" {bold('--target'):<25} {args.target.export()}", file=out) print(f" {bold('--opt'):<25} {args.opt}", file=out) print(f" {bold('--output'):<25} {args.output}", file=out) @@ -101,6 +101,14 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin config, quantization, model_type, target, opt, build_func, prefix_symbols, output ) _echo_args(args) + model_config = args.model.config.from_file(args.config) + model = args.model.model(model_config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(f"{name}: {param.shape} {param.dtype}") OPT_FLAG_PRESET = { diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 49e947f741..6bf7647ff1 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -217,14 +217,26 @@ def get_default_spec(self): "prefill": { "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, }, "decode": { "inputs": nn.spec.Tensor([batch_size, 1], "int32"), "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, }, "softmax_with_temperature": { "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, }, } return nn.spec.ModuleSpec.from_raw(mod_spec, self) From c0c3a8d6eb2dab6542ac7ad0837fcb4c1ff8b40f Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 29 Oct 2023 13:16:46 -0700 Subject: [PATCH 061/116] [Slim-LM] Enable Group Quant (#1129) * Enable group quant via new interface. * Minor fix. * Linting. * Fix isort. * Fix mypy. * TE compute working. * Skip embed. * Support cpu+gpu quantization. * Add target option to tests. * Linting. --- .../compiler/model/llama_quantization.py | 101 +++++++++++ .../compiler/parameter/huggingface_loader.py | 39 ++++- python/mlc_chat/compiler/parameter/mapping.py | 2 +- python/mlc_chat/compiler/parameter/utils.py | 38 ++++- .../compiler/quantization/__init__.py | 2 +- .../compiler/quantization/group_quantizer.py | 70 ++++++++ .../python/parameter/test_group_quantizer.py | 157 ++++++++++++++++++ 7 files changed, 401 insertions(+), 8 deletions(-) create mode 100644 python/mlc_chat/compiler/model/llama_quantization.py create mode 100644 python/mlc_chat/compiler/quantization/group_quantizer.py create mode 100644 tests/python/parameter/test_group_quantizer.py diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py new file mode 100644 index 0000000000..dbf360c31d --- /dev/null +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -0,0 +1,101 @@ +""" +Quantization specs for Llama2 architecture. +TODO: add docstring +""" +from typing import Callable, Dict, List, Optional + +import tvm +from tvm.runtime import NDArray + +from ..parameter import QuantizeMapping +from ..quantization import QuantizeConfig +from ..quantization.group_quantizer import te_quantize as te_group_quantize +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM + + +def huggingface_group_quantize( + model_config: LlamaConfig, + quantize_config: QuantizeConfig, + target: Optional[tvm.target.Target] = None, +) -> QuantizeMapping: + """Returns a parameter mapping that maps a parameter in MLC LLM's model + definition to its eventual names and values after quantization. + + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + quantize_config : GroupQuantizeConfig + The configuration of the group quantization. + target : Optional[tvm.target.Target] + The target device to run the quantization on, by default None, which + means the quantization will be run on CPU. + + Returns + ------- + quantize_map : QuantizeMapping + The parameter mapping from a parameter in MLC LLM's model definition to + its eventual names and values after quantization. + """ + + def group_quantize( + param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None + ): + if target is None or target.kind.name == "llvm": + target = tvm.target.Target("llvm") + device = tvm.cpu() + elif target.kind.name == "cuda": + device = tvm.cuda() + else: + raise ValueError(f"Invalid target device: {target}") + param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param") + weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore + param_tensor, config + ) + s = tvm.te.create_schedule( + [compute.op for compute in [weight_compute, scale_compute] + other_computes] + ) + if target.kind.name == "cuda": + # thread_binding for cuda + for compute in [weight_compute, scale_compute] + other_computes: + xo, xi = s[compute].split(compute.op.axis[0], 256) + s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) + s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) + f_quantize = tvm.build( + s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target + ) + weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device) + scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device) + f_quantize(param.copyto(device), weight, scale) + return weight, scale + + # Param check + assert ( + quantize_config.kind == "group_quantize" + ), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}" + assert ( + quantize_config.name == "q4f16_1" + ), """Only support q4f16_1 quantization scheme for now.""" + + # Fetch model parameter & names + model = LlamaForCasualLM(model_config) + _, named_params = model.export_tvm(spec=model.get_default_spec()) + parameter_names = {name for name, _ in named_params} + + # Init mappings + param_map: Dict[str, List[str]] = {} + map_func: Dict[str, Callable] = {} + + # Dispatch quantization scheme + # Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py + for name in parameter_names: + if "norm.weight" not in name and "embed" not in name: + param_map[name] = [f"{name}_quantized", f"{name}_scale"] + map_func[name] = lambda x: group_quantize(x, quantize_config, target=target) + else: + # skip these parameters + param_map[name] = [name] + map_func[name] = lambda x: [x] + + return QuantizeMapping(param_map, map_func) diff --git a/python/mlc_chat/compiler/parameter/huggingface_loader.py b/python/mlc_chat/compiler/parameter/huggingface_loader.py index fa6beb40eb..ed91255c81 100644 --- a/python/mlc_chat/compiler/parameter/huggingface_loader.py +++ b/python/mlc_chat/compiler/parameter/huggingface_loader.py @@ -5,16 +5,21 @@ import logging from collections import OrderedDict, defaultdict from pathlib import Path -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import numpy as np from tqdm import tqdm from tvm.runtime import NDArray from tvm.runtime.ndarray import array as as_ndarray -from .mapping import ExternMapping +from .mapping import ExternMapping, QuantizeMapping from .stats import Stats -from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard +from .utils import ( + ParamQuantizer, + check_parameter_usage, + load_safetensor_shard, + load_torch_shard, +) logger = logging.getLogger(__name__) @@ -38,17 +43,22 @@ class HuggingFaceLoader: # pylint: disable=too-few-public-methods cached_files : Dict[Path, Dict[str, np.ndarray]] A cache of the loaded files. The key is the path of the file, and the value is a mapping from parameter name to the parameter value. + + quantize_param_map : Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters. """ stats: Stats - extern_param_map: ExternMapping cached_files: Dict[Path, Dict[str, np.ndarray]] torch_to_path: Dict[str, Path] + extern_param_map: ExternMapping + quantize_param_map: Optional[QuantizeMapping] def __init__( self, path: Path, extern_param_map: ExternMapping, + quantize_param_map: Optional[QuantizeMapping] = None, ) -> None: """Create a parameter loader from HuggingFace PyTorch format. @@ -66,12 +76,17 @@ def __init__( extern_param_map : ExternMapping Maps an MLC parameter to a list of PyTorch/SafeTensor parameters. + + quantize_param_map: Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters, default to None, which + means no quantization. """ assert path.is_file() self.stats = Stats() self.extern_param_map = extern_param_map self.cached_files = {} self.torch_to_path = {} + self.quantize_param_map = quantize_param_map if path.suffix in (".bin", ".safetensors"): self._load_file(path) for name in self.cached_files[path].keys(): @@ -90,7 +105,21 @@ def load(self) -> Iterator[Tuple[str, NDArray]]: mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) for mlc_name in tqdm(mlc_names): param = self._load_mlc_param(mlc_name) - yield mlc_name, param + if self.quantize_param_map: + with self.stats.timer("quant_time_sec"): + quantized_params = ParamQuantizer(self.quantize_param_map).quantize( + mlc_name, param + ) + for quantized_name, quantized_param in quantized_params: + logger.info( + ' Quantized Parameter: "%s", shape: %s, dtype: %s', + quantized_name, + quantized_param.shape, + quantized_param.dtype, + ) + yield quantized_name, quantized_param + else: + yield mlc_name, param cached_files = list(self.cached_files.keys()) for path in cached_files: self._unload_file(path) diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/parameter/mapping.py index 6f63dce71a..aab674cfa8 100644 --- a/python/mlc_chat/compiler/parameter/mapping.py +++ b/python/mlc_chat/compiler/parameter/mapping.py @@ -80,7 +80,7 @@ class QuantizeMapping: used to convert the quantized parameters into the desired form. """ - param_map: Dict[str, Callable[[str], List[str]]] + param_map: Dict[str, List[str]] map_func: Dict[str, Callable[[NDArray], List[NDArray]]] diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/parameter/utils.py index 596941aaca..a2789cee55 100644 --- a/python/mlc_chat/compiler/parameter/utils.py +++ b/python/mlc_chat/compiler/parameter/utils.py @@ -1,15 +1,51 @@ """Common utilities for loading parameters""" +# pylint: disable=too-few-public-methods import logging from pathlib import Path -from typing import Iterator, Set, Tuple +from typing import TYPE_CHECKING, Iterator, Set, Tuple import numpy as np from .mapping import ExternMapping +if TYPE_CHECKING: + from tvm.runtime import NDArray + + from ..parameter import QuantizeMapping + logger = logging.getLogger(__name__) +class ParamQuantizer: + """A parameter quantizer that quantizes given mlc-llm parameters""" + + quantize_map: "QuantizeMapping" + + def __init__(self, quantize_map: "QuantizeMapping") -> None: + self.quantize_map = quantize_map + + def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]: + """Apply quantization to the given parameters + + Parameters + ---------- + name : str + The name of the parameter + param : NDArray + The parameter to be quantized + + Returns + ------- + List[Tuple[str, NDArray]] + The quantized parameters, each with its name + """ + + assert name in self.quantize_map.param_map + quantized_names = self.quantize_map.param_map[name] + quantized_params = self.quantize_map.map_func[name](param) + return zip(quantized_names, quantized_params) + + def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]): """Check that all external parameters have been used and are stored in the weights file.""" used_extern_names = set(sum(param_map.param_map.values(), [])) diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py index ab352fc6c2..a932119f9c 100644 --- a/python/mlc_chat/compiler/quantization/__init__.py +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -1,2 +1,2 @@ """A subpackage for quantization and dequantization algorithms""" -from .quantization import QUANT +from .quantization import QUANT, QuantizeConfig diff --git a/python/mlc_chat/compiler/quantization/group_quantizer.py b/python/mlc_chat/compiler/quantization/group_quantizer.py new file mode 100644 index 0000000000..418617dd70 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/group_quantizer.py @@ -0,0 +1,70 @@ +"""A group quantizer for on the fly parameter quantization""" +# pylint: disable=too-few-public-methods + +from typing import List, Tuple + +from tvm import te, tir + +from .quantization import QuantizeConfig + + +def te_quantize( + weight: te.Tensor, config: QuantizeConfig +) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]: + """Group quantization for weight tensor, defined in tensor expression.""" + # pylint: disable=too-many-locals + assert len(weight.shape) == 2 + n, m = weight.shape + # compute scale per group + r = te.reduce_axis((0, config.group_size), name="r") + num_group = tir.ceildiv(m, config.group_size) + scale_shape = (n, num_group) + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda i, j: te.max( + tir.if_then_else( + j * config.group_size + r < weight.shape[1], + te.abs(weight[i, j * config.group_size + r]), + tir.const(1e-4, config.weight_dtype), + ), + axis=r, + ), + name="max_abs_value", + ) + scale = te.compute( + (n, m), + lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype), + name="scale", + ) + + # compute scaled weight + tir_max_int = tir.const(config.max_int_value, config.weight_dtype) + tir_zero = tir.const(0, config.weight_dtype) + tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype) + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda i, j: tir.min( + tir.max( + tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int), + tir_zero, + ), + tir_max_int_2, + ).astype(config.storage_dtype), + ) + + # compute quantized weight per storage + r = te.reduce_axis((0, config.num_elem_per_storage), name="r") + num_storage = config.num_storage_per_group * num_group + quantized_weight_shape = (n, num_storage) + quantized_weight = te.compute( + shape=quantized_weight_shape, + fcompute=lambda i, j: tir.sum( + scaled_weight[i, j * config.num_elem_per_storage + r] + << (r * config.quantize_dtype_bits), + axis=r, + where=j * config.num_elem_per_storage + r < m, + ), + name="weight", + ) + return quantized_weight, scale, [max_abs, scaled_weight] + # pylint: enable=too-many-locals diff --git a/tests/python/parameter/test_group_quantizer.py b/tests/python/parameter/test_group_quantizer.py new file mode 100644 index 0000000000..b0e4b6522f --- /dev/null +++ b/tests/python/parameter/test_group_quantizer.py @@ -0,0 +1,157 @@ +# pylint: disable=missing-docstring,too-many-instance-attributes +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Tuple, Union + +import numpy as np +import tvm +from mlc_chat.compiler import MODELS +from mlc_chat.compiler.model.llama_config import LlamaConfig +from mlc_chat.compiler.model.llama_quantization import huggingface_group_quantize +from mlc_chat.compiler.parameter import HuggingFaceLoader +from mlc_chat.support import tqdm +from tvm.runtime import NDArray + +if TYPE_CHECKING: + from tvm.relax.frontend import nn + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def test_load_torch_llama_group_quantize(base_path: Union[str, Path], target: str = "llvm"): + @dataclass + class TestGroupQuantizeConfig: + name: str = "q4f16_1" + kind: str = "group_quantize" + group_size: int = 32 + weight_dtype: str = "float16" + max_int_value: int = 7 + storage_dtype: str = "uint32" + num_elem_per_storage: int = 8 + num_storage_per_group: int = 4 + quantize_dtype_bits: int = 4 + + def quantize(self, _: "nn.Module") -> "nn.Module": + raise NotImplementedError + + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + model = MODELS["llama"] + model_config = LlamaConfig.from_file(path_config) + quantize_config = TestGroupQuantizeConfig() + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-torch"](model_config, None), + quantize_param_map=huggingface_group_quantize( + model_config, + quantize_config, + target=tvm.target.Target(target), + ), + ) + with tqdm.redirect(): + for _name, _param in loader.load(): + ... + + +def test_group_quantize_vs_numpy(): + bits = { + "int4": 4, + "int8": 8, + "fp16": 16, + "fp32": 32, + "int32": 32, + "uint32": 32, + } + + # pylint: disable=unused-variable + def group_quantize_np( + w: NDArray, + quantize_dtype: str = "int4", + storage_dtype: str = "uint32", + group_size: int = 32, + # symmetric: bool = True, + # transpose: bool = False, + ) -> Tuple[NDArray, NDArray]: + # pylint: disable=too-many-locals + def _pad_axis_by_factor(tensor: np.ndarray, axis: int, factor: int) -> np.ndarray: + dim = int(tensor.shape[axis]) + if dim % factor == 0: + return tensor + pad_width = [[0, 0] for i in tensor.shape] + pad_width[axis][1] = factor - (dim % factor) + return np.pad(tensor, pad_width, mode="constant", constant_values=0) + + def _clip( + x: np.ndarray, + x_min: int, + x_max: int, + dtype: str, + ) -> np.ndarray: + return np.clip(x, a_min=x_min, a_max=x_max).astype(dtype) + + num_elem_per_storage = bits[storage_dtype] // bits[quantize_dtype] + assert group_size % num_elem_per_storage == 0 + num_storage_units = (group_size + num_elem_per_storage - 1) // num_elem_per_storage + + # using numpy for now + w = w.numpy() + + # Step 1. Tile `w`: [n, k'] -> [n, k, group_size] + w = _pad_axis_by_factor(w, axis=1, factor=group_size) + n, k = [int(v) for v in w.shape] # pylint: disable=invalid-name + assert k % group_size == 0, "Padding is not working properly" + k = k // group_size + w = w.reshape([n, k, group_size]) + + # Step 2. Calculate + if quantize_dtype.startswith("int"): + max_int_value = (2 ** (bits[quantize_dtype] - 1)) - 1 + # 1) `scale`: [n, k, group_size] -> [n, k] + scale = np.maximum(np.amax(w, axis=-1), 1e-4) / max_int_value + # 2) `w`: w / scale + + w = _clip( + np.round(w / scale[:, :, np.newaxis]).astype("int") + max_int_value, + x_min=0, + x_max=max_int_value * 2, + dtype=storage_dtype, + ) + else: + raise NotImplementedError + + # Step 3. Compress `w` to every `num_elem_per_storage` elements + res = np.zeros((n, k, num_storage_units), dtype=np.uint32) + for i in range(n): + for j in range(k): + for m in range(num_storage_units): + for k in range(num_elem_per_storage): + res[i, j, m] += w[i, j, m * num_elem_per_storage + k] * 2**k + return tvm.nd.array(res), tvm.nd.array(scale) + # pylint: enable=too-many-locals + + +if __name__ == "__main__": + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-7b-hf", + target="llvm", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-7b-hf", + target="nvidia/nvidia-a100", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-13b-hf", + target="llvm", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-13b-hf", + target="nvidia/nvidia-a100", + ) From 2193767fa20b9f08a530c29b402bf4237cc84561 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 29 Oct 2023 16:35:07 -0700 Subject: [PATCH 062/116] Enable Mypy and Pylint in mlc_chat Python Package (#1149) --- ci/task/mypy.sh | 6 +- ci/task/pylint.sh | 4 +- python/mlc_chat/base.py | 4 +- python/mlc_chat/callback.py | 8 +- python/mlc_chat/chat_module.py | 186 +++++++++--------- python/mlc_chat/cli/compile.py | 1 - python/mlc_chat/embeddings/openai.py | 40 ++-- python/mlc_chat/gradio.py | 31 +-- python/mlc_chat/interface/openai_api.py | 4 +- python/mlc_chat/rest.py | 138 +++++++------ python/setup.py | 113 ++++++----- .../legacy => legacy-python}/compare_lib.py | 0 .../dump_intermediate.py | 0 .../legacy => legacy-python}/evaluate.py | 0 .../test_batching_llama.py | 0 .../test_build_args.py | 0 .../test_build_model_from_args.py | 0 17 files changed, 280 insertions(+), 255 deletions(-) rename tests/{python/legacy => legacy-python}/compare_lib.py (100%) rename tests/{python/legacy => legacy-python}/dump_intermediate.py (100%) rename tests/{python/legacy => legacy-python}/evaluate.py (100%) rename tests/{python/legacy => legacy-python}/test_batching_llama.py (100%) rename tests/{python/legacy => legacy-python}/test_build_args.py (100%) rename tests/{python/legacy => legacy-python}/test_build_model_from_args.py (100%) diff --git a/ci/task/mypy.sh b/ci/task/mypy.sh index f241cf2c3c..52da13da5f 100755 --- a/ci/task/mypy.sh +++ b/ci/task/mypy.sh @@ -8,8 +8,4 @@ export PYTHONPATH="./python:$PYTHONPATH" set -x -mypy ./python/mlc_chat/compiler \ - ./python/mlc_chat/support \ - ./tests/python/model \ - ./tests/python/parameter \ - ./tests/python/support +mypy ./python/ ./tests/python/ diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh index 9dae28767d..7d2a0d326b 100755 --- a/ci/task/pylint.sh +++ b/ci/task/pylint.sh @@ -11,5 +11,5 @@ set -x # TVM Unity is a dependency to this testing pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly -pylint --jobs $NUM_THREADS ./python/mlc_chat/compiler ./python/mlc_chat/support -pylint --jobs $NUM_THREADS --recursive=y ./tests/python/model ./tests/python/parameter ./tests/python/support/ +pylint --jobs $NUM_THREADS ./python/ +pylint --jobs $NUM_THREADS --recursive=y ./tests/python/ diff --git a/python/mlc_chat/base.py b/python/mlc_chat/base.py index e8393eecf7..8980330977 100644 --- a/python/mlc_chat/base.py +++ b/python/mlc_chat/base.py @@ -1,5 +1,4 @@ """Load MLC LLM library and _ffi_api functions.""" - import ctypes import os import sys @@ -15,7 +14,9 @@ def _load_mlc_llm_lib(): if sys.platform.startswith("win32") and sys.version_info >= (3, 8): for path in libinfo.get_dll_directories(): os.add_dll_directory(path) + # pylint: disable=protected-access lib_name = "mlc_llm" if tvm._ffi.base._RUNTIME_ONLY else "mlc_llm_module" + # pylint: enable=protected-access lib_path = libinfo.find_lib_path(lib_name, optional=False) return ctypes.CDLL(lib_path[0]), lib_path[0] @@ -46,6 +47,7 @@ def get_delta_message(curr_message: str, new_message: str) -> str: def set_global_random_seed(seed): + """Set global random seed for python, numpy, torch and tvm.""" if "numpy" in sys.modules: sys.modules["numpy"].random.seed(seed) if "torch" in sys.modules: diff --git a/python/mlc_chat/callback.py b/python/mlc_chat/callback.py index 921d9c0052..0ef3fe580b 100644 --- a/python/mlc_chat/callback.py +++ b/python/mlc_chat/callback.py @@ -1,6 +1,5 @@ """Namespace of callback functions in Python API.""" -#! pylint: disable=unused-import, invalid-name, unnecessary-pass - +# pylint: disable=unused-import, invalid-name, unnecessary-pass from queue import Queue from typing import Optional @@ -94,7 +93,7 @@ def __init__(self, callback_interval: int = 2, timeout: Optional[float] = None): Timeout for put and get from the delta messages queue """ super().__init__() - self.delta_messages = Queue() + self.delta_messages: Queue[str] = Queue() self.callback_interval = callback_interval self.timeout = timeout @@ -119,5 +118,4 @@ def __next__(self): value = self.delta_messages.get(timeout=self.timeout) if value: return value - else: - raise StopIteration() + raise StopIteration() diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index b2e0ec126c..058557c182 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -1,5 +1,5 @@ """The Python API for MLC chat.""" -#! pylint: disable=unused-import, invalid-name +#! pylint: disable=too-many-lines import inspect import json import logging @@ -8,12 +8,11 @@ import warnings from dataclasses import asdict, dataclass, fields from enum import Enum -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union import tvm -from tvm.runtime import disco +from tvm.runtime import disco # pylint: disable=unused-import -from . import callback from .interface.openai_api import ChatMessage # pylint: disable=line-too-long @@ -22,7 +21,7 @@ @dataclass -class ConvConfig: +class ConvConfig: # pylint: disable=too-many-instance-attributes r"""A dataclass that represents user-defined partial configuration for conversation template. This is an attribute of :class:`mlc_chat.ChatConfig`, which can then be passed in to the @@ -84,7 +83,7 @@ def __post_init__(self): @dataclass -class ChatConfig: +class ChatConfig: # pylint: disable=too-many-instance-attributes r"""A dataclass that represents user-defined partial configuration for the chat config file. @@ -172,18 +171,12 @@ class ChatConfig: max_window_size: Optional[int] = None @classmethod - def _from_json(chat_config_cls, json_obj: dict): - return chat_config_cls( - **{ - k: v - for k, v in json_obj.items() - if k in inspect.signature(chat_config_cls).parameters - } - ) + def _from_json(cls, json_obj: dict): + return cls(**{k: v for k, v in json_obj.items() if k in inspect.signature(cls).parameters}) @dataclass -class GenerationConfig: +class GenerationConfig: # pylint: disable=too-many-instance-attributes r"""A dataclass that represents user-defined generation configuration. An instance of ``GenerationConfig`` can be passed in to the generate function @@ -256,16 +249,16 @@ class GenerationConfig: max_gen_len: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 - n: Optional[int] = None + n: Optional[int] = None # pylint: disable=invalid-name stop: Optional[Union[str, List[str]]] = None @classmethod - def _from_chat_config(generation_config_cls, chat_config_obj: ChatConfig): - return generation_config_cls( + def _from_chat_config(cls, chat_config_obj: ChatConfig): + return cls( **{ f.name: getattr(chat_config_obj, f.name) for f in fields(chat_config_obj) - if f.name in inspect.signature(generation_config_cls).parameters + if f.name in inspect.signature(cls).parameters } ) @@ -273,20 +266,21 @@ def _from_chat_config(generation_config_cls, chat_config_obj: ChatConfig): class PlaceInPrompt(Enum): """The place of an input message in a prompt.""" - # The input message should have role names and corresponding seperators appended both prior to it and after it, - # making it a complete prompt. - All = 0 - # The input message is only the beginning part of a prompt, no role name and separator should be appended after - # the message since there will be future messages appended after the message. - Begin = 1 - # The input message is in the middle of a prompt, nothing should be appended before or after the message. - Middle = 2 - # The input message is the ending part of a prompt, no role name and separator should be appended prior to it - # since the message is concatenated to some prior messages. - End = 3 - - -def _get_model_path(model: str) -> (str, str): + # The input message should have role names and corresponding seperators appended both prior to + # it and after it, making it a complete prompt. + All = 0 # pylint: disable=invalid-name + # The input message is only the beginning part of a prompt, no role name and separator should + # be appended after the message since there will be future messages appended after the message. + Begin = 1 # pylint: disable=invalid-name + # The input message is in the middle of a prompt, nothing should be appended before or after + # the message. + Middle = 2 # pylint: disable=invalid-name + # The input message is the ending part of a prompt, no role name and separator should be + # appended prior to it since the message is concatenated to some prior messages. + End = 3 # pylint: disable=invalid-name + + +def _get_model_path(model: str) -> Tuple[str, str]: """Use user-provided argument ``model`` to search for a valid model path. We define "valid" as having an ``mlc-chat-config.json`` right under the folder. @@ -320,8 +314,8 @@ def _get_model_path(model: str) -> (str, str): for candidate in candidate_paths: chat_file = os.path.join(candidate, "mlc-chat-config.json") if os.path.isfile(chat_file): - logging.info(f"Using model folder: {os.path.abspath(candidate)}") - logging.info(f"Using mlc chat config: {os.path.abspath(chat_file)}") + logging.info("Using model folder: %s", os.path.abspath(candidate)) + logging.info("Using mlc chat config: %s", os.path.abspath(chat_file)) return candidate, chat_file # Failed to find a valid model_path, analyzing error for user @@ -336,7 +330,7 @@ def _get_model_path(model: str) -> (str, str): if found_folder: # Error 1: there is a folder, but not an mlc-llm model folder (E1) - err_msg = ( + raise FileNotFoundError( "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n" "Specifically, we cannot find `mlc-chat-config.json`, a required file. You should " "provide a path that contains the file.\n" @@ -346,21 +340,16 @@ def _get_model_path(model: str) -> (str, str): f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on " "how to load a model." ) - raise FileNotFoundError(err_msg) - else: - # Error 2: cannot find a folder (E0) - all_paths_str = "" - for path in candidate_paths: - all_paths_str += f"- {path}\n" - err_msg = ( - "Cannot find the model folder. We searched over the following possible paths:\n" - f"{all_paths_str}" - "You can try to pass in `model=/path/to/your-model-path`, and confirm " - "that it contains `mlc-chat-config.json`, among other essential files.\n" - f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an " - "example on how to load a model." - ) - raise FileNotFoundError(err_msg) + # Error 2: cannot find a folder (E0) + all_paths_str = "".join(f"- {path}\n" for path in candidate_paths) + raise FileNotFoundError( + "Cannot find the model folder. We searched over the following possible paths:\n" + f"{all_paths_str}" + "You can try to pass in `model=/path/to/your-model-path`, and confirm " + "that it contains `mlc-chat-config.json`, among other essential files.\n" + f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an " + "example on how to load a model." + ) def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfig]) -> ChatConfig: @@ -379,8 +368,8 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi ``ChatConfig`` corresponding to ``config_file_path``, overriden by ``user_chat_config``. """ final_chat_config = None - with open(config_file_path, mode="rt", encoding="utf-8") as f: - json_object = json.load(f) + with open(config_file_path, mode="rt", encoding="utf-8") as file: + json_object = json.load(file) final_chat_config = ChatConfig._from_json(json_object) # pylint: disable=protected-access if user_chat_config is not None: # We override using user's chat config @@ -415,9 +404,12 @@ def _get_generation_config( Returns ------ final_generation_config : GenerationConfig - ``GenerationConfig`` corresponding to ``user_chat_config``, overriden by ``user_generation_config``. + ``GenerationConfig`` corresponding to ``user_chat_config``, overriden by + ``user_generation_config``. """ + # pylint: disable=protected-access final_generation_config = GenerationConfig._from_chat_config(user_chat_config) + # pylint: enable=protected-access if user_generation_config is not None: # We override using user's chat config for field in fields(user_generation_config): @@ -428,7 +420,7 @@ def _get_generation_config( return final_generation_config -def _get_lib_module_path( +def _get_lib_module_path( # pylint: disable=too-many-arguments model: str, model_path: str, chat_config: ChatConfig, @@ -465,14 +457,12 @@ def _get_lib_module_path( # 1. Use user's model_lib_path if provided if model_lib_path is not None: if os.path.isfile(model_lib_path): - logging.info(f"Using library model: {model_lib_path}") + logging.info("Using library model: %s", model_lib_path) return model_lib_path - else: - err_msg = ( - f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\nPlease checkout " - f"{_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on how to load a model." - ) - raise FileNotFoundError(err_msg) + raise FileNotFoundError( + f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\n" + f"Please refer to {_PYTHON_GET_STARTED_TUTORIAL_URL} as tutorial on model loading." + ) # 2. Generate all possible file names according to OS candidate_lib_names = [] @@ -511,7 +501,7 @@ def _get_lib_module_path( # 4. Search for model library for candidate in candidate_paths: if os.path.isfile(candidate): - logging.info(f"Using library model: {os.path.abspath(candidate)}\n") + logging.info("Using library model: %s", os.path.abspath(candidate)) return candidate # 5. Error @@ -558,18 +548,17 @@ def _convert_chat_config_to_json_str( # Only want to keep entries that are not None; otherwise, we would override things to None assert hasattr(ChatConfig, "conv_config") # in case dataclass attribute name changes chat_dict = {} - for k, v in asdict(chat_config).items(): - if k == "conv_config" and v is not None: + for key, value in asdict(chat_config).items(): + if key == "conv_config" and value is not None: # conv template is another dict, do the same thing conv_dict = {} - for conv_k, conv_v in v.items(): + for conv_k, conv_v in value.items(): if conv_v is not None: conv_dict[conv_k] = conv_v - chat_dict[k] = conv_dict + chat_dict[key] = conv_dict continue - - if v is not None: - chat_dict[k] = v + if value is not None: + chat_dict[key] = value return json.dumps(chat_dict) @@ -593,7 +582,7 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio return json.dumps(asdict(generation_config)) -def _parse_device_str(device: str) -> (tvm.runtime.Device, str): +def _parse_device_str(device: str) -> Tuple[tvm.runtime.Device, str]: """Parse the input device identifier into device name and id. Parameters @@ -636,14 +625,14 @@ def _parse_device_str(device: str) -> (tvm.runtime.Device, str): device = tvm.opencl(device_id) elif device_name == "auto": device, device_name = _detect_local_device(device_id) - logging.info(f"System automatically detected device: {device_name}") + logging.info("System automatically detected device: %s", device_name) else: raise ValueError(device_err_msg) return device, device_name -def _detect_local_device(device_id: int = 0) -> (tvm.runtime.Device, str): +def _detect_local_device(device_id: int = 0) -> Tuple[tvm.runtime.Device, str]: """Automatically detect the local device if user does not specify. Parameters @@ -669,14 +658,14 @@ def _detect_local_device(device_id: int = 0) -> (tvm.runtime.Device, str): return tvm.vulkan(device_id), "vulkan" if tvm.opencl().exist: return tvm.opencl(device_id), "opencl" - logging.info( - "None of the following device is detected: metal, rocm, cuda, vulkan, opencl. Switch to llvm instead." + "None of the following device is detected: metal, rocm, cuda, vulkan, opencl. " + "Switch to llvm instead." ) return tvm.cpu(device_id), "llvm" -class ChatModule: +class ChatModule: # pylint: disable=too-many-instance-attributes r"""The ChatModule for MLC LLM. Examples @@ -798,9 +787,9 @@ def generate( generation_config: Optional[GenerationConfig] = None, progress_callback=None, ) -> Union[str, List[str]]: - r"""A high-level method that returns the full response from the chat module given a user prompt. - User can optionally specify which callback method to use upon receiving the response. By default, - no callback will be applied. + r"""A high-level method that returns the full response from the chat module given a user + prompt. User can optionally specify which callback method to use upon receiving the + response. By default, no callback will be applied. Parameters ---------- @@ -815,9 +804,10 @@ def generate( generation_config: Optional[GenerationConfig] The generation config object to override the ChatConfig generation settings. progress_callback: object - The optional callback method used upon receiving a newly generated message from the chat module. - See `mlc_chat/callback.py` for a full list of available callback classes. Currently, only - streaming to stdout callback method is supported, see `Examples` for more detailed usage. + The optional callback method used upon receiving a newly generated message from the + chat module. See `mlc_chat/callback.py` for a full list of available callback classes. + Currently, only streaming to stdout callback method is supported, see `Examples` for + more detailed usage. Returns ------- @@ -899,7 +889,7 @@ def reset_chat(self, chat_config: Optional[ChatConfig] = None): # Second argument is `partial_update = True` self._load_json_override_func(user_chat_config_json_str, True) - def embed_text(self, input: str): + def embed_text(self, input: str): # pylint: disable=redefined-builtin r"""Given a text input, returns its embedding in the LLM. Parameters @@ -923,7 +913,7 @@ def embed_text(self, input: str): return self._embed_func(input, PlaceInPrompt.Middle.value) def stats(self, verbose=False) -> str: - r"""Get the runtime stats of the encoding step, decoding step, (and embedding step if exists) + r"""Get the runtime stats of the encoding step, decoding step (and embedding step if exists) of the chat module in text form. Returns @@ -933,8 +923,7 @@ def stats(self, verbose=False) -> str: """ if verbose: return self._verbose_runtime_stats_text_func() - else: - return self._runtime_stats_text_func() + return self._runtime_stats_text_func() def benchmark_generate(self, prompt: str, generate_length: int) -> str: r"""Controlled generation with input prompt and fixed number of @@ -1008,7 +997,7 @@ def _unload(self): def _prefill( self, - input: Union[str, List[ChatMessage]], + input: Union[str, List[ChatMessage]], # pylint: disable=redefined-builtin decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, generation_config: Optional[GenerationConfig] = None, @@ -1043,7 +1032,7 @@ def _prefill( messages = [] role0 = self._get_role_0() role1 = self._get_role_1() - for idx, msg in enumerate(input[:-1]): + for _, msg in enumerate(input[:-1]): role = msg.role content = msg.content if role == "user": @@ -1055,11 +1044,12 @@ def _prefill( if not input[-1].role == "user": raise ValueError("Last message should be from user.") conv_config["messages"] = messages - conv_config[ - "offset" - ] = 0 # Otherwise, the offset will be set to the length of the conversation, which means history will be retained even after calling reset_chat + conv_config["offset"] = 0 + # Otherwise, the offset will be set to the length of the conversation, + # which means history will be retained even after calling reset_chat self._load_json_override( - json.dumps({"conv_config": conv_config}), partial_update=True + json.dumps({"conv_config": conv_config}), + partial_update=True, ) input_str = input[-1].content else: @@ -1071,13 +1061,13 @@ def _prefill( def _embed( self, - input: str, + input: str, # pylint: disable=redefined-builtin place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, generation_config: Optional[GenerationConfig] = None, ): - r"""A more fine-grained embedding API. Given a text input, get the embedding of the tokenized prompt. - User can decide where to place the input in the prompt. This functionality usually aids the subsequent - call to :func:`_prefill_with_embed`. + r"""A more fine-grained embedding API. Given a text input, get the embedding of the + tokenized prompt. User can decide where to place the input in the prompt. This functionality + usually aids the subsequent call to :func:`_prefill_with_embed`. Parameters ---------- @@ -1174,8 +1164,8 @@ def _load_json_override(self, config_str: str, partial_update: bool = False): config_str : str A json config string that partially specifies some of the options. partial_update : bool - Whether it's a partial update or full update, if set to true, we perform a partial update - on some of the provided options; if set to false, all options must be provided. + Whether it's a partial update or full update. If set to true, we perform a partial + update on some of the provided options; if set to false, all options must be provided. """ self._load_json_override_func(config_str, partial_update) diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 17b53797f4..8a41ab5bbb 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -1,6 +1,5 @@ """Command line entrypoint of compilation.""" import argparse -import json import logging from pathlib import Path from typing import Union diff --git a/python/mlc_chat/embeddings/openai.py b/python/mlc_chat/embeddings/openai.py index ed8dd5ea93..ad6b750b0b 100644 --- a/python/mlc_chat/embeddings/openai.py +++ b/python/mlc_chat/embeddings/openai.py @@ -1,11 +1,15 @@ +# pylint: disable=missing-docstring from __future__ import annotations import logging -from typing import List, Optional, Sequence, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple import numpy as np -from langchain.embeddings import OpenAIEmbeddings -from langchain.embeddings.openai import async_embed_with_retry, embed_with_retry +from langchain.embeddings import OpenAIEmbeddings # pylint: disable=import-error +from langchain.embeddings.openai import ( # pylint: disable=import-error + async_embed_with_retry, + embed_with_retry, +) logger = logging.getLogger(__name__) @@ -19,13 +23,13 @@ def _chunk_tokens(self, texts: Sequence[str]) -> Tuple[List[List], List[int]]: ) try: - import tiktoken - except ImportError: + import tiktoken # pylint: disable=import-outside-toplevel + except ImportError as err: raise ImportError( "Could not import tiktoken python package. " "This is needed in order to for OpenAIEmbeddings. " "Please install it with `pip install tiktoken`." - ) + ) from err tokens = [] indices = [] @@ -56,10 +60,10 @@ def _batch_embed( ) -> List[List[float]]: batched_embeddings: List[List[float]] = [] _chunk_size = chunk_size or self.chunk_size - _iter = range(0, len(inputs), _chunk_size) + _iter: Iterable = range(0, len(inputs), _chunk_size) if self.show_progress_bar: try: - from tqdm.auto import tqdm + from tqdm import tqdm # pylint: disable=import-outside-toplevel _iter = tqdm(_iter) except ImportError: @@ -79,10 +83,10 @@ async def _abatch_embed( ) -> List[List[float]]: batched_embeddings: List[List[float]] = [] _chunk_size = chunk_size or self.chunk_size - _iter = range(0, len(inputs), _chunk_size) + _iter: Iterable = range(0, len(inputs), _chunk_size) if self.show_progress_bar: try: - from tqdm.auto import tqdm + from tqdm import tqdm # pylint: disable=import-outside-toplevel _iter = tqdm(_iter) except ImportError: @@ -99,8 +103,12 @@ async def _abatch_embed( # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb - def _get_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + def _get_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, ) -> List[List[float]]: tokens, indices = self._chunk_tokens(texts) batched_embeddings = self._batch_embed(tokens, chunk_size=chunk_size) @@ -130,8 +138,12 @@ def _get_len_safe_embeddings( # please refer to # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb - async def _aget_len_safe_embeddings( - self, texts: List[str], *, engine: str, chunk_size: Optional[int] = None + async def _aget_len_safe_embeddings( # pylint: disable=too-many-locals,unused-argument + self, + texts: List[str], + *, + engine: str, + chunk_size: Optional[int] = None, ) -> List[List[float]]: tokens, indices = self._chunk_tokens(texts) batched_embeddings = await self._abatch_embed(tokens, chunk_size=chunk_size) diff --git a/python/mlc_chat/gradio.py b/python/mlc_chat/gradio.py index 8f0e16ab26..1ab6ae6dc0 100644 --- a/python/mlc_chat/gradio.py +++ b/python/mlc_chat/gradio.py @@ -1,11 +1,9 @@ """Gradio interface for MLC Chat.""" -# pylint: disable=import-error, import-outside-toplevel, invalid-name, line-too-long, protected-access -# too-many-instance-attributes, too-many-locals, unused-import - +# pylint: disable=import-error,invalid-name,too-many-instance-attributes,too-many-locals import argparse import glob import os -from typing import Dict +from typing import Dict, Optional import gradio as gr @@ -48,17 +46,19 @@ def _get_all_available_models_under_dir(artifact_path: str) -> Dict[str, str]: Note ---- We only search for folders under the artifact_path, without recursive search for subfolders. - For each folder, we count it as a valid MLC model folder if either it contains a `mlc-chat-config.json` - file, or it contains a `params` folder which contains a `mlc-chat-config.json` file. We will map - the name of a valid folder to its full path to the folder containing `mlc-chat-config.json`. + For each folder, we count it as a valid MLC model folder if either it contains an + `mlc-chat-config.json` file, or it contains a `params` folder which contains an + `mlc-chat-config.json` file. We will map the name of a valid folder to its full path to the + folder containing `mlc-chat-config.json`. """ # step 0. retrieve the absolute path of artifact_path search_dir = os.path.abspath(artifact_path) if not os.path.exists(search_dir): err_msg = ( - f"The artifact path {artifact_path} you provided is neither a valid full path nor a valid path ", - "relative to the current working directory. Please provide a correct artifact path.", + f"The artifact path {artifact_path} you provided is neither a valid full path nor a " + "valid path relative to the current working directory. Please provide a correct " + "artifact path.", ) raise FileNotFoundError(err_msg) @@ -78,9 +78,9 @@ def _get_all_available_models_under_dir(artifact_path: str) -> Dict[str, str]: class GradioModule: - r"""The Gradio module for MLC Chat. Different from ChatModule Python API, Gradio module allows users - to load in a directory of models, watch the streaming in web browser, and switch between models more - easily to compare performance. + r"""The Gradio module for MLC Chat. Different from ChatModule Python API, Gradio module allows + users to load in a directory of models, watch the streaming in web browser, and switch between + models more easily to compare performance. Note: Multimodality will be supported soon, i.e. allowing users to upload an image to chat. """ @@ -88,7 +88,7 @@ class GradioModule: def __init__(self, artifact_path: str = "dist", device: str = "auto"): self.artifact_path = artifact_path self.device_str = device - self.chat_mod = None + self.chat_mod: Optional[ChatModule] = None self.model_dict = _get_all_available_models_under_dir(artifact_path) def gradio_reload_model(self, model_name: str): @@ -133,6 +133,7 @@ def gradio_answer(self, chatbot, stream_interval): Note: Below is a low-level implementation of generate() API, since it's easier to yield without delta callback.""" prompt = chatbot[-1][0] + # pylint: disable=protected-access self.chat_mod._prefill(prompt) i, new_msg = 0, "" while not self.chat_mod._stopped(): @@ -142,6 +143,7 @@ def gradio_answer(self, chatbot, stream_interval): chatbot[-1][1] = new_msg yield chatbot i += 1 + # pylint: enable=protected-access def gradio_stats(self): """Get runtime statistics.""" @@ -155,7 +157,8 @@ def launch_gradio( share: bool = False, host: str = "127.0.0.1", ): - r"""Launch the gradio interface with a given port, creating a publically sharable link if specified.""" + r"""Launch the gradio interface with a given port, creating a publically sharable link if + specified.""" # create a gradio module mod = GradioModule(artifact_path, device) diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index ed08c75b0a..654b1646bc 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -1,5 +1,7 @@ +# pylint: disable=missing-docstring,fixme,import-error,too-few-public-methods """ -Adapted from FastChat's OpenAI protocol: https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py +Adapted from FastChat's OpenAI protocol: +https://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py """ import time diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index d48316845d..a1cc57c4e9 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -1,7 +1,9 @@ +# pylint: disable=missing-docstring,fixme,import-error import argparse import asyncio +import dataclasses from contextlib import asynccontextmanager -from dataclasses import dataclass, field, fields +from typing import Dict import numpy as np import uvicorn @@ -12,14 +14,31 @@ from .base import set_global_random_seed from .chat_module import ChatModule -from .interface.openai_api import * +from .interface.openai_api import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + DeltaMessage, + EmbeddingsRequest, + EmbeddingsResponse, + UsageInfo, +) -@dataclass +@dataclasses.dataclass class RestAPIArgs: - """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API server.""" + """RestAPIArgs is the dataclass that organizes the arguments used for starting a REST API + server.""" - model: str = field( + model: str = dataclasses.field( metadata={ "help": ( """ @@ -32,7 +51,7 @@ class RestAPIArgs: ) } ) - lib_path: str = field( + lib_path: str = dataclasses.field( default=None, metadata={ "help": ( @@ -42,7 +61,7 @@ class RestAPIArgs: ) }, ) - device: str = field( + device: str = dataclasses.field( default="auto", metadata={ "help": ( @@ -56,7 +75,7 @@ class RestAPIArgs: ) }, ) - host: str = field( + host: str = dataclasses.field( default="127.0.0.1", metadata={ "help": ( @@ -66,7 +85,7 @@ class RestAPIArgs: ) }, ) - port: int = field( + port: int = dataclasses.field( default=8000, metadata={ "help": ( @@ -76,7 +95,7 @@ class RestAPIArgs: ) }, ) - random_seed: int = field( + random_seed: int = dataclasses.field( default=None, metadata={ "help": ( @@ -92,7 +111,7 @@ class RestAPIArgs: def convert_args_to_argparser() -> argparse.ArgumentParser: """Convert from RestAPIArgs to an equivalent ArgumentParser.""" args = argparse.ArgumentParser("MLC Chat REST API") - for field in fields(RestAPIArgs): + for field in dataclasses.fields(RestAPIArgs): name = field.name.replace("_", "-") field_name = f"--{name}" # `kwargs` contains `help`, `choices`, and `action` @@ -105,11 +124,11 @@ def convert_args_to_argparser() -> argparse.ArgumentParser: return args -session = {} +session: Dict[str, ChatModule] = {} @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(_app: FastAPI): if ARGS.random_seed is not None: set_global_random_seed(ARGS.random_seed) chat_mod = ChatModule( @@ -118,19 +137,14 @@ async def lifespan(app: FastAPI): model_lib_path=ARGS.lib_path, ) session["chat_mod"] = chat_mod - yield - session.clear() -app = FastAPI(lifespan=lifespan) - -origins = [ - "*", -] +origins = ["*"] -app.add_middleware( +APP = FastAPI(lifespan=lifespan) +APP.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, @@ -147,24 +161,24 @@ def __aiter__(self): return self async def get_next_msg(self): + # pylint: disable=protected-access if not session["chat_mod"]._stopped(): session["chat_mod"]._decode(generation_config=self.generation_config) msg = session["chat_mod"]._get_message() return msg - else: - raise StopAsyncIteration + # pylint: enable=protected-access + raise StopAsyncIteration async def __anext__(self): if not session["chat_mod"]._stopped(): task = asyncio.create_task(self.get_next_msg()) msg = await task return msg - else: - raise StopAsyncIteration + raise StopAsyncIteration -@app.post("/v1/chat/completions") -async def request_completion(request: ChatCompletionRequest): +@APP.post("/v1/chat/completions") +async def request_chat_completion(request: ChatCompletionRequest): """ Creates model response for the given chat conversation. The messages field contains a list of messages (describing the conversation history). eg: @@ -192,7 +206,10 @@ async def request_completion(request: ChatCompletionRequest): session["chat_mod"].reset_chat() # Reset previous history, KV cache, etc. if request.stream: - session["chat_mod"]._prefill(input=request.messages, generation_config=generation_config) + session["chat_mod"]._prefill( # pylint: disable=protected-access + input=request.messages, + generation_config=generation_config, + ) async def iter_response(): prev_txt = "" @@ -213,27 +230,24 @@ async def iter_response(): yield f"data: {chunk.json(exclude_unset=True)}\n\n" return StreamingResponse(iter_response(), media_type="text/event-stream") - else: - msg = session["chat_mod"].generate( - prompt=request.messages, generation_config=generation_config - ) - if isinstance(msg, str): - msg = [msg] - return ChatCompletionResponse( - choices=[ - ChatCompletionResponseChoice( - index=index, - message=ChatMessage(role="assistant", content=msg[index]), - finish_reason="stop", - ) - for index in range(len(msg)) - ], - # TODO: Fill in correct usage info - usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) + msg = session["chat_mod"].generate(prompt=request.messages, generation_config=generation_config) + if isinstance(msg, str): + msg = [msg] + return ChatCompletionResponse( + choices=[ + ChatCompletionResponseChoice( + index=index, + message=ChatMessage(role="assistant", content=msg[index]), + finish_reason="stop", + ) + for index in range(len(msg)) + ], + # TODO: Fill in correct usage info + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) -@app.post("/v1/completions") +@APP.post("/v1/completions") async def request_completion(request: CompletionRequest): """ Creates a completion for a given prompt. @@ -264,7 +278,10 @@ async def request_completion(request: CompletionRequest): prompt = request.prompt if request.stream: - session["chat_mod"]._prefill(input=prompt, generation_config=generation_config) + session["chat_mod"]._prefill( # pylint: disable=protected-access + input=prompt, + generation_config=generation_config, + ) async def iter_response(): prev_txt = "" @@ -283,24 +300,23 @@ async def iter_response(): yield f"data: {chunk.json(exclude_unset=True)}\n\n" return StreamingResponse(iter_response(), media_type="text/event-stream") - else: - msg = session["chat_mod"].generate(prompt=prompt, generation_config=generation_config) - return CompletionResponse( - choices=[CompletionResponseChoice(index=0, text=msg)], - # TODO: Fill in correct usage info - usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), - ) + msg = session["chat_mod"].generate(prompt=prompt, generation_config=generation_config) + return CompletionResponse( + choices=[CompletionResponseChoice(index=0, text=msg)], + # TODO: Fill in correct usage info + usage=UsageInfo(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) -@app.post("/v1/embeddings") +@APP.post("/v1/embeddings") async def request_embeddings(request: EmbeddingsRequest): """ Gets embedding for some text. """ inps = [] - if type(request.input) == str: + if isinstance(request.input, str): inps.append(request.input) - elif type(request.input) == list: + elif isinstance(request.input, list): inps = request.input else: assert f"Invalid input type {type(request.input)}" @@ -318,7 +334,7 @@ async def request_embeddings(request: EmbeddingsRequest): ) -@app.post("/chat/reset") +@APP.post("/chat/reset") async def reset(): """ Reset the chat for the currently initialized model. @@ -326,7 +342,7 @@ async def reset(): session["chat_mod"].reset_chat() -@app.get("/stats") +@APP.get("/stats") async def read_stats(): """ Get the runtime stats. @@ -334,7 +350,7 @@ async def read_stats(): return session["chat_mod"].stats() -@app.get("/verbose_stats") +@APP.get("/verbose_stats") async def read_stats_verbose(): """ Get the verbose runtime stats. diff --git a/python/setup.py b/python/setup.py index aa7394f1d2..af471c19c0 100644 --- a/python/setup.py +++ b/python/setup.py @@ -2,7 +2,6 @@ """Setup MLC LLM package.""" import os import shutil -import sys from setuptools import find_packages, setup from setuptools.dist import Distribution @@ -16,7 +15,8 @@ def get_lib_path(): # Directly exec libinfo to get the right setup libinfo_py = os.path.join(CURRENT_DIR, "./mlc_chat/libinfo.py") libinfo = {"__file__": libinfo_py} - exec(compile(open(libinfo_py, "rb").read(), libinfo_py, "exec"), libinfo, libinfo) + with open(libinfo_py, "rb") as f: + exec(compile(f.read(), libinfo_py, "exec"), libinfo, libinfo) version = libinfo["__version__"] # conda installs libraries into env instead of packaging with pip @@ -35,10 +35,11 @@ def git_describe_version(original_version): """Get git describe version.""" ver_py = os.path.join(CURRENT_DIR, "..", "version.py") libver = {"__file__": ver_py} - exec(compile(open(ver_py, "rb").read(), ver_py, "exec"), libver, libver) + with open(ver_py, "rb") as f: + exec(compile(f.read(), ver_py, "exec"), libver, libver) _, gd_version = libver["git_describe_version"]() if gd_version is not None and gd_version != original_version: - print("Use git describe based version %s" % gd_version) + print(f"Use git describe based version {gd_version}") return gd_version @@ -47,60 +48,66 @@ def git_describe_version(original_version): class BinaryDistribution(Distribution): + """This class is needed in order to create OS specific wheels.""" + def has_ext_modules(self): + """Return True for binary distribution.""" return True def is_pure(self): + """Return False for binary distribution.""" return False -setup_kwargs = {} -if not CONDA_BUILD: - with open("MANIFEST.in", "w") as fo: - for path in LIB_LIST: +def main(): + """The main entrypoint.""" + setup_kwargs = {} + if not CONDA_BUILD: + with open("MANIFEST.in", "w", encoding="utf-8") as fo: + for path in LIB_LIST: + if os.path.isfile(path): + shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_chat")) + _, libname = os.path.split(path) + fo.write(f"include mlc_chat/{libname}\n") + setup_kwargs = {"include_package_data": True} + + setup( + name="mlc_chat", + version=__version__, + description="MLC Chat: an universal runtime running LLMs", + url="https://llm.mlc.ai/", + author="MLC LLM Contributors", + license="Apache 2.0", + # See https://pypi.org/classifiers/ + classifiers=[ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + ], + keywords="machine learning", + zip_safe=False, + packages=find_packages(), + package_dir={"mlc_chat": "mlc_chat"}, + install_requires=["fastapi", "uvicorn", "shortuuid"], + distclass=BinaryDistribution, + **setup_kwargs, + ) + + def _remove_path(path): + if os.path.exists(path): if os.path.isfile(path): - shutil.copy(path, os.path.join(CURRENT_DIR, "mlc_chat")) - _, libname = os.path.split(path) - fo.write(f"include mlc_chat/{libname}\n") - setup_kwargs = {"include_package_data": True} - - -setup( - name="mlc_chat", - version=__version__, - description="MLC Chat: an universal runtime running LLMs", - url="https://llm.mlc.ai/", - author="MLC LLM Contributors", - license="Apache 2.0", - # See https://pypi.org/classifiers/ - classifiers=[ - "License :: OSI Approved :: Apache Software License", - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - ], - keywords="machine learning", - zip_safe=False, - packages=find_packages(), - package_dir={"mlc_chat": "mlc_chat"}, - install_requires=["fastapi", "uvicorn", "shortuuid"], - distclass=BinaryDistribution, - **setup_kwargs, -) - - -def _remove_path(path): - if os.path.exists(path): - if os.path.isfile(path): - os.remove(path) - elif os.path.isdir(path): - shutil.rmtree(path) - - -if not CONDA_BUILD: - # Wheel cleanup - os.remove("MANIFEST.in") - for path in LIB_LIST: - _, libname = os.path.split(path) - _remove_path(f"mlc_chat/{libname}") + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + + if not CONDA_BUILD: + # Wheel cleanup + os.remove("MANIFEST.in") + for path in LIB_LIST: + _, libname = os.path.split(path) + _remove_path(f"mlc_chat/{libname}") + + +main() diff --git a/tests/python/legacy/compare_lib.py b/tests/legacy-python/compare_lib.py similarity index 100% rename from tests/python/legacy/compare_lib.py rename to tests/legacy-python/compare_lib.py diff --git a/tests/python/legacy/dump_intermediate.py b/tests/legacy-python/dump_intermediate.py similarity index 100% rename from tests/python/legacy/dump_intermediate.py rename to tests/legacy-python/dump_intermediate.py diff --git a/tests/python/legacy/evaluate.py b/tests/legacy-python/evaluate.py similarity index 100% rename from tests/python/legacy/evaluate.py rename to tests/legacy-python/evaluate.py diff --git a/tests/python/legacy/test_batching_llama.py b/tests/legacy-python/test_batching_llama.py similarity index 100% rename from tests/python/legacy/test_batching_llama.py rename to tests/legacy-python/test_batching_llama.py diff --git a/tests/python/legacy/test_build_args.py b/tests/legacy-python/test_build_args.py similarity index 100% rename from tests/python/legacy/test_build_args.py rename to tests/legacy-python/test_build_args.py diff --git a/tests/python/legacy/test_build_model_from_args.py b/tests/legacy-python/test_build_model_from_args.py similarity index 100% rename from tests/python/legacy/test_build_model_from_args.py rename to tests/legacy-python/test_build_model_from_args.py From 0a253741dbd647660acbe95b152263fc7025ac28 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 29 Oct 2023 21:17:38 -0700 Subject: [PATCH 063/116] Migrate Compiler Passes (#1150) --- python/mlc_chat/compiler/__init__.py | 1 + python/mlc_chat/compiler/compile.py | 106 ++-------- .../compiler/compiler_pass/__init__.py | 2 + .../compiler_pass/clean_up_tir_attrs.py | 31 +++ .../compiler_pass/fuse_decode_matmul_ewise.py | 81 ++++++++ .../compiler_pass/fuse_decode_take.py | 83 ++++++++ .../compiler_pass/fuse_decode_transpose.py | 109 ++++++++++ .../compiler_pass/fuse_transpose_matmul.py | 153 ++++++++++++++ .../compiler_pass/lift_global_buffer_alloc.py | 196 ++++++++++++++++++ .../compiler/compiler_pass/pipeline.py | 49 +++++ .../mlc_chat/compiler/flags_optimization.py | 77 +++++++ .../compiler/model/llama_quantization.py | 4 +- .../compiler/quantization/group_quantizer.py | 6 +- .../python/parameter/test_group_quantizer.py | 6 +- 14 files changed, 808 insertions(+), 96 deletions(-) create mode 100644 python/mlc_chat/compiler/compiler_pass/__init__.py create mode 100644 python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py create mode 100644 python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py create mode 100644 python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py create mode 100644 python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py create mode 100644 python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py create mode 100644 python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py create mode 100644 python/mlc_chat/compiler/compiler_pass/pipeline.py create mode 100644 python/mlc_chat/compiler/flags_optimization.py diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index c0f6c7e51b..2aa7dfcf98 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -2,6 +2,7 @@ A compiler for MLC Chat. By default, it is not imported to MLC Chat to avoid unnecessary dependency, but users could optionally import it if they want to use the compiler. """ +from . import compiler_pass from .compile import ( # pylint: disable=redefined-builtin CompileArgs, OptimizationFlags, diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index cc6b61b1c2..5b77a94f81 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -1,63 +1,15 @@ """Python entrypoint of compilation.""" -import argparse import dataclasses -import logging from io import StringIO from pathlib import Path from typing import Callable -from mlc_chat.compiler.model import Model -from tvm import IRModule # pylint: disable=wrong-import-order -from tvm.target import Target # pylint: disable=wrong-import-order +from tvm import IRModule, relax +from tvm.target import Target +from ..compiler.model import Model from ..support.style import bold - -logger = logging.getLogger(__name__) - - -@dataclasses.dataclass -class OptimizationFlags: - """Optiization flags""" - - cutlass_attn: bool = True - cutlass_norm: bool = True - cublas_gemm: bool = False - cudagraph: bool = False - - def __repr__(self) -> str: - out = StringIO() - print(f"cutlass_attn={int(self.cutlass_attn)}", file=out, end="") - print(f";cutlass_norm={int(self.cutlass_norm)}", file=out, end="") - print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") - print(f";cudagraph={int(self.cudagraph)}", file=out, end="") - return out.getvalue().rstrip() - - @staticmethod - def from_str(source: str) -> "OptimizationFlags": - """Parse optimization flags from a string.""" - - if source in OPT_FLAG_PRESET: - return OPT_FLAG_PRESET[source] - - def boolean(value: str) -> bool: - if value == "0": - return False - if value == "1": - return True - raise ValueError(f"Invalid boolean value: {value}") - - parser = argparse.ArgumentParser(description="optimization flags") - parser.add_argument("--cutlass_attn", type=boolean, default=True) - parser.add_argument("--cutlass_norm", type=boolean, default=True) - parser.add_argument("--cublas_gemm", type=boolean, default=False) - parser.add_argument("--cudagraph", type=boolean, default=False) - results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) - return OptimizationFlags( - cutlass_attn=results.cutlass_attn, - cutlass_norm=results.cutlass_norm, - cublas_gemm=results.cublas_gemm, - cudagraph=results.cudagraph, - ) +from .flags_optimization import OptimizationFlags @dataclasses.dataclass @@ -86,6 +38,19 @@ def _echo_args(args: CompileArgs) -> None: print(out.getvalue().rstrip()) +def _compile(args: CompileArgs): + model_config = args.model.config.from_file(args.config) + model = args.model.model(model_config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + with args.target: + mod = relax.get_pipeline("mlc_llm")(mod) + mod.show(black_format=False) + for name, param in named_params: + print(f"{name}: {param.shape} {param.dtype}") + + def compile( # pylint: disable=too-many-arguments,redefined-builtin config: Path, quantization, @@ -101,39 +66,4 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin config, quantization, model_type, target, opt, build_func, prefix_symbols, output ) _echo_args(args) - model_config = args.model.config.from_file(args.config) - model = args.model.model(model_config) - mod, named_params = model.export_tvm( - spec=model.get_default_spec(), # type: ignore - ) - mod.show(black_format=False) - for name, param in named_params: - print(f"{name}: {param.shape} {param.dtype}") - - -OPT_FLAG_PRESET = { - "O0": OptimizationFlags( - cutlass_attn=False, - cutlass_norm=False, - cublas_gemm=False, - cudagraph=False, - ), - "O1": OptimizationFlags( - cutlass_attn=False, - cutlass_norm=True, - cublas_gemm=False, - cudagraph=False, - ), - "O2": OptimizationFlags( - cutlass_attn=True, - cutlass_norm=True, - cublas_gemm=False, - cudagraph=False, - ), - "O3": OptimizationFlags( - cutlass_attn=True, - cutlass_norm=True, - cublas_gemm=False, - cudagraph=True, - ), -} + _compile(args) diff --git a/python/mlc_chat/compiler/compiler_pass/__init__.py b/python/mlc_chat/compiler/compiler_pass/__init__.py new file mode 100644 index 0000000000..762ba8c1e0 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/__init__.py @@ -0,0 +1,2 @@ +"""Compiler passes used in MLC LLM.""" +from . import pipeline as _pipeline diff --git a/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py b/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py new file mode 100644 index 0000000000..71848ba546 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py @@ -0,0 +1,31 @@ +"""A compiler pass that cleans up undesired TIR attrs.""" +from typing import List + +import tvm +from tvm.ir.module import IRModule + + +@tvm.transform.module_pass(opt_level=0, name="CleanUpTIRAttrs") +class CleanUpTIRAttrs: # pylint: disable=too-few-public-methods + """A compiler pass that cleans up undesired TIR attrs.""" + + def __init__(self, attrs: List[str]): + self.attrs = attrs + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for g_var in list(mod.functions): + func = mod[g_var] + changed = False + for attr in self.attrs: + if func.attrs is not None and attr in func.attrs: + func = func.without_attr(attr) + changed = True + break + if changed: + mod[g_var] = func + return mod diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py new file mode 100644 index 0000000000..0e02f2ae5a --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py @@ -0,0 +1,81 @@ +"""A compiler pass that fuses decode + matmul + elementwise.""" +import tvm +from tvm import IRModule, relax +from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise") +class FuseDecodeMatmulEwise: # pylint: disable=too-few-public-methods + """A compiler pass that fuses decode + matmul + elementwise.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for n_aux_tensor in [1, 2, 3, 4]: + for match_ewise in [0, 1, 2, 6]: + if match_ewise == 6 and n_aux_tensor != 4: + continue + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_matmul", + *_pattern(match_ewise, n_aux_tensor), + ) + ] + )(mod) + mod = relax.transform.FuseTIR()(mod) + return mod + + +def _pattern(match_ewise: int, n_aux_tensor: int): + # pylint: disable=invalid-name + w_scaled = wildcard() + x = wildcard() + w = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([w_scaled] + [wildcard() for _ in range(n_aux_tensor)]), + add_constraint=False, + ) + matmul = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([x, w] + [wildcard() for _ in range(match_ewise)]), + add_constraint=False, + ) + # pylint: enable=invalid-name + annotations = { + "w_scaled": w_scaled, + "x": x, + "w": w, + "matmul": matmul, + } + + def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["w"] + if not isinstance(call, relax.Call): + return False + g_var = call.args[0] + if not isinstance(g_var, relax.GlobalVar): + return False + return g_var.name_hint.startswith("decode") or g_var.name_hint.startswith("fused_decode") + + def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["matmul"] + if not isinstance(call, relax.Call): + return False + g_var = call.args[0] + if not isinstance(g_var, relax.GlobalVar): + return False + return ( + g_var.name_hint.startswith("matmul") + or g_var.name_hint.startswith("fused_matmul") + or g_var.name_hint.startswith("NT_matmul") + or g_var.name_hint.startswith("fused_NT_matmul") + ) + + def _check(ctx: relax.transform.PatternCheckContext) -> bool: + return _check_decoding(ctx) and _check_matmul(ctx) + + return matmul, annotations, _check diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py new file mode 100644 index 0000000000..96678fa951 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py @@ -0,0 +1,83 @@ +"""A compiler pass that fuses decode + take.""" +import tvm +from tvm import IRModule, relax, tir +from tvm.relax.dpl.pattern import ( + GlobalVarPattern, + TuplePattern, + is_const, + is_op, + wildcard, +) + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") +class FuseDecodeTake: # pylint: disable=too-few-public-methods + """A compiler pass that fuses decode + take.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + for n_aux_tensor in [2, 3]: + for match_tir_vars in [False, True]: + mod = relax.transform.FuseOpsByPattern( + [ + ( + "decode_take", + *_pattern(n_aux_tensor, match_tir_vars), + ) + ] + )(mod) + mod = relax.transform.FuseTIR()(mod) + for g_var, func in mod.functions.items(): + name = g_var.name_hint + if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)): + mod = tvm.IRModule({"main": func}) + sch = tir.Schedule(mod) + sch.compute_inline("decode") + mod[g_var] = sch.mod["main"] + return mod + + +def _pattern(n_aux_tensor: int, match_tir_vars: bool): + decode = is_op("relax.call_tir")( + GlobalVarPattern(), + TuplePattern([wildcard() for _ in range(n_aux_tensor)]), + add_constraint=False, + ) + indices = ~is_const() + if match_tir_vars: + call_tir_args_take = [ + GlobalVarPattern(), + TuplePattern([decode, indices]), + wildcard(), + ] + else: + call_tir_args_take = [ + GlobalVarPattern(), + TuplePattern([decode, indices]), + ] + take = is_op("relax.call_tir")( + *call_tir_args_take, + add_constraint=False, + ) + annotations = { + "take": take, + "decode": decode, + "indices": indices, + } + + def _check(ctx: relax.transform.PatternCheckContext) -> bool: + take = ctx.annotated_expr["take"] + decode = ctx.annotated_expr["decode"] + if not isinstance(decode, relax.expr.Call): + return False + if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( + decode.args[0], relax.GlobalVar + ): + return False + return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint + + return take, annotations, _check diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py new file mode 100644 index 0000000000..99bcb1b602 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py @@ -0,0 +1,109 @@ +"""A compiler pass that fuses transpose + dequantize.""" +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTranspose") +class FuseDecodeTranspose: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + dequantize.""" + + def __init__(self, skip_gemm: bool) -> None: + self.skip_gemm = skip_gemm + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + return _DecodeTransposeFuser(mod, skip_gemm=self.skip_gemm).transform() + + +@mutator +class _DecodeTransposeFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__( + self, + mod: IRModule, + skip_gemm: bool, + ): + super().__init__(mod) + self.mod = mod + self.skip_gemm = skip_gemm + + def transform(self) -> IRModule: + """Entry point""" + for g_var, func in self.mod.functions.items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + call = self.visit_expr_post_order(call) + if call.op != tvm.ir.Op.get("relax.matmul"): + return call + # Do not fuse decode-transpose for GeMM + if self.skip_gemm and ( + call.args[0].struct_info.ndim < 2 + or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm) + or call.args[0].struct_info.shape[-2].value != 1 + ): + return call + + matmul_rhs = self.lookup_binding(call.args[1]) + if ( + not isinstance(matmul_rhs, relax.Call) + or matmul_rhs.op != tvm.ir.Op.get("relax.permute_dims") + or matmul_rhs.args[0].struct_info.ndim != 2 + or matmul_rhs.attrs.axes is not None + ): + return call + + transpose_input = self.lookup_binding(matmul_rhs.args[0]) + if ( + not isinstance(transpose_input, relax.Call) + or transpose_input.op != tvm.ir.Op.get("relax.call_tir") + or not transpose_input.args[0].name_hint.startswith("decode") + or not isinstance(transpose_input.struct_info, relax.TensorStructInfo) + ): + return call + + decode_tir_func = self.mod[transpose_input.args[0]] + assert isinstance(decode_tir_func, tir.PrimFunc) + if ( # pylint: disable=too-many-boolean-expressions + len(decode_tir_func.body.block.alloc_buffers) != 1 + or not isinstance(decode_tir_func.body.block.body, tir.SeqStmt) + or len(decode_tir_func.body.block.body) != 2 + or not isinstance(decode_tir_func.body.block.body[1], tir.For) + or not isinstance(decode_tir_func.body.block.body[1].body.body, tir.BlockRealize) + or decode_tir_func.body.block.body[1].body.body.block.name_hint != "T_transpose" + ): + return call + + new_func_buffers = [decode_tir_func.buffer_map[var] for var in decode_tir_func.params] + new_func_buffers[-1] = decode_tir_func.body.block.alloc_buffers[0] + new_func = tir.PrimFunc( + params=new_func_buffers, + body=tir.BlockRealize( + iter_values=[], + predicate=True, + block=tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=decode_tir_func.body.block.body[0], + ), + ), + ) + # Call `renew_defs` for deep-copy to avoid IR node duplication in + # different PrimFuncs of an IRModule. + new_func = tir.stmt_functor.renew_defs(new_func) + g_var = self.builder_.add_func(new_func, func_name="decode") + decoded_matmul_rhs = self.builder_.emit( + relax.call_tir(g_var, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info) + ) + return relax.op.matmul(call.args[0], decoded_matmul_rhs, out_dtype=call.attrs.out_dtype) diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py b/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py new file mode 100644 index 0000000000..ac1de41377 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py @@ -0,0 +1,153 @@ +"""A compiler pass that fuses transpose + matmul.""" +import tvm +from tvm import IRModule, relax, te, tir +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="FuseTransposeMatmul") +class FuseTransposeMatmul: # pylint: disable=too-few-public-methods + """A compiler pass that fuses transpose + matmul.""" + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """IRModule-level transformation""" + mod = relax.transform.FuseOpsByPattern( + [ + ( + "transpose_matmul_fuse", + *_pattern(), + ), + ] + )(mod) + + transpose_matmul_codegen = _TransposeMatmulFuser(mod) + for g_var in mod.functions: + func = mod[g_var] + if isinstance(func, relax.Function): + func = transpose_matmul_codegen.visit_expr(func) + transpose_matmul_codegen.builder_.update_func(g_var, func) + return transpose_matmul_codegen.builder_.get() + + +def _pattern(): + """Pattern for transpose + matmul.""" + # pylint: disable=invalid-name + w = wildcard() + x = wildcard() + wT = is_op("relax.permute_dims")(w) + o = is_op("relax.matmul")(x, wT) + # pylint: enable=invalid-name + annotations = {"o": o, "w": w, "x": x, "wT": wT} + + def _check(context: relax.transform.PatternCheckContext) -> bool: + transpose_call = context.annotated_expr["wT"] + ndim = transpose_call.args[0].struct_info.ndim + if ndim == -1: + return False + if ndim == 2 and transpose_call.attrs.axes is None: + return True + axes = list(range(ndim)) + axes[-1], axes[-2] = axes[-2], axes[-1] + return list(transpose_call.attrs.axes) == axes + + return o, annotations, _check + + +# pylint: disable=missing-docstring,invalid-name + + +@mutator +class _TransposeMatmulFuser(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod): + super().__init__(mod) + + def visit_call_( # pylint: disable=arguments-renamed + self, + call: relax.Call, + ) -> relax.Expr: + out_dtype = None + + def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor: + nonlocal out_dtype + a_shape = list(a.shape) + b_shape = list(b.shape) + a_prepended = False + b_appended = False + if len(a_shape) == 1: + a_prepended = True + a_shape.insert(0, 1) + if len(b_shape) == 1: + b_appended = True + b_shape.append(1) + + is_a_larger = len(a_shape) > len(b_shape) + offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape) + + a_relax = relax.Var("a", relax.TensorStructInfo(a.shape)) + bT_shape = list(b.shape) + bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1] + bT_relax = relax.Var("b", relax.TensorStructInfo(bT_shape)) + output_shape = self.builder_.normalize( + relax.op.matmul(a_relax, bT_relax) + ).struct_info.shape + + def matmul_compute(*idx_spatial): + k = te.reduce_axis((0, a_shape[-1]), name="k") + + def multiply_compute(idx_reduce): + a_indices = [] + b_indices = [] + + for i in range(offset): + if is_a_larger: + a_indices.append(idx_spatial[i]) + else: + b_indices.append(idx_spatial[i]) + for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)): + a_dim = a_shape[i if is_a_larger else i - offset] + b_dim = b_shape[i if not is_a_larger else i - offset] + dim_equal = a_dim == b_dim + if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0: + a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1 + b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1 + a_indices.append(0 if a_dim_is_one else idx_spatial[i]) + b_indices.append(0 if b_dim_is_one else idx_spatial[i]) + else: + a_indices.append(idx_spatial[i]) + b_indices.append(idx_spatial[i]) + + if not a_prepended: + a_indices.append(idx_spatial[-2 + b_appended]) + a_indices.append(idx_reduce) + if not b_appended: + b_indices.append(idx_spatial[-1]) + b_indices.append(idx_reduce) + + dtype = out_dtype + if dtype != "": + return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype) + return a(*a_indices) * b(*b_indices) + + return te.sum(multiply_compute(k), axis=k) + + return te.compute( + output_shape, + lambda *idx: matmul_compute(*idx), # pylint: disable=unnecessary-lambda + name="NT_matmul", + ) + + if isinstance(call.op, relax.GlobalVar): + function = self.builder_.get()[call.op] + if ( + "Composite" in function.attrs + and function.attrs["Composite"] == "transpose_matmul_fuse" + ): + out_dtype = function.ret_struct_info.dtype + return self.builder_.call_te( + te_transposed_matmul, + call.args[1], + call.args[0], + primfunc_name_hint="NT_matmul", + ) + + return super().visit_call_(call) diff --git a/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py new file mode 100644 index 0000000000..dc8eaa5bdc --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py @@ -0,0 +1,196 @@ +"""A compiler pass that lifts TIR-level global allocation to Relax.""" +from typing import Dict, List, Tuple + +import tvm +from tvm import relax, tir +from tvm.ir.module import IRModule +from tvm.relax.analysis import remove_all_unused +from tvm.relax.expr_functor import PyExprMutator, mutator + + +@tvm.transform.module_pass(opt_level=0, name="LiftTIRGlobalBufferAlloc") +class LiftTIRGlobalBufferAlloc: # pylint: disable=too-few-public-methods + """A compiler pass that lifts TIR-level global allocation to Relax.""" + + def transform_module( + self, + mod: IRModule, + _ctx: tvm.transform.PassContext, + ) -> IRModule: + """IRModule-level transformation""" + return _TIRGlobalAllocRewriter(mod).transform() + + +@mutator +class _TIRGlobalAllocRewriter(PyExprMutator): # pylint: disable=abstract-method + def __init__(self, mod: IRModule): + super().__init__(mod) + self.mod = mod + self.gv2new_tensor_sinfo: Dict[ + tvm.ir.GlobalVar, Tuple[List[relax.TensorStructInfo], tir.PrimFunc] + ] = {} + + def transform(self) -> IRModule: + """Entry point of the transformation""" + for g_var, func in self.mod.functions.items(): + if isinstance(func, tir.PrimFunc): + updated_func, tensor_sinfo_list = remove_global_buf_alloc(func) + if len(tensor_sinfo_list) > 0: + self.gv2new_tensor_sinfo[g_var] = (tensor_sinfo_list, func) + self.builder_.update_func(g_var, updated_func) + + self.mod = self.builder_.get() + for g_var, func in self.mod.functions.items(): + if not isinstance(func, relax.Function): + continue + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) + return self.builder_.get() + + def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed + call = self.visit_expr_post_order(call) + if ( + call.op != tvm.ir.Op.get("relax.call_tir") + or call.args[0] not in self.gv2new_tensor_sinfo + ): + return call + + g_var = call.args[0] + tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var] + + assert len(call.sinfo_args) == 1 + if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo): + tensor_sinfo, success = _resolve_tir_var_mapping(func_before_update, call, tensor_sinfo) + if not success: + # Cannot resolve TIR var mapping. Fall back to no lifting. + self.builder_.update_func(g_var, func_before_update) + self.gv2new_tensor_sinfo.pop(g_var) + return call + + if isinstance(call.sinfo_args[0], relax.TensorStructInfo): + new_call = relax.Call( + call.op, + args=call.args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)], + attrs=call.attrs, + ) + emitted_tuple = self.builder_.emit(new_call) + return relax.TupleGetItem(emitted_tuple, 0) + assert isinstance(call.sinfo_args[0], relax.TupleStructInfo) + return relax.Call( + call.op, + args=call.args, + sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)], + attrs=call.attrs, + ) + + +def remove_global_buf_alloc( + func: tir.PrimFunc, +) -> Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]: + """Remove the global buffer allocation for a given TIR PrimFunc.""" + assert isinstance(func.body, tir.BlockRealize) + params = list(func.params) + buffer_map = dict(func.buffer_map) + tensor_sinfo = [] + alloc_buffers = [] + + insertion_point = len(params) + while params[insertion_point - 1].dtype != "handle": + insertion_point -= 1 + assert insertion_point >= 1 + + prev_root_block = func.body.block + for buf_alloc in func.body.block.alloc_buffers: + if buf_alloc.scope() == "global": + param = tir.Var("var_" + buf_alloc.name, "handle") + params.insert(insertion_point, param) + insertion_point += 1 + buffer_map[param] = buf_alloc + tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype)) + else: + alloc_buffers.append(buf_alloc) + + if len(tensor_sinfo) == 0: + return func, [] + + assert len(prev_root_block.iter_vars) == 0 + assert len(prev_root_block.reads) == 0 + assert len(prev_root_block.writes) == 0 + assert len(prev_root_block.match_buffers) == 0 + assert prev_root_block.name_hint == "root" + assert prev_root_block.init is None + root_block = tir.Block( + iter_vars=[], + reads=[], + writes=[], + name_hint="root", + body=prev_root_block.body, + alloc_buffers=alloc_buffers, + annotations=prev_root_block.annotations, + ) + + updated_func = tir.PrimFunc( + params=params, + body=tir.BlockRealize(iter_values=[], predicate=True, block=root_block), + ret_type=func.ret_type, + buffer_map=buffer_map, + attrs=func.attrs, + ) + return updated_func, tensor_sinfo + + +def _has_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool: + assert isinstance(tensor_sinfo.shape, relax.ShapeExpr) + for dim in tensor_sinfo.shape.values: + if not isinstance(dim, tir.IntImm): + return True + return False + + +def _resolve_tir_var_mapping( # pylint: disable=too-many-locals + func: tir.PrimFunc, + call: relax.Call, + tensor_sinfo: List[relax.TensorStructInfo], +) -> Tuple[List[relax.TensorStructInfo], bool]: + """Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function""" + var_map: Dict[tir.Var, tir.PrimExpr] = {} + + n_arg = len(call.args[1].fields) + for i in range(n_arg): + buffer_shape = func.buffer_map[func.params[i]].shape + arg_shape = call.args[1][i].struct_info.shape.values + assert len(buffer_shape) == len(arg_shape) + for v_l, v_r in zip(buffer_shape, arg_shape): + if isinstance(v_l, tir.Var): + var_map[v_l] = v_r + elif not isinstance(v_l, tir.IntImm): + return [], False + + ret_tensors = call.sinfo_args[0] + ret_tensors = ( + [ret_tensors] + if isinstance(ret_tensors, relax.TensorStructInfo) + else list(ret_tensors.fields) + ) + for i, ret_tensor in enumerate(ret_tensors): + buffer_shape = func.buffer_map[func.params[n_arg + i]].shape + ret_tensor_shape = ret_tensor.shape.values + assert len(buffer_shape) == len(ret_tensor_shape) + for v_l, v_r in zip(buffer_shape, ret_tensor_shape): + if isinstance(v_l, tir.Var): + var_map[v_l] = v_r + elif not isinstance(v_l, tir.IntImm): + return [], False + + updated_tensor_sinfo = [] + for sinfo in tensor_sinfo: + if not _has_symbolic_var(sinfo): + updated_tensor_sinfo.append(sinfo) + continue + new_shape = [] + for dim in sinfo.shape.values: + new_shape.append(tir.stmt_functor.substitute(dim, var_map)) + updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype)) + return updated_tensor_sinfo, True diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler/compiler_pass/pipeline.py new file mode 100644 index 0000000000..349a5af0f0 --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/pipeline.py @@ -0,0 +1,49 @@ +"""The compilation pipeline for LLM applications.""" +import tvm +from tvm import dlight as dl +from tvm.relax import register_pipeline # pylint: disable=no-name-in-module + +from .clean_up_tir_attrs import CleanUpTIRAttrs +from .fuse_decode_matmul_ewise import FuseDecodeMatmulEwise +from .fuse_decode_take import FuseDecodeTake +from .fuse_decode_transpose import FuseDecodeTranspose +from .fuse_transpose_matmul import FuseTransposeMatmul +from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc + + +@register_pipeline("mlc_llm") +def _mlc_llm_pipeline(): + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + seq = tvm.transform.Sequential( + [ + # Phase 1. Passes on high-level operator graph + FuseDecodeTranspose(skip_gemm=False), + FuseTransposeMatmul(), + # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline + tvm.relax.transform.LegalizeOps(), + tvm.relax.transform.AnnotateTIROpPattern(), + tvm.relax.transform.FoldConstant(), + tvm.relax.transform.FuseOps(), + tvm.relax.transform.FuseTIR(), + # Phase 3. Passes on TIR + FuseDecodeMatmulEwise(), + FuseDecodeTake(), + tvm.relax.transform.DeadCodeElimination(), + CleanUpTIRAttrs(["op_pattern"]), + # Phase 4. Low-level Optimizations + dl.ApplyDefaultSchedule( + dl.gpu.Matmul(), + dl.gpu.GEMV(), + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + ), + LiftTIRGlobalBufferAlloc(), + tvm.tir.transform.ForceNarrowIndexToInt32(), + ] + ) + mod = seq(mod._move()) # pylint: disable=protected-access + return mod + + return _pipeline diff --git a/python/mlc_chat/compiler/flags_optimization.py b/python/mlc_chat/compiler/flags_optimization.py new file mode 100644 index 0000000000..704903b419 --- /dev/null +++ b/python/mlc_chat/compiler/flags_optimization.py @@ -0,0 +1,77 @@ +"""Optimization flags""" +import argparse +import dataclasses +from io import StringIO + + +@dataclasses.dataclass +class OptimizationFlags: + """Optiization flags""" + + cutlass_attn: bool = True + cutlass_norm: bool = True + cublas_gemm: bool = False + cudagraph: bool = False + + def __repr__(self) -> str: + out = StringIO() + print(f"cutlass_attn={int(self.cutlass_attn)}", file=out, end="") + print(f";cutlass_norm={int(self.cutlass_norm)}", file=out, end="") + print(f";cublas_gemm={int(self.cublas_gemm)}", file=out, end="") + print(f";cudagraph={int(self.cudagraph)}", file=out, end="") + return out.getvalue().rstrip() + + @staticmethod + def from_str(source: str) -> "OptimizationFlags": + """Parse optimization flags from a string.""" + + if source in OPT_FLAG_PRESET: + return OPT_FLAG_PRESET[source] + + def boolean(value: str) -> bool: + if value == "0": + return False + if value == "1": + return True + raise ValueError(f"Invalid boolean value: {value}") + + parser = argparse.ArgumentParser(description="optimization flags") + parser.add_argument("--cutlass_attn", type=boolean, default=True) + parser.add_argument("--cutlass_norm", type=boolean, default=True) + parser.add_argument("--cublas_gemm", type=boolean, default=False) + parser.add_argument("--cudagraph", type=boolean, default=False) + results = parser.parse_args([f"--{i}" for i in source.split(";") if i]) + return OptimizationFlags( + cutlass_attn=results.cutlass_attn, + cutlass_norm=results.cutlass_norm, + cublas_gemm=results.cublas_gemm, + cudagraph=results.cudagraph, + ) + + +OPT_FLAG_PRESET = { + "O0": OptimizationFlags( + cutlass_attn=False, + cutlass_norm=False, + cublas_gemm=False, + cudagraph=False, + ), + "O1": OptimizationFlags( + cutlass_attn=False, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=False, + ), + "O2": OptimizationFlags( + cutlass_attn=True, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=False, + ), + "O3": OptimizationFlags( + cutlass_attn=True, + cutlass_norm=True, + cublas_gemm=False, + cudagraph=True, + ), +} diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py index dbf360c31d..a263ba0c4d 100644 --- a/python/mlc_chat/compiler/model/llama_quantization.py +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -53,13 +53,13 @@ def group_quantize( weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore param_tensor, config ) - s = tvm.te.create_schedule( + s = tvm.te.create_schedule( # pylint: disable=invalid-name [compute.op for compute in [weight_compute, scale_compute] + other_computes] ) if target.kind.name == "cuda": # thread_binding for cuda for compute in [weight_compute, scale_compute] + other_computes: - xo, xi = s[compute].split(compute.op.axis[0], 256) + xo, xi = s[compute].split(compute.op.axis[0], 256) # pylint: disable=invalid-name s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) f_quantize = tvm.build( diff --git a/python/mlc_chat/compiler/quantization/group_quantizer.py b/python/mlc_chat/compiler/quantization/group_quantizer.py index 418617dd70..b95c946abd 100644 --- a/python/mlc_chat/compiler/quantization/group_quantizer.py +++ b/python/mlc_chat/compiler/quantization/group_quantizer.py @@ -14,9 +14,9 @@ def te_quantize( """Group quantization for weight tensor, defined in tensor expression.""" # pylint: disable=too-many-locals assert len(weight.shape) == 2 - n, m = weight.shape + n, m = weight.shape # pylint: disable=invalid-name # compute scale per group - r = te.reduce_axis((0, config.group_size), name="r") + r = te.reduce_axis((0, config.group_size), name="r") # pylint: disable=invalid-name num_group = tir.ceildiv(m, config.group_size) scale_shape = (n, num_group) max_abs = te.compute( @@ -53,7 +53,7 @@ def te_quantize( ) # compute quantized weight per storage - r = te.reduce_axis((0, config.num_elem_per_storage), name="r") + r = te.reduce_axis((0, config.num_elem_per_storage), name="r") # pylint: disable=invalid-name num_storage = config.num_storage_per_group * num_group quantized_weight_shape = (n, num_storage) quantized_weight = te.compute( diff --git a/tests/python/parameter/test_group_quantizer.py b/tests/python/parameter/test_group_quantizer.py index b0e4b6522f..4c16548b64 100644 --- a/tests/python/parameter/test_group_quantizer.py +++ b/tests/python/parameter/test_group_quantizer.py @@ -73,7 +73,7 @@ def test_group_quantize_vs_numpy(): # pylint: disable=unused-variable def group_quantize_np( - w: NDArray, + w: NDArray, # pylint: disable=invalid-name quantize_dtype: str = "int4", storage_dtype: str = "uint32", group_size: int = 32, @@ -90,7 +90,7 @@ def _pad_axis_by_factor(tensor: np.ndarray, axis: int, factor: int) -> np.ndarra return np.pad(tensor, pad_width, mode="constant", constant_values=0) def _clip( - x: np.ndarray, + x: np.ndarray, # pylint: disable=invalid-name x_min: int, x_max: int, dtype: str, @@ -131,7 +131,7 @@ def _clip( res = np.zeros((n, k, num_storage_units), dtype=np.uint32) for i in range(n): for j in range(k): - for m in range(num_storage_units): + for m in range(num_storage_units): # pylint: disable=invalid-name for k in range(num_elem_per_storage): res[i, j, m] += w[i, j, m * num_elem_per_storage + k] * 2**k return tvm.nd.array(res), tvm.nd.array(scale) From 1a79a5381b7a3580b59d1305b43b1cd0d53a049a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 29 Oct 2023 21:51:36 -0700 Subject: [PATCH 064/116] Compile Model Preset without External `config.json` (#1151) This PR adds support for compiling a preset of models without having to provide a `config.json` on disk using the commands below: ```diff python -m mlc_chat.cli.compile \ --quantization q4f16_1 -o /tmp/1.so \ - --config /models/Llama-2-7b-chat-hf + --config llama2_7b ``` This allows easier testing and binary distribution without having to depend on external model directory. --- python/mlc_chat/cli/compile.py | 2 +- python/mlc_chat/compiler/__init__.py | 9 +++---- python/mlc_chat/compiler/model/__init__.py | 2 +- python/mlc_chat/compiler/model/model.py | 2 ++ python/mlc_chat/support/auto_config.py | 30 +++++++++++++++++----- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 8a41ab5bbb..31b639a68f 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -27,7 +27,7 @@ def main(): def _parse_config(path: Union[str, Path]) -> Path: try: - return detect_config(Path(path)) + return detect_config(path) except ValueError as err: raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index 2aa7dfcf98..4905e8ac91 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -3,11 +3,8 @@ but users could optionally import it if they want to use the compiler. """ from . import compiler_pass -from .compile import ( # pylint: disable=redefined-builtin - CompileArgs, - OptimizationFlags, - compile, -) -from .model import MODELS, Model +from .compile import CompileArgs, compile # pylint: disable=redefined-builtin +from .flags_optimization import OptimizationFlags +from .model import MODEL_PRESETS, MODELS, Model from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping from .quantization import QUANT diff --git a/python/mlc_chat/compiler/model/__init__.py b/python/mlc_chat/compiler/model/__init__.py index 8bb4879e7d..a42dda9a09 100644 --- a/python/mlc_chat/compiler/model/__init__.py +++ b/python/mlc_chat/compiler/model/__init__.py @@ -1,2 +1,2 @@ """Model definition for the compiler.""" -from .model import MODELS, Model +from .model import MODEL_PRESETS, MODELS, Model diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index 8fd041ef32..3027a39500 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -61,3 +61,5 @@ class Model: quantize={}, ) } + +MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 165c0a0f20..2ff650dec4 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -1,8 +1,9 @@ """Help function for detecting the model configuration file `config.json`""" import json import logging +import tempfile from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from .style import green @@ -14,25 +15,42 @@ FOUND = green("Found") -def detect_config(config_path: Path) -> Path: - """Detect and return the path that points to config.json. If config_path is a directory, +def detect_config(config: Union[str, Path]) -> Path: + """Detect and return the path that points to config.json. If `config` is a directory, it looks for config.json below it. Parameters --------- - config_path : pathlib.Path - The path to config.json or the directory containing config.json. + config : Union[str, pathlib.Path] + The preset name of the model, or the path to `config.json`, or the directory containing + `config.json`. Returns ------- config_json_path : pathlib.Path The path points to config.json. """ + from mlc_chat.compiler import ( # pylint: disable=import-outside-toplevel + MODEL_PRESETS, + ) + + if isinstance(config, str) and config in MODEL_PRESETS: + content = MODEL_PRESETS[config] + temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with + suffix=".json", + delete=False, + ) + logger.info("%s preset model configuration: %s", FOUND, temp_file.name) + config_path = Path(temp_file.name) + with config_path.open("w", encoding="utf-8") as config_file: + json.dump(content, config_file, indent=2) + else: + config_path = Path(config) if not config_path.exists(): raise ValueError(f"{config_path} does not exist.") if config_path.is_dir(): - # search config.json under config_path + # search config.json under config path config_json_path = config_path / "config.json" if not config_json_path.exists(): raise ValueError(f"Fail to find config.json under {config_path}.") From ba678358d35a2548cca2c854b1e24d42c7e8f43e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 29 Oct 2023 23:54:12 -0700 Subject: [PATCH 065/116] Update attention layer (#1153) Existing dlight optimization only works for NT matmul, but not NN. As a result, the new `nn.Module`-based implementation, which uses NN matmul, fails compilation at HEAD for now. This PR fixes this issue by tweaking `k` to the preferred layout. The following commands now work with the new compilation pipeline: ```bash python -m mlc_chat.cli.compile --config llama2_7b --quantization q4f16_1 -o /tmp/1.so python -m mlc_chat.cli.compile --config llama2_13b --quantization q4f16_1 -o /tmp/1.so python -m mlc_chat.cli.compile --config llama2_70b --quantization q4f16_1 -o /tmp/1.so ``` Note that the quantization algorithm per se, `q4f16_1`, has not been implemented yet, meaning this code path is not yet ready for use so far. --- python/mlc_chat/compiler/model/llama_model.py | 17 ++++++++--------- python/mlc_chat/support/auto_config.py | 3 ++- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 6bf7647ff1..1106b38c56 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -95,14 +95,16 @@ def forward( # pylint: disable=too-many-locals self.k_cache.append(op.squeeze(k, axis=0)) self.v_cache.append(op.squeeze(v, axis=0)) - k = op.reshape(self.k_cache.view(total_seq_len), (t, b, h_kv, d)) - v = op.reshape(self.v_cache.view(total_seq_len), (t, b, h_kv, d)) + k = op.reshape(self.k_cache.view(total_seq_len), (b, t, h_kv, d)) + v = op.reshape(self.v_cache.view(total_seq_len), (b, t, h_kv, d)) if h_kv != h_q: k = k.repeat(h_q // h_kv, axis=2) v = v.repeat(h_q // h_kv, axis=2) - attn_weights = op.matmul( # [b, h, s, t] - q.permute_dims([0, 2, 1, 3]), # [b, h, s, d] - k.permute_dims([1, 2, 3, 0]), # [b, h, d, t] + q = q.permute_dims([0, 2, 1, 3]) # [b, h, s, d] + k = k.permute_dims([0, 2, 1, 3]) # [b, h, t, d] + v = v.permute_dims([0, 2, 1, 3]) # [b, h, t, d] + attn_weights = op.matmul( + q, k.permute_dims([0, 1, 3, 2]) # [b, h, s, d] x [b, h, d, t] = [b, h, s, t] ) / math.sqrt(d) dtype = attn_weights.dtype attn_weights = attn_weights.maximum(tir.min_value(dtype)).minimum(attention_mask) @@ -111,10 +113,7 @@ def forward( # pylint: disable=too-many-locals else: attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) return self.o_proj( - op.matmul( # [b, h, s, d] - attn_weights, # [b, h, s, t] - v.permute_dims([1, 2, 0, 3]), # [b, h, t, d] - ) + op.matmul(attn_weights, v) # [b, h, s, t] x [b, h, t, d] = [b, h, s, d] .permute_dims([0, 2, 1, 3]) # [b, s, h, d] .reshape((b, s, h_q * d)) ) diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 2ff650dec4..61a84b4041 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -35,12 +35,13 @@ def detect_config(config: Union[str, Path]) -> Path: ) if isinstance(config, str) and config in MODEL_PRESETS: + logger.info("%s preset model: %s", FOUND, config) content = MODEL_PRESETS[config] temp_file = tempfile.NamedTemporaryFile( # pylint: disable=consider-using-with suffix=".json", delete=False, ) - logger.info("%s preset model configuration: %s", FOUND, temp_file.name) + logger.info("Dumping config to: %s", temp_file.name) config_path = Path(temp_file.name) with config_path.open("w", encoding="utf-8") as config_file: json.dump(content, config_file, indent=2) From fee2cb543459aaa65b95e50fcc6c0a3d70a74419 Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 31 Oct 2023 01:32:06 +0900 Subject: [PATCH 066/116] Add batched Llama model definition using vLLM paged attention (#1134) * Add batched Llama model with vllm paged attention * update core.py * doc * minor * add e2e test * mv file * clean * Check if TVM has been built with USE_VLLM * update BuildArgs docstring --- examples/python/run_llama_batched_vllm.py | 448 +++++++++++++++ mlc_llm/core.py | 41 +- mlc_llm/relax_model/llama.py | 362 ++++++------ mlc_llm/relax_model/llama_batched_vllm.py | 661 ++++++++++++++++++++++ 4 files changed, 1347 insertions(+), 165 deletions(-) create mode 100644 examples/python/run_llama_batched_vllm.py create mode 100644 mlc_llm/relax_model/llama_batched_vllm.py diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py new file mode 100644 index 0000000000..a290eb892c --- /dev/null +++ b/examples/python/run_llama_batched_vllm.py @@ -0,0 +1,448 @@ +import argparse +import math +import os +import json +from collections import defaultdict +from typing import List +from dataclasses import dataclass + +import numpy as np + +import tvm +from tvm import relax +from tvm.runtime import disco as di + +import torch +from transformers import AutoTokenizer + +from mlc_llm.relax_model.llama import LlamaConfig +from mlc_llm import utils + + +class KVCache: + def __init__(self, num_blocks, block_size, num_layers, num_heads, head_size, disco_session): + if disco_session: + init_cache_func = disco_session.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + else: + init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache") + + self.cache = init_cache_func(head_size, num_layers, num_heads, block_size, num_blocks) + + self.block_tables = defaultdict(list) + self.slot_mappings = defaultdict(list) + self.block_size = block_size + + +class CacheManager: + block_size: int = 16 + + def __init__( + self, num_blocks, num_layers, num_heads, head_size, disco_session=None, sliding_window=None + ): + self.num_blocks = num_blocks + self.free_blocks = list(range(num_blocks)) + self.kv_cache = KVCache( + num_blocks, self.block_size, num_layers, num_heads, head_size, disco_session + ) + + if sliding_window: + assert sliding_window % self.kv_cache.block_size == 0 + self.block_sliding_window = sliding_window // self.kv_cache.block_size + else: + self.block_sliding_window = None + + def set_size(self, request_ids: List[int], target_sizes: List[int]): + for id, size in zip(request_ids, target_sizes): + num_needed_block = math.ceil(size / self.block_size) + + if self.block_sliding_window: + num_needed_block = min(num_needed_block, self.block_sliding_window) + + if id in self.kv_cache.block_tables and size == 0: + self.free_blocks.extend(self.kv_cache.block_tables[id]) + del self.kv_cache.block_tables[id] + del self.kv_cache.slot_mappings[id] + + elif id in self.kv_cache.block_tables: + # Decoding + if len(self.kv_cache.block_tables[id]) < num_needed_block: + # Need to allocate a new block for this request + assert len(self.kv_cache.block_tables[id]) + 1 == num_needed_block + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + pos = size - 1 + block_number = self.kv_cache.block_tables[id][-1] + + if self.block_sliding_window: + block_number = self.kv_cache.block_tables[id][ + (pos // self.block_size) % self.block_sliding_window + ] + else: + block_number = self.kv_cache.block_tables[id][-1] + + block_offset = pos % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + elif id not in self.kv_cache.block_tables: + assert len(self.free_blocks) >= num_needed_block, "Not enough free blocks." + + for _ in range(num_needed_block): + self.kv_cache.block_tables[id].append(self.free_blocks.pop()) + + for i in range(size): + block_idx = i // self.block_size + + if self.block_sliding_window: + block_idx %= self.block_sliding_window + + block_number = self.kv_cache.block_tables[id][block_idx] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + self.kv_cache.slot_mappings[id].append(slot) + + def get(self): + return self.kv_cache + + +@dataclass +class SequenceGenerationRequest: + request_id: int + token_ids: List[int] + + +@dataclass +class SequenceGenerationResponse: + request_id: int + token_id: int + + +def sample(logits): + logits = torch.from_dlpack(logits) + return torch.argmax(logits, -1).cpu().numpy() + + +def load_params_disco(artifact_path, lib_path, num_shards): + sess = di.ProcessSession(num_workers=num_shards) + devices = range(num_shards) + sess.init_ccl("nccl", *devices) + module = sess.load_vm_module(lib_path) + + loader_create = sess.get_global_func("runtime.disco.ShardLoader") + metadata_path = os.path.join(artifact_path, "params", "ndarray-cache.json") + with open(metadata_path, "r", encoding="utf-8") as f: + ndarray_cache_metadata = f.read() + + loader = loader_create(metadata_path, ndarray_cache_metadata, "", module) + loader_load = sess.get_global_func("runtime.disco.ShardLoaderLoadAll") + params = loader_load(loader) + + return module, params, sess + + +def copy_to_worker_0(sess: di.Session, host_array): + x_array = sess.empty(host_array.shape, host_array.dtype) + sess.copy_to_worker_0(host_array, x_array) + return x_array + + +def get_tvm_model(artifact_path, model, quantization, num_shards, dev): + lib_path = os.path.join(artifact_path, f"{model}-{quantization}-cuda.so") + + if num_shards == 1: + ex = tvm.runtime.load_module(lib_path) + vm = relax.VirtualMachine(ex, dev) + params = utils.load_params(artifact_path, dev) + return vm.module, params, None + + return load_params_disco(artifact_path, lib_path, num_shards) + + +def _prepare_inputs( + requests, + all_slot_mappings, + all_block_tables, + sliding_window, + dev, + is_prefill, +): + block_tables = [] + seq_lens = [] + input_ids = [] + slot_mapping = [] + positions = [] + max_num_blocks_per_seq = 0 + indices_within_window = [] + start_idx = 0 + + for request in requests: + request_id = request.request_id + token_ids = request.token_ids + + if is_prefill: + input_ids += token_ids + prompt_len = len(token_ids) + seq_lens.append(prompt_len) + positions += range(prompt_len) + slot_mapping += all_slot_mappings[request_id] + + if sliding_window: + indices_within_window += range( + start_idx + max(0, prompt_len - sliding_window), + start_idx + prompt_len, + ) + start_idx += prompt_len + + else: + input_ids.append(token_ids[-1]) + pos = len(token_ids) - 1 + positions.append(pos) + block_table = all_block_tables[request_id] + max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table)) + block_tables.append(block_table) + slot_mapping.append(all_slot_mappings[request_id][-1]) + + if sliding_window: + seq_lens.append(min(len(token_ids), sliding_window)) + else: + seq_lens.append(len(token_ids)) + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + if is_prefill and sliding_window: + indices_within_window = tvm.nd.array(np.array(indices_within_window, dtype="int32"), dev) + else: + indices_within_window = None + + if not is_prefill: + + def _pad_to_max(x: List[int], max_len: int) -> List[int]: + return x + [0] * (max_len - len(x)) + + padded_block_tables = [ + _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in block_tables + ] + + block_tables_np = np.vstack(padded_block_tables).astype("int32") + block_tables = tvm.nd.array(np.array(block_tables_np, dtype="int32"), dev) + else: + block_tables = None + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) + + +class Model: + def __init__( + self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window + ): + self.mod, self.params, self.disco_session = get_tvm_model( + artifact_path, model_name, quant, num_shards, dev + ) + self.dev = dev + self.vocab_size = vocab_size + self.sliding_window = sliding_window + + if sliding_window: + self.block_sliding_window = sliding_window // CacheManager.block_size + else: + self.block_sliding_window = None + + def generate( + self, requests: List[SequenceGenerationRequest], cache: KVCache, is_prefill: bool + ) -> List[SequenceGenerationResponse]: + ( + input_ids, + positions, + seq_lens, + slot_mapping, + indices_within_window, + block_tables, + ) = _prepare_inputs( + requests, + cache.slot_mappings, + cache.block_tables, + self.sliding_window, + self.dev, + is_prefill, + ) + + if self.disco_session: + input_ids = copy_to_worker_0(self.disco_session, input_ids) + positions = copy_to_worker_0(self.disco_session, positions) + seq_lens = copy_to_worker_0(self.disco_session, seq_lens) + slot_mapping = copy_to_worker_0(self.disco_session, slot_mapping) + + kv_cache = cache.cache + + if is_prefill: + if self.sliding_window: + if self.disco_session: + indices_within_window = copy_to_worker_0( + self.disco_session, indices_within_window + ) + + out = self.mod["prefill"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + indices_within_window, + self.params, + ) + else: + out = self.mod["prefill"]( + input_ids, positions, seq_lens, kv_cache, slot_mapping, self.params + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] # Ignore returned KV cache since it is updated in-place anyway. + else: + if self.disco_session: + block_tables = copy_to_worker_0(self.disco_session, block_tables) + + out = self.mod["decode"]( + input_ids, + positions, + seq_lens, + kv_cache, + slot_mapping, + block_tables, + self.params, + ) + + if self.disco_session: + logits, _ = out.debug_get_from_remote(0) + else: + logits = out[0] + + next_tokens = sample(logits) + + return [ + SequenceGenerationResponse(request.request_id, new_token) + for request, new_token in zip(requests, next_tokens) + ] + + +def parse_args(): + # Example + # python build.py --model vicuna-v1-7b --quantization q4f16_ft --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention + # python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q4f16_ft + # + # For Disco: + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --build-model-only --num-shards 2 + # python build.py --model vicuna-v1-7b --quantization q0f16 --use-cache=0 --max-seq-len 768 --enable-batching --use-vllm-attention --convert-weight-only + # CUDA_VISIBLE_DEVICES=0,1 python examples/python/run_llama_batched_vllm.py --local-id vicuna-v1-7b-q0f16 --num-shards 2 + + args = argparse.ArgumentParser() + args.add_argument("--local-id", type=str, required=True) + args.add_argument("--artifact-path", type=str, default="dist") + args.add_argument("--num-shards", type=int, default=1) + args.add_argument("--num-decode-steps", type=int, default=20) + parsed = args.parse_args() + parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) + utils.argparse_postproc_common(parsed) + parsed.artifact_path = os.path.join( + parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" + ) + return parsed + + +def run(args): + quantization = args.quantization.name + artifact_path = args.artifact_path + model_name = args.model + model_path = f"dist/models/{model_name}" + + dev = tvm.device("cuda", 0) + + with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f: + config = LlamaConfig(**json.load(i_f)) + + model = Model( + artifact_path, + model_name, + quantization, + config.vocab_size, + args.num_shards, + dev, + config.sliding_window, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + + num_kv_heads = config.get_num_key_value_heads() // args.num_shards + head_size = config.hidden_size // config.num_attention_heads + num_blocks = 500 + + cache_manager = CacheManager( + num_blocks, + config.num_hidden_layers, + num_kv_heads, + head_size, + model.disco_session, + sliding_window=config.sliding_window, + ) + cache = cache_manager.get() + + model.block_sliding_window = cache_manager.block_sliding_window + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + batched_token_ids = [tokenizer.encode(p) for p in prompts] + prompts_len = [len(ids) for ids in batched_token_ids] + request_ids = list(range(len(prompts))) + target_sizes = [] + requests = [] + + for token_ids, request_id in zip(batched_token_ids, request_ids): + request_ids.append(request_id) + target_sizes.append(len(token_ids)) + requests.append(SequenceGenerationRequest(request_id, token_ids)) + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, True) + + for _ in range(args.num_decode_steps): + for i, response in enumerate(out): + new_token_id = response.token_id + requests[i].token_ids.append(new_token_id) + target_sizes[i] += 1 + + cache_manager.set_size(request_ids, target_sizes) + + out = model.generate(requests, cache, False) + + output_tokens = [ + tokenizer.convert_ids_to_tokens( + requests[i].token_ids[prompts_len[i] :], skip_special_tokens=True + ) + for i in range(len(requests)) + ] + + generated = [tokenizer.convert_tokens_to_string(tokens) for tokens in output_tokens] + + for p, g in zip(prompts, generated): + print("Prompt = '{}', generated text = '{}'".format(p, g)) + + +if __name__ == "__main__": + run(parse_args()) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index e720d19542..0b7d1c8c39 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -22,6 +22,7 @@ gpt_neox, gptj, llama, + llama_batched_vllm, minigpt, param_manager, rwkv, @@ -96,7 +97,7 @@ class BuildArgs: Disable offloading layer and RMS norm operations to CUTLASS. no_cublas: bool Disable the step that offloads matmul to cuBLAS. Without this flag, - matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or + matmul will be offloaded to cuBLAS if quantization mode is ``q0f16`` or ``q0f32``, target is CUDA and TVM has been built with cuBLAS enabled. use_cuda_graph: bool Specifies whether to enable CUDA Graph for the decoder. MLP and QKV @@ -108,6 +109,8 @@ class BuildArgs: Offload multi-query attention workload to Flash Attention. pdb: bool If set, drop into a pdb debugger on error. + use_vllm_attention: bool + Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True. """ model: str = field( default="auto", @@ -279,6 +282,15 @@ class BuildArgs: "action": "store_true", }, ) + use_vllm_attention: bool = field( + default=False, + metadata={ + "help": ( + "Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True." + ), + "action": "store_true", + }, + ) def convert_build_args_to_argparser() -> argparse.ArgumentParser: @@ -315,6 +327,11 @@ def _parse_args(parsed) -> argparse.Namespace: utils.parse_target(parsed) utils.argparse_postproc_common(parsed) + if parsed.use_vllm_attention: + assert parsed.enable_batching, "--enable_batching is required for using vLLM attention." + assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA." + assert tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True), "TVM needs to be built with -DUSE_VLLM=ON." + parsed.artifact_path = os.path.join( parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" ) @@ -409,10 +426,19 @@ def mod_transform_before_build( model_names = [ "prefill", "decode", - "create_kv_cache", - "softmax_with_temperature", - "get_metadata", ] + + if not args.use_vllm_attention: + model_names += [ + "create_kv_cache", + "softmax_with_temperature", + "get_metadata", + ] + else: + # This is equivalent to prefill but without KV cache. It is used for + # determining the number of paged cache blocks that can be allocated. + model_names.append("evaluate") + if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] if args.enable_batching: @@ -427,7 +453,8 @@ def mod_transform_before_build( mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) if ( - hasattr(config, "num_attention_heads") + not args.enable_batching + and hasattr(config, "num_attention_heads") and hasattr(config, "hidden_size") and hasattr(config, "position_embedding_base") and getattr(config, "dtype", "float16") == "float16" @@ -649,6 +676,10 @@ def build_model_from_args(args: argparse.Namespace): "chatglm": chatglm, } + if args.use_vllm_attention: + model_generators["llama"] = llama_batched_vllm + model_generators["mistral"] = llama_batched_vllm + assert args.model_category in model_generators, f"Model {args.model} not supported" mod, param_manager, params, model_config = model_generators[args.model_category].get_model( diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index e45a4a3e20..8294313324 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -38,6 +38,7 @@ def __init__( combine_matmul=True, build_model_only=False, num_shards=1, + sliding_window=None, **kwargs, ): self.dtype = dtype @@ -57,6 +58,8 @@ def __init__( self.tie_word_embeddings = tie_word_embeddings self.position_embedding_base = position_embedding_base self.combine_matmul = combine_matmul + self.sliding_window = sliding_window + if build_model_only and num_shards > 1: self.num_shards = num_shards else: @@ -120,30 +123,50 @@ def f_rms_norm(x, weight): def f_square(x): return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x - k = te.reduce_axis((0, x.shape[2]), name="k") - square_sum = te.compute( - (x.shape[0], x.shape[1]), - lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), - name=x.op.name + "red_temp", - ) + def f_mul_cast(x, y): + value = x * y + if not is_float32: + value = tir.Cast(x.dtype, value) + return value - def f_div_cast(bsz, i, k): + def f_div_cast_2d(i, k): + x_val = x[i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[i] / x.shape[1] + self.variance_epsilon) + + def f_div_cast_3d(bsz, i, k): x_val = x[bsz, i, k] if not is_float32: x_val = tir.Cast("float32", x_val) return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) - def f_mul_cast(x, y): - value = x * y - if not is_float32: - value = tir.Cast(x.dtype, value) - return value + k = te.reduce_axis((0, x.shape[-1]), name="k") - return te.compute( - x.shape, - lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), - name="rms_norm", - ) + if len(x.shape) == 2: + square_sum = te.compute( + (x.shape[0],), + lambda i: te.sum(f_square(x[i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + return te.compute( + x.shape, + lambda i, k: f_mul_cast(weight(k), f_div_cast_2d(i, k)), + name="rms_norm", + ) + else: + square_sum = te.compute( + (x.shape[0], x.shape[1]), + lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + return te.compute( + x.shape, + lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast_3d(bsz, i, k)), + name="rms_norm", + ) return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") @@ -186,28 +209,36 @@ def forward(self, x): return result +def rotary_modulate_by_freq(tensor, idx, pos, position_embedding_base): + head_dim = tensor.shape[-1] + dtype = tensor.dtype + n_feat_half = head_dim // 2 + feat_idx = idx[-1] + inv_freq = te.const(1, "float32") / ( + te.power( + te.const(position_embedding_base, "float32"), + ((2 * feat_idx) % head_dim).astype("float32") / head_dim.astype("float32"), + ) + ) + freq = pos * inv_freq + left_indices = idx[:-1] + (feat_idx - n_feat_half,) + right_indices = idx[:-1] + (feat_idx + n_feat_half,) + return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype(dtype) * tvm.tir.Select( + feat_idx >= n_feat_half, + tensor[(*left_indices,)], + -tensor[(*right_indices,)], + ) + + def apply_rotary_pos_emb(q, k, position_embedding_base, offset: int = 0): def f_rotary_embedding(tensor, offset): - dtype = tensor.dtype - head_dim = tensor.shape[-1] - n_feat_half = tensor.shape[-1] // 2 - def rotary_compute(*idx): - i, j = idx[-3], idx[-1] - pos = (offset + i).astype("float32") - inv_freq = te.const(1, "float32") / ( - te.power( - te.const(position_embedding_base, "float32"), - ((2 * j) % head_dim).astype("float32") / head_dim.astype("float32"), - ) - ) - freq = pos * inv_freq - return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype( - dtype - ) * tvm.tir.Select( - j >= n_feat_half, - tensor[idx[0], i, idx[2], j - n_feat_half], - -tensor[idx[0], i, idx[2], j + n_feat_half], + pos = (offset + idx[-3]).astype("float32") + return rotary_modulate_by_freq( + tensor, + idx, + pos, + position_embedding_base, ) return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") @@ -268,18 +299,9 @@ def __init__(self, config: LlamaConfig): self.o_proj.weight.shard_dim = 1 self.o_proj.weight.shard_strategy = "shard_o_proj_k" - def forward( - self, - hidden_states: relax.Expr, - all_seq_len_shape: Optional[relax.Expr], - past_key_values: Union[relax.Expr, Tuple[relax.Expr]], - layer_id: int, - attention_mask: Optional[relax.Expr] = None, - ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: + def project_qkv(self, hidden_states, query_output_shape, kv_output_shape): from tvm.relax.op import reshape, split - bsz, q_len, _ = hidden_states.struct_info.shape - if self.combine_matmul: qkv_states = nn.emit( split( @@ -300,24 +322,35 @@ def forward( value_states = self.v_proj(hidden_states) query_states = nn.emit( - reshape( - query_states, - (bsz, q_len, self.num_query_heads, self.head_dim), - ), + reshape(query_states, query_output_shape), ) key_states = nn.emit( - reshape( - key_states, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), + reshape(key_states, kv_output_shape), ) value_states = nn.emit( - reshape( - value_states, - (bsz, q_len, self.num_key_value_heads, self.head_dim), - ), + reshape(value_states, kv_output_shape), ) + return query_states, key_states, value_states + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: Optional[relax.Expr], + past_key_values: Union[relax.Expr, Tuple[relax.Expr]], + layer_id: int, + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Union[relax.Expr, Tuple[relax.Expr]]]: + bsz, q_len, _ = hidden_states.struct_info.shape + + query_states, key_states, value_states = self.project_qkv( + hidden_states, + (bsz, q_len, self.num_query_heads, self.head_dim), + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ) + + from tvm.relax.op import reshape + attn_output, past_key_values = self.attention_fwd( query_states, key_states, @@ -541,6 +574,29 @@ def __init__(self, config: LlamaConfig, enable_batching: bool): config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps ) + def post_self_attn(self, hidden_states, residual): + if self.self_attn.num_shards > 1: + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + return hidden_states + def forward( self, hidden_states: relax.Expr, @@ -561,25 +617,7 @@ def forward( all_seq_len_shape=all_seq_len_shape, layer_id=layer_id, ) - if self.self_attn.num_shards > 1: - residual = nn.emit( - residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.self_attn.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - if self.mlp.num_shards > 1: - residual = nn.emit( - residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) - ) - hidden_states = nn.emit(residual + hidden_states) - if self.mlp.num_shards > 1: - hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + hidden_states = self.post_self_attn(hidden_states, residual) return hidden_states, present_key_value @@ -1164,6 +1202,91 @@ def kv_cache_transpose_append( bb.add_func(relax.extern("attention_func"), "attention") +def setup_params(mod, param_manager, dtype, config, args): + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.get_num_key_value_heads() + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + qkv = np.concatenate([q, k, v], axis=0).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + gate, up = torch_params + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + device = tvm.cpu() + param_list = [None] * param_manager.nparam_to_load + + head_dim = config.hidden_size / config.num_attention_heads + inv_freq = 1.0 / ( + config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) + ) + + # The following cos/sin values can be removed but **are kept for compatibility issues**. + t = np.arange(2048, dtype=inv_freq.dtype) + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) + param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) + + return mod, param_manager, param_list, config + + def get_model(args, hf_config): model_name = args.model dtype = args.quantization.model_dtype @@ -1174,7 +1297,7 @@ def get_model(args, hf_config): raise ValueError("`sep_embed` is required when batching is enabled.") position_embedding_base = 10000 - max_position_embeddings = 2048 + if "rope_theta" in hf_config: position_embedding_base = hf_config["rope_theta"] @@ -1249,85 +1372,4 @@ def get_model(args, hf_config): if args.build_model_only: return mod, param_manager, None, config - def f_convert_pname_fwd(pname: str) -> List[str]: - if not config.combine_matmul: - return [pname] - - qkv_str = "query_key_value_proj" - gate_up_str = "gate_up_proj" - if qkv_str in pname: - return [ - pname.replace(qkv_str, "q_proj"), - pname.replace(qkv_str, "k_proj"), - pname.replace(qkv_str, "v_proj"), - ] - elif gate_up_str in pname: - return [ - pname.replace(gate_up_str, "gate_proj"), - pname.replace(gate_up_str, "up_proj"), - ] - else: - return [pname] - - def f_convert_param_bkwd(torch_pname: str, torch_param): - if not config.combine_matmul: - return [(torch_pname, torch_param.astype(dtype))] - - combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] - if any([name in torch_pname for name in combined_layers]): - return None - return [(torch_pname, torch_param.astype(dtype))] - - def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): - # Expected to enter this function only for the combined linear matmul weights. - # Other weights are supposed to be loaded in `f_convert_param_bkwd` since - # each other relax param has a unique corresponding torch param. - if not config.combine_matmul: - # When matmul combination is not turned on, each relax param has a unique - # corresponding torch param, and this function is not expected to be entered. - raise NotImplementedError( - "Matmul combination is not turned on, and the function " - "is not expected to be entered" - ) - hidden_size = config.hidden_size - head_dim = config.hidden_size // config.num_attention_heads - - if "query_key_value_proj" in relax_pname: - q_heads = config.num_attention_heads - kv_heads = config.get_num_key_value_heads() - q, k, v = torch_params - assert q.shape == (q_heads * head_dim, hidden_size) - assert k.shape == (kv_heads * head_dim, hidden_size) - assert v.shape == (kv_heads * head_dim, hidden_size) - qkv = np.concatenate([q, k, v], axis=0).astype(dtype) - return qkv - if "gate_up_proj" in relax_pname: - gate, up = torch_params - gate_up = np.concatenate([gate, up], axis=0).astype(dtype) - return gate_up - raise ValueError("Unexpected param loading") - - param_manager.set_param_loading_func( - args.model_path, - args.use_safetensors, - f_convert_pname_fwd, - f_convert_param_bkwd, - f_compute_relax_param, - ) - - device = tvm.cpu() - param_list = [None] * param_manager.nparam_to_load - - head_dim = config.hidden_size / config.num_attention_heads - inv_freq = 1.0 / ( - config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim) - ) - - # The following cos/sin values can be removed but **are kept for compatibility issues**. - t = np.arange(2048, dtype=inv_freq.dtype) - freqs = np.einsum("i,j->ij", t, inv_freq) - emb = np.concatenate((freqs, freqs), axis=-1) - param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) - param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) - - return mod, param_manager, param_list, config + return setup_params(mod, param_manager, dtype, config, args) diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py new file mode 100644 index 0000000000..2309bdd92e --- /dev/null +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -0,0 +1,661 @@ +from typing import Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, repeat, take +from tvm.relax.op.nn import attention_var_len +from tvm.relax.testing import nn +from tvm.ir import VDevice +from tvm.script import relax as R +from tvm.script.ir_builder import tir as T + +from ..quantization import QuantizationScheme +from .modules import ModuleList +from .param_manager import ParamManager +from .llama import ( + LlamaConfig, + Linear, + Embedding, + LlamaRMSNorm, + LlamaAttentionBase, + LlamaDecoderLayer, + get_param_quant_kind, + setup_params, + rotary_modulate_by_freq, +) + + +def apply_rotary_pos_emb(q, k, positions, position_embedding_base): + def f_rotary_embedding(tensor, pos_tensor): + def rotary_compute(*idx): + pos = pos_tensor[idx[0]].astype("float32") + return rotary_modulate_by_freq( + tensor, + idx, + pos, + position_embedding_base, + ) + + return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") + + q_embed = nn.emit_te(f_rotary_embedding, q, positions, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, positions, primfunc_name_hint="rotary_embedding") + return q_embed, k_embed + + +class LlamaAttentionBatched(LlamaAttentionBase): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config) + self.head_mapping = head_mapping # (num_heads,), used by vLLM for multi-query attention + self.sliding_window = None + + if config.sliding_window: + self.sliding_window = T.IntImm("int32", config.sliding_window) + + def forward( + self, + hidden_states: relax.Expr, # (num_token, hidden_size) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], + slot_mapping: Optional[relax.Expr], # (num_token,) + max_seqlen: Optional[relax.Expr], # (), must be on CPU + seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention + ): + num_tokens, _ = hidden_states.struct_info.shape + + queries, keys, values = self.project_qkv( + hidden_states, + (num_tokens, self.num_query_heads, self.head_dim), + (num_tokens, self.num_key_value_heads, self.head_dim), + ) + + queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) + + if kv_cache: + # Paged KV cache update + k_cache, v_cache = kv_cache + + if self.sliding_window is None or block_tables: + # For decode or prefill without sliding window, cache all keys / values. + keys_to_cache = keys + values_to_cache = values + else: + # Cache only the most recent keys and values within the window. + keys_to_cache = nn.emit(take(keys, indices_within_window, axis=0)) + values_to_cache = nn.emit(take(values, indices_within_window, axis=0)) + slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0)) + + # kv caches are updated inplace, but make it look like a pure operation + kv = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reshape_and_cache", + keys_to_cache, + values_to_cache, + k_cache, + v_cache, + slot_mapping, + sinfo_args=[k_cache.struct_info, v_cache.struct_info], + ) + ) + + k_cache, v_cache = kv[0], kv[1] + else: + k_cache = v_cache = None + + if seqstart: + # Prefill, batched attention over variable sequence lengths + attn_output = nn.emit( + attention_var_len( + nn.emit(expand_dims(queries, axis=0)), + nn.emit(expand_dims(keys, axis=0)), + nn.emit(expand_dims(values, axis=0)), + seqstart_q=seqstart, + max_seqlen_q=max_seqlen, + causal_mask="BottomRight", + window_size=self.sliding_window, + ) + ) + else: + # Decode, using vLLM kernel + attn_output = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.vllm.single_query_cached_kv_attention", + [ + queries, + k_cache, + v_cache, + self.head_mapping, + block_tables, + seq_lens, + 16, # block_size + max_seqlen, + ], + out_sinfo=queries.struct_info, + ) + ) + + attn_output = nn.emit( + reshape(attn_output, (num_tokens, self.num_query_heads * self.head_dim)) + ) + attn_output = self.o_proj(attn_output) + + return attn_output, (k_cache, v_cache) + + +class LlamaDecoderLayerBatched(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, head_mapping: relax.Constant): + super().__init__(config, False) + self.self_attn = LlamaAttentionBatched(config, head_mapping) + + def forward( + self, + hidden_states: relax.Expr, + positions: relax.Expr, + seq_lens: relax.Expr, + kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], + slot_mapping: Optional[relax.Expr], + max_seqlen: Optional[relax.Expr], + seqstart: Optional[relax.Expr], + block_tables: Optional[relax.Expr], + indices_within_window: Optional[relax.Expr], + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, new_kv = self.self_attn( + hidden_states=hidden_states, + positions=positions, + seq_lens=seq_lens, + kv_cache=kv_cache, + slot_mapping=slot_mapping, + max_seqlen=max_seqlen, + seqstart=seqstart, + block_tables=block_tables, + indices_within_window=indices_within_window, + ) + + hidden_states = self.post_self_attn(hidden_states, residual) + + return hidden_states, new_kv + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + cpu_device: VDevice, + vocab_size_var: tvm.tir.Var, + sep_embed: bool = False, + ): + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + num_query_heads = config.num_attention_heads // config.num_shards + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + num_queries_per_kv = num_query_heads // num_key_value_heads + head_mapping = relax.const( + tvm.nd.array( + np.repeat(np.arange(num_key_value_heads, dtype="int32"), num_queries_per_kv) + ) + ) + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [ + LlamaDecoderLayerBatched(config, head_mapping) + for _ in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + + self.cpu_device = cpu_device + + def forward( + self, + inputs: relax.Expr, + positions: relax.Expr, + seq_lens: relax.Expr, + kv_caches: Optional[relax.Expr], + slot_mapping: Optional[relax.Expr], + seqstart: Optional[relax.Expr], + block_tables: Optional[relax.Expr], + indices_within_window: Optional[relax.Expr], + ): + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + + hidden_states = inputs_embeds + + # max_seqlen needs to be on CPU, so that vLLM and Flash Attention can directly get the + # integer length by max_seqlen->data[0]. Otherwise, we need to repeatedly do cudaMemcpy + # of a single int32. + max_seqlen = R.to_vdevice(R.max(seq_lens), self.cpu_device) + + new_kvs = () + + for idx, decoder_layer in enumerate(self.layers): + if kv_caches: + cache = (kv_caches[2 * idx], kv_caches[2 * idx + 1]) + else: + cache = None + + hidden_states, new_kv = decoder_layer( + hidden_states, + positions, + seq_lens, + cache, + slot_mapping, + max_seqlen, + seqstart, + block_tables, + indices_within_window, + ) + new_kvs += new_kv + + return self.norm(hidden_states), new_kvs + + +class LlamaForCausalLM(nn.Module): + def __init__( + self, + config: LlamaConfig, + cpu_device: VDevice, + vocab_size_var: tvm.tir.Var, + sep_embed: bool = False, + ): + self.num_shards = config.num_shards + self.model = LlamaModel(config, cpu_device, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + ############ Rotary embedding constants ############ + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads + + # Set the cached sin/cos to the maximum of 2048 and max seq len. + # This will be eliminated further with online rotary embedding calculation. + cache_len = te.var("cache_len", "int64") + self.cos_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="cos_cached") + self.sin_cached = nn.Parameter((cache_len, head_dim), dtype=config.dtype, name="sin_cached") + ############ End ############ + + def forward( + self, + input_ids: relax.Expr, # (num_token,) + positions: relax.Expr, # (num_token,), for batched RoPE + seq_lens: relax.Expr, # (num_seq,) + kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate + slot_mapping: Optional[ + relax.Expr + ], # (num_token,), for prefill and decode, not needed for evaluate + block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode + indices_within_window: Optional[ + relax.Expr + ], # (num_cached_total,), for prefill with sliding-window attention + ): + """ + In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other + for values. The tensor has shape (num_blocks, num_kv_heads, head_size, block_size). + (In practice, the key cache has a slightly different shape for an efficiency reason, + but that's not important.) + + The mapping between sequences / tokens to blocks is specified by two inputs. + - block_tables: A list of block IDs allocated for the sequence. + - slot_mapping: A linear index into the 2D grid (num_blocks, block_size), for each token. + + Support for sliding-window attention is realized by making a block table a circular buffer. + So the length of a block table for each sequence is at most ceil(window_size / block_size). + + With sliding window, not all past K / V values need to be cached during prefill. + The last input, indices_within_window, tells which tokens among (num_token,) need to have + their K / V values cached. + """ + if self.num_shards > 1: + input_ids = nn.emit(ccl.broadcast_from_worker0(input_ids)) + positions = nn.emit(ccl.broadcast_from_worker0(positions)) + seq_lens = nn.emit(ccl.broadcast_from_worker0(seq_lens)) + + if slot_mapping: + slot_mapping = nn.emit(ccl.broadcast_from_worker0(slot_mapping)) + + if block_tables: + block_tables = nn.emit(ccl.broadcast_from_worker0(block_tables)) + + if indices_within_window: + indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) + + is_prompt = block_tables is None + + if is_prompt: # prefill and evaluate + # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust + cumsum = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + ) + seqstart = nn.emit(concat([zeros((1,), "int32"), cumsum])) + else: + seqstart = None + + hidden_states, new_kvs = self.model( + input_ids, + positions, + seq_lens, + kv_caches, + slot_mapping, + seqstart, + block_tables, + indices_within_window, + ) + + if is_prompt: + # Extract logits for the last token in each sequence + + def get_logits_last_tokens(x, seq_len_tensor, seqstart): + return te.compute( + shape=(seq_len_tensor.shape[0], x.shape[-1]), + fcompute=lambda i, j: x[seqstart[i] + seq_len_tensor[i] - 1, j], + name="get_logits_last_tokens", + ) + + logits = self.lm_head( + nn.emit_te( + get_logits_last_tokens, + hidden_states, + seq_lens, + seqstart, + primfunc_name_hint="get_logits_last_tokens", + ) + ) + else: + logits = self.lm_head(hidden_states) + + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, new_kvs + + +def get_inputs( + num_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True +): + hidden_size = config.hidden_size + + inputs = ( + nn.Placeholder((num_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((num_token,), dtype="int32", name="input_ids") + ) + + seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") + positions = nn.Placeholder((num_token,), dtype="int32", name="positions") + + if need_cache: + num_blocks = tvm.tir.Var("num_blocks", "int64") + block_size = 16 + + vec_size = 8 # 128 bit, fp16 x 8 + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + head_size = hidden_size // config.num_attention_heads + + k_cache_shape = ( + num_blocks, + num_key_value_heads, + head_size // vec_size, + block_size, + vec_size, + ) + v_cache_shape = (num_blocks, num_key_value_heads, head_size, block_size) + + get_cache_sinfo = lambda i: relax.TensorStructInfo( + k_cache_shape if i % 2 == 0 else v_cache_shape, dtype="float16" + ) + + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [get_cache_sinfo(i) for i in range(config.num_hidden_layers * 2)] + ), + ) + slot_mapping = nn.Placeholder((num_token,), dtype="int32", name="slot_mapping") + else: + past_key_values = None + slot_mapping = None + block_tables = None + + if max_num_blocks_per_seq is None: + block_tables = None + else: + block_tables = nn.Placeholder( + (num_seq, max_num_blocks_per_seq), dtype="int32", name="block_tables" + ) + + return inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables + + +def create_evaluate_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" + func_name = "evaluate" + + num_token = tvm.tir.Var("num_token", "int64") + num_seq = tvm.tir.Var("num_seq", "int64") + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs, positions, seq_lens, _, _, _ = get_inputs( + num_token, num_seq, config, sep_embed=sep_embed + ) + + with bb.dataflow(): + logits, _ = model( + inputs, + positions, + seq_lens, + kv_caches=None, + slot_mapping=None, + block_tables=None, + indices_within_window=None, + ) + params = [ + inputs, + positions, + seq_lens, + ] + model.parameters() + gv = bb.emit_output(logits) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 3)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + """Batched prefill with vLLM paged KV cache. + + The batched attention op is intended to be offloaded to CUTLASS or Flash Attention + via BYOC. + """ + func_name = "prefill_with_embed" if sep_embed else "prefill" + + num_token = tvm.tir.Var("num_token", "int64") + num_seq = tvm.tir.Var("num_seq", "int64") + + num_inputs = 5 + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( + num_token, num_seq, config, sep_embed=sep_embed + ) + + with bb.dataflow(): + params = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + ] + + inputs = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + None, # block_tables + ] + + if config.sliding_window: + num_inputs += 1 + # The value of num_cached_total is between + # num_token (if seq_len < sliding_window for all seq) and + # num_seq * config.sliding_window (if seq_len > sliding_window for all seq) + num_cached_total = tvm.tir.Var("num_cached_total", "int64") + indices_within_window = nn.Placeholder( + (num_cached_total,), dtype="int32", name="indices_within_window" + ) + inputs.append(indices_within_window) + params.append(indices_within_window) + else: + inputs.append(None) + + logits, new_kvs = model(*inputs) + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + + bb.emit_func_output(gv, params + model.parameters()) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", num_inputs)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, +) -> None: + """Batched decoding with vLLM paged KV cache.""" + func_name = "decode" + + num_seq = tvm.tir.Var("num_seq", "int64") + max_num_blocks_per_seq = tvm.tir.Var("max_num_blocks_per_seq", "int64") + + with bb.function(func_name): + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables = get_inputs( + num_seq, num_seq, config, max_num_blocks_per_seq + ) + + with bb.dataflow(): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + logits, new_kvs = model( + inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables, None + ) + params = [ + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 6)) + + +def get_model(args, hf_config): + dtype = args.quantization.model_dtype + sep_embed = False + + position_embedding_base = 10000 + + if "rope_theta" in hf_config: + position_embedding_base = hf_config["rope_theta"] + + # Llama-2 variants use `max_position_embeddings` to encode maximum sequence length in their hf model cards, + # while Llama-1 variants use `max_sequence_length`. + # Thus, use `max_sequence_length` if defined. Otherwise, use `max_position_embeddings`. + # If none of them is defined, throw an error. + if "max_sequence_length" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + elif "max_position_embeddings" in hf_config: + config = LlamaConfig( + **hf_config, + dtype=dtype, + max_sequence_length=hf_config["max_position_embeddings"], + position_embedding_base=position_embedding_base, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + ) + else: + raise Exception( + "The model config should contain information about maximum sequence length." + ) + + # If there is a user-provided maximum sequence length, override hf config. + if args.max_seq_len != -1: + config.max_sequence_length = args.max_seq_len + + param_manager = ParamManager() + bb = relax.BlockBuilder() + + # The CPU device to copy the result of relax.op.max(seq_lens) to CPU. + cpu_dev = VDevice("llvm", 0, "global") + + create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) + create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) + + mod = bb.get() + + mod.update_global_info("vdevice", [cpu_dev]) + + if args.build_model_only: + return mod, param_manager, None, config + + return setup_params(mod, param_manager, dtype, config, args) From ece97b16916c28043072cf64ddd503f43b9e8691 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 13:58:08 -0500 Subject: [PATCH 067/116] [Transform][Redo] Apply split_rotary optimization on prefill (#1125) Prior to this commit, the `transform.fuse_split_rotary_embedding` function was only applicable to the `decode` function of a Llama-type model. This was due to the sequence length being restricted to one, both in the pattern-match rule and in the `split_rotary` function, and the function being restricted to operate only on the `decode` function. This commit updates the `transform.fuse_split_rotary_embedding` pass to be a `tvm.ir.transform.Pass`, operating on all applicable matched in the `IRModule`. The `split_rotary` function is now produced as a fully-generic function, with static parameters substituted in afterwards. At this stage, the sequence length is retained as a dynamic parameter, such that it can be used by the `prefill` function. This commit reapplies the reverted commit https://github.com/mlc-ai/mlc-llm/pull/1033. The error in the previous implementation was in the definition of `rotary_embedding_offset`, which provided the `query_sequence_length` instead of `kv_sequence_length`. This was able to pass the validity tests described [here](https://github.com/mlc-ai/mlc-llm/pull/1058#issuecomment-1761622534), as these two sequence lengths are identical for the first call. --- mlc_llm/core.py | 11 +- .../transform/fuse_split_rotary_embedding.py | 446 ++++++++++-------- 2 files changed, 247 insertions(+), 210 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 0b7d1c8c39..17f695b7a7 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -468,12 +468,11 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() mod = fuse_split_rotary_embedding( - mod, - config.num_attention_heads // args.num_shards, - num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, - config.position_embedding_base, - ) + config.num_attention_heads // args.num_shards, + num_key_value_heads // args.num_shards, + config.hidden_size // args.num_shards, + config.position_embedding_base, + )(mod) if args.target_kind == "cuda": patterns = [] diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index a7dbdf6c31..ed19a7095c 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -1,5 +1,5 @@ +import tvm from tvm import relax -from tvm.script import tir as T from tvm.relax.dpl import ( PatternContext, is_op, @@ -10,237 +10,275 @@ TuplePattern, is_shape, ) -from tvm.script import relax as R +from tvm.script import relax as R, tir as T -def get_split_rotary(num_attention_heads, head_dim, position_embedding_base): - hidden_size = num_attention_heads * head_dim +def get_dynamic_split_rotary(): + """Implementation of R.split(rotary_embedding(fused_qkv)) - @T.prim_func + Implementation is generic over the number of query heads, + key/value heads, sequence length, head dimension, and position + embedding base. These parameters can be replaced with static + values using `PrimFunc.specialize`. + """ + + @T.prim_func(private=True) def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, + fused_qkv_handle: T.handle, + embedded_query_handle: T.handle, + embedded_key_handle: T.handle, + value_handle: T.handle, + rotary_offset: T.int64, + batch_size: T.int64, + seq_len: T.int64, + num_query_heads: T.int64, + num_kv_heads: T.int64, + head_dim: T.int64, + position_embedding_base: T.float32, ): - A = T.match_buffer(qkv, [1, 1, hidden_size * 3], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, hidden_size], dtype="float16") + Fused_QKV = T.match_buffer( + fused_qkv_handle, + [batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim], + dtype="float16", + ) + EmbeddedQuery = T.match_buffer( + embedded_query_handle, + [batch_size, seq_len, num_query_heads, head_dim], + dtype="float16", + ) + EmbeddedKey = T.match_buffer( + embedded_key_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) + Value = T.match_buffer( + value_handle, + [batch_size, seq_len, num_kv_heads, head_dim], + dtype="float16", + ) T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)], - ) - T.writes( - T_split[v_ax0, v_ax1, v_ax2], - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) + + for iters in T.grid(batch_size, seq_len, num_query_heads + num_kv_heads * 2, head_dim): + with T.block("FusedRotaryEmbeddingAndSplitQKV"): + batch_i, seq_i, head_num, head_i = T.axis.remap("SSSS", iters) + pos: T.float32 = T.Cast("float32", rotary_offset + seq_i - seq_len) + inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), + position_embedding_base, + T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), ) freq: T.float32 = pos * inv_freq cos_value: T.float16 = T.Cast("float16", T.cos(freq)) sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size) + T.int64(head_dim // 2)] + + input_value = Fused_QKV[batch_i, seq_i, head_num, head_i] + embedded_value = cos_value * input_value + sin_value * T.Select( + head_i < T.int64(head_dim // 2), + Fused_QKV[batch_i, seq_i, head_num, head_i + T.int64(head_dim // 2)] * T.float16(-1), + Fused_QKV[batch_i, seq_i, head_num, head_i - T.int64(head_dim // 2)], ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[v_ax0, v_ax1, v_ax2 + T.int64(hidden_size * 2)] + if head_num < num_query_heads: + EmbeddedQuery[batch_i, seq_i, head_num, head_i] = embedded_value + elif head_num < num_query_heads + num_kv_heads: + EmbeddedKey[batch_i, seq_i, head_num - num_query_heads, head_i] = embedded_value + else: + Value[ + batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i + ] = input_value + + param_sinfo = [] + for param in split_rotary.params: + if param in split_rotary.buffer_map: + buf = split_rotary.buffer_map[param] + sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) + else: + sinfo = relax.PrimStructInfo(param.dtype) + param_sinfo.append(sinfo) + + relax.expr._update_struct_info( + split_rotary, + tvm.relax.FuncStructInfo( + params=param_sinfo, + ret=relax.TupleStructInfo([]), + purity=False, + ), + ) return split_rotary -def get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base +def fuse_split_rotary_embedding( + num_query_heads, num_kv_heads, hidden_size, position_embedding_base ): - query_hidden_size = num_query_heads * head_dim - kv_hidden_size = num_kv_heads * head_dim - total_size = query_hidden_size + kv_hidden_size * 2 + @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") + def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + head_dim = hidden_size // num_query_heads + split_rotary = get_dynamic_split_rotary() - @T.prim_func - def split_rotary( - qkv: T.handle, - split_0: T.handle, - split_1: T.handle, - split_2: T.handle, - n: T.int64, - ): - A = T.match_buffer(qkv, [1, 1, total_size], dtype="float16") - T_split = T.match_buffer(split_0, [1, 1, query_hidden_size], dtype="float16") - T_split_1 = T.match_buffer(split_1, [1, 1, kv_hidden_size], dtype="float16") - T_split_2 = T.match_buffer(split_2, [1, 1, kv_hidden_size], dtype="float16") + ( + dyn_batch_size, + dyn_seq_len, + dyn_num_query_heads, + dyn_num_kv_heads, + dyn_head_dim, + dyn_position_embedding_base, + ) = split_rotary.params[-6:] - T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(query_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2], - ) - T.writes(T_split[v_ax0, v_ax1, v_ax2]) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(head_dim // 2)] * T.float16(-1), - ) - for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(kv_hidden_size)): - with T.block("T_split"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads( - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size)], - ) - T.writes( - T_split_1[v_ax0, v_ax1, v_ax2], - T_split_2[v_ax0, v_ax1, v_ax2], - ) - pos: T.float32 = T.Cast("float32", n - T.int64(1)) - inv_freq: T.float32 = T.float32(1) / T.pow( - T.float32(position_embedding_base), - T.Cast("float32", (v_ax2 * 2) % head_dim) / T.float32(head_dim), - ) - freq: T.float32 = pos * inv_freq - cos_value: T.float16 = T.Cast("float16", T.cos(freq)) - sin_value: T.float16 = T.Cast("float16", T.sin(freq)) - T_split_1[v_ax0, v_ax1, v_ax2] = cos_value * A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - ] + sin_value * T.Select( - T.int64(head_dim // 2) <= v_ax2 % head_dim, - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) - T.int64(head_dim // 2)], - A[v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size) + T.int64(head_dim // 2)] - * T.float16(-1), - ) - T_split_2[v_ax0, v_ax1, v_ax2] = A[ - v_ax0, v_ax1, v_ax2 + T.int64(query_hidden_size + kv_hidden_size) - ] + split_rotary = split_rotary.specialize( + { + # Static model parameters + dyn_batch_size: T.int64(1), + dyn_num_query_heads: T.int64(num_query_heads), + dyn_num_kv_heads: T.int64(num_kv_heads), + dyn_head_dim: T.int64(head_dim), + dyn_position_embedding_base: T.float32(position_embedding_base), + # Dynamic parameters, to be inferred from TIR Buffer shapes + dyn_seq_len: tvm.tir.Var("query_sequence_length", "int64"), + } + ) - return split_rotary + mod["split_rotary"] = split_rotary + split_rotary_gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) -def fuse_split_rotary_embedding( - mod, num_query_heads, num_kv_heads, hidden_size, position_embedding_base -): - if "rotary_embedding1" not in [gv.name_hint for gv in mod.functions]: - return mod - - head_dim = hidden_size // num_query_heads - mod["split_rotary"] = ( - get_split_rotary(num_query_heads, head_dim, position_embedding_base) - if num_query_heads == num_kv_heads - else get_split_rotary_group_query_attention( - num_query_heads, num_kv_heads, head_dim, position_embedding_base - ) - ) + with PatternContext() as ctx: + # flat_qkv_tuple: R.Tuple( + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # R.Tensor((batch_size, seq_len, 4096), dtype="float16"), + # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) + # + # flat_query: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[0] + # query: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_query, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_key: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[1] + # key: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_key, R.shape([batch_size, seq_len, 32, 128]) + # ) + # flat_value: R.Tensor((batch_size, seq_len, 4096), dtype="float16") = flat_qkv_tuple[2] + # value: R.Tensor((batch_size, seq_len, 32, 128), dtype="float16") = R.reshape( + # flat_value, R.shape([batch_size, seq_len, 32, 128]) + # ) + # embedded_query = R.call_tir( + # cls.rotary_embedding1, + # [query], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) + # embedded_key = R.call_tir( + # cls.rotary_embedding1, + # [key], + # out_sinfo=R.Tensor((batch_size, seq_len, 32, 128), dtype="float16"), + # tir_vars=R.shape([n]), + # ) - gvar = mod.get_global_var("split_rotary") - relax.expr._update_struct_info(gvar, mod.get_global_var("rotary_embedding1").struct_info) + pat_rotary_embedding_gvar = GlobalVarPattern() - with PatternContext() as ctx: - # lv3: R.Tuple(R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16"), R.Tensor((1, 1, 4096), dtype="float16")) = R.split(lv2, indices_or_sections=[4096, 8192], axis=2) + pat_flat_fused_qkv = wildcard() + pat_offset = wildcard() - # lv1521: R.Tensor((1, 1, 4096), dtype="float16") = lv3[0] - # lv1522: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1521, R.shape([1, 1, 32, 128])) - # lv1524: R.Tensor((1, 1, 4096), dtype="float16") = lv3[1] - # lv1525: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1524, R.shape([1, 1, 32, 128])) - # lv1527: R.Tensor((1, 1, 4096), dtype="float16") = lv3[2] - # lv1528: R.Tensor((1, 1, 32, 128), dtype="float16") = R.reshape(lv1527, R.shape([1, 1, 32, 128])) - # lv1530 = R.call_tir(cls.rotary_embedding1, (lv1525, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape([n])) - # lv_1 = R.call_tir(cls.rotary_embedding1, (lv1522, cos_cached1, sin_cached1), out_sinfo=R.Tensor((1, 1, 32, 128), dtype="float16"), tir_vars=R.shape( + # query_shape = is_shape([1, seq_len, num_query_heads, head_dim]) + pat_query_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_key_shape = wildcard() + # value_shape = is_shape([1, seq_len, num_kv_heads, head_dim]) + pat_value_shape = wildcard() - inp_pat = wildcard() - offset = wildcard() + pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) + pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) + pat_query = is_op("relax.reshape")( + pat_flat_query, pat_query_shape, add_constraint=False + ) + pat_flat_query.used_by(pat_query) + pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) + pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) + pat_flat_key.used_by(pat_key) + pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) + pat_value = is_op("relax.reshape")( + pat_flat_value, pat_value_shape, add_constraint=False + ) + pat_flat_value.used_by(pat_value) - lv3 = is_op("relax.split")(inp_pat) - lv1521 = is_tuple_get_item(lv3, 0) - lv1522 = is_op("relax.reshape")( - lv1521, is_shape([1, 1, num_query_heads, head_dim]), add_constraint=False - ) - lv1521.used_by(lv1522) - lv1524 = is_tuple_get_item(lv3, 1) - lv1525 = is_op("relax.reshape")( - lv1524, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1524.used_by(lv1525) - lv1527 = is_tuple_get_item(lv3, 2) - V = is_op("relax.reshape")( - lv1527, is_shape([1, 1, num_kv_heads, head_dim]), add_constraint=False - ) - lv1527.used_by(V) + pat_embedded_query = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_query]), + pat_offset, + add_constraint=False, + ) + pat_embedded_key = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_key]), + pat_offset, + add_constraint=False, + ) - Q = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1522]), offset, add_constraint=False - ) - K = is_op("relax.call_tir")( - GlobalVarPattern(), TuplePattern([lv1525]), offset, add_constraint=False - ) + pat_flat_qkv_tuple.used_by(pat_flat_query) + pat_flat_qkv_tuple.used_by(pat_flat_key) + pat_flat_qkv_tuple.used_by(pat_flat_value) + pat_query.used_by(pat_embedded_query) + pat_key.used_by(pat_embedded_key) - lv3.used_by(lv1521) - lv3.used_by(lv1524) - lv3.used_by(lv1527) - lv1522.used_by(Q) - lv1525.used_by(K) - - def rewriter(matchings, bindings): - inp = matchings[inp_pat] - call_tir = matchings[Q] - n = bindings[call_tir].args[-1] - out_sinfo = [ - R.Tensor((1, 1, num_query_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - R.Tensor((1, 1, num_kv_heads * head_dim), dtype="float16"), - ] - lv3_new = R.call_tir( - mod.get_global_var("split_rotary"), (inp,), out_sinfo=out_sinfo, tir_vars=n - ) - lv1521_new = lv3_new[0] - lv1522_new = R.reshape(lv1521_new, R.shape([1, 1, num_query_heads, head_dim])) - lv1524_new = lv3_new[1] - lv1525_new = R.reshape(lv1524_new, R.shape([1, 1, num_kv_heads, head_dim])) - lv1527_new = lv3_new[2] - lv1528_new = R.reshape(lv1527_new, R.shape([1, 1, num_kv_heads, head_dim])) - - return { - matchings[lv3]: lv3_new, - matchings[lv1521]: lv1521_new, - matchings[lv1522]: lv1522_new, - matchings[lv1524]: lv1524_new, - matchings[lv1525]: lv1525_new, - matchings[lv1527]: lv1527_new, - matchings[V]: lv1528_new, - matchings[Q]: lv1522_new, - matchings[K]: lv1525_new, - } - - mod["decode"] = rewrite_bindings(ctx, rewriter, mod["decode"]) - return mod + def rewriter(matchings, bindings): + # Extracting all the relax and TIR variables that we'll need + flat_fused_qkv = matchings[pat_flat_fused_qkv] + flat_qkv_tuple = matchings[pat_flat_qkv_tuple] + + flat_query = matchings[pat_flat_query] + flat_key = matchings[pat_flat_key] + flat_value = matchings[pat_flat_value] + + query = matchings[pat_query] + key = matchings[pat_key] + value = matchings[pat_value] + + embedded_query = matchings[pat_embedded_query] + embedded_key = matchings[pat_embedded_key] + + # rotary_embedding_offset = bindings[query].args[-1][1] + rotary_embedding_offset = bindings[embedded_query].args[-1][0] + + batch_size, seq_len, num_query_heads, head_dim = query.struct_info.shape + _batch_size, _seq_len, num_kv_heads, _head_dim = key.struct_info.shape + + # Rewriting along the new path + + fused_qkv = relax.op.reshape( + flat_fused_qkv, [batch_size, seq_len, num_query_heads + 2 * num_kv_heads, head_dim] + ) + + split_rotary_sinfo = [ + R.Tensor((batch_size, seq_len, num_query_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + R.Tensor((batch_size, seq_len, num_kv_heads, head_dim), dtype="float16"), + ] + qkv_tuple_new = R.call_tir( + split_rotary_gvar, + (fused_qkv,), + out_sinfo=split_rotary_sinfo, + tir_vars=[rotary_embedding_offset], + ) + + embedded_query_new = qkv_tuple_new[0] + embedded_key_new = qkv_tuple_new[1] + value_new = qkv_tuple_new[2] + + return { + value: value_new, + embedded_query: embedded_query_new, + embedded_key: embedded_key_new, + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + return new_mod + + return ir_module_pass From b1905787d37261f4f35384b89d1f0d6b9492e8ec Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 13:58:18 -0500 Subject: [PATCH 068/116] Apply rewrite for normal attention and MQA (#1138) Fixes a bug introduced in https://github.com/mlc-ai/mlc-llm/pull/1052, where use of the `--use-flash-attn-mqa` flag on a model that doesn't use MQA would prevent the use of CUTLASS attention at all. --- mlc_llm/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 17f695b7a7..6b993c07b5 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -480,7 +480,9 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: - mod = rewrite_attention(use_flash_mqa=args.use_flash_attn_mqa)(mod) + if args.use_flash_attn_mqa: + mod = rewrite_attention(use_flash_mqa=True)(mod) + mod = rewrite_attention(use_flash_mqa=False)(mod) patterns += get_patterns_with_prefix("cutlass.attention") if has_cutlass and not args.no_cutlass_norm: From 8ca0176d41b90b109246a0d5c22efc06d6f41352 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Mon, 30 Oct 2023 12:26:58 -0700 Subject: [PATCH 069/116] [Rest] Fix emoji handling in Rest API. (#1142) --- python/mlc_chat/rest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index a1cc57c4e9..b6762a5b4d 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -215,18 +215,19 @@ async def iter_response(): prev_txt = "" async for content in AsyncCompletionStream(generation_config=generation_config): if content: + valid_content = content.replace("�", "") chunk = ChatCompletionStreamResponse( choices=[ ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage( - role="assistant", content=content[len(prev_txt) :] + role="assistant", content=valid_content[len(prev_txt) :] ), finish_reason="stop", ) ] ) - prev_txt = content + prev_txt = valid_content yield f"data: {chunk.json(exclude_unset=True)}\n\n" return StreamingResponse(iter_response(), media_type="text/event-stream") From 3cf560566b5d28ef250cda8aff341eb04a88a688 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 14:43:51 -0500 Subject: [PATCH 070/116] [Utility] Check for isinstance(exc, Exception) before entering pdb (#1095) This is a follow-up to #1017, which added a `--pdb` flag to enter a debugger on exit. This commit checks the type of the raised exception, and only enters the debugger if it is a subclass of `Exception`. This ensures that implementation-details, such as a thrown `SystemExit` or `KeyboardInterrupt`, do not cause an erroneous entry to pdb. --- mlc_llm/build.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlc_llm/build.py b/mlc_llm/build.py index c90da542b8..b7619aa963 100644 --- a/mlc_llm/build.py +++ b/mlc_llm/build.py @@ -10,7 +10,8 @@ def debug_on_except(): try: yield finally: - if sys.exc_info() == (None, None, None): + raised_exception = sys.exc_info()[1] + if not isinstance(raised_exception, Exception): return import traceback From 0a9d6c7a351fac9394a38b0d66f060cb4e325bbe Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 30 Oct 2023 14:44:44 -0500 Subject: [PATCH 071/116] [Utils] Remove conversion to numpy array in utils.save_params (#1083) Prior to this commit, each parameter was converted to a numpy-owned array as part of a total size computation. This commit computes the size directly, removing the conversion. --- mlc_llm/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index bb19f45c4f..1bcf1e8816 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -2,10 +2,13 @@ import argparse import functools import json +import math import os import shutil from typing import Any, Dict, List, Optional, Set +import numpy as np + import tvm from tvm import relax @@ -283,11 +286,12 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str) -> None: meta_data["ParamSize"] = len(params) total_size = 0.0 for i, nd in enumerate(params): + assert nd is not None, f"Missing parameter at index {i}" param_dict[f"param_{i}"] = nd - np_nd = nd.numpy() - total_size += np_nd.size * np_nd.dtype.itemsize - total_size = total_size / 1024.0 / 1024.0 / 1024.0 - print(f"Total param size: {total_size} GB") + + total_size_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params) + total_size_gb = total_size_bytes / (1024 ** 3) + print(f"Total param size: {total_size_gb} GB") tvmjs.dump_ndarray_cache( param_dict, f"{artifact_path}/params", meta_data=meta_data, encode_format="raw" ) From 425a2cb09870bfa8edde20cf4670b9f32bf76172 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 30 Oct 2023 15:00:24 -0700 Subject: [PATCH 072/116] [Fix][REST] Use lowered-cased "app" (#1159) --- python/mlc_chat/rest.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index b6762a5b4d..2816d9bec7 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -143,8 +143,8 @@ async def lifespan(_app: FastAPI): origins = ["*"] -APP = FastAPI(lifespan=lifespan) -APP.add_middleware( +app = FastAPI(lifespan=lifespan) +app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, @@ -177,7 +177,7 @@ async def __anext__(self): raise StopAsyncIteration -@APP.post("/v1/chat/completions") +@app.post("/v1/chat/completions") async def request_chat_completion(request: ChatCompletionRequest): """ Creates model response for the given chat conversation. @@ -248,7 +248,7 @@ async def iter_response(): ) -@APP.post("/v1/completions") +@app.post("/v1/completions") async def request_completion(request: CompletionRequest): """ Creates a completion for a given prompt. @@ -309,7 +309,7 @@ async def iter_response(): ) -@APP.post("/v1/embeddings") +@app.post("/v1/embeddings") async def request_embeddings(request: EmbeddingsRequest): """ Gets embedding for some text. @@ -335,7 +335,7 @@ async def request_embeddings(request: EmbeddingsRequest): ) -@APP.post("/chat/reset") +@app.post("/chat/reset") async def reset(): """ Reset the chat for the currently initialized model. @@ -343,7 +343,7 @@ async def reset(): session["chat_mod"].reset_chat() -@APP.get("/stats") +@app.get("/stats") async def read_stats(): """ Get the runtime stats. @@ -351,7 +351,7 @@ async def read_stats(): return session["chat_mod"].stats() -@APP.get("/verbose_stats") +@app.get("/verbose_stats") async def read_stats_verbose(): """ Get the verbose runtime stats. From 9076d0134f314d5cbf11909b8c0f954c399e8f62 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Mon, 30 Oct 2023 22:40:53 -0700 Subject: [PATCH 073/116] [Rest] Document emoji handling (#1160) Followup PR of #1142 to document the emoji handling. --- python/mlc_chat/rest.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index 2816d9bec7..e92c2824d3 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -215,6 +215,10 @@ async def iter_response(): prev_txt = "" async for content in AsyncCompletionStream(generation_config=generation_config): if content: + # Remove the replacement character (U+FFFD) from the response + # This is to handle emojis. An emoji might be made up of multiple tokens. + # In the Rest streaming setting, if an emoji gets truncated in the middle of + # its encoded byte sequence, a replacement character will appear. valid_content = content.replace("�", "") chunk = ChatCompletionStreamResponse( choices=[ From b5bfa5be8363ffdbb26d901582962bb3882c1dcb Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 31 Oct 2023 11:39:44 -0700 Subject: [PATCH 074/116] Enable group quant transform with nn.Module (#1154) * Enable group quant transform with nn.Module This PR completes the group quantization support for `nn.Module` based model. * remove deprecated tests * Update * wip * remove deprecated test * fix lint * fix lint * fix lint --------- Co-authored-by: Junru Shao --- python/mlc_chat/cli/compile.py | 6 +- python/mlc_chat/compiler/__init__.py | 2 +- python/mlc_chat/compiler/compile.py | 16 +- python/mlc_chat/compiler/model/llama_model.py | 7 +- .../compiler/model/llama_quantization.py | 109 ++------ python/mlc_chat/compiler/model/model.py | 21 +- .../compiler/quantization/__init__.py | 3 +- .../quantization/group_quantization.py | 237 ++++++++++++++++++ .../compiler/quantization/group_quantizer.py | 70 ------ .../compiler/quantization/quantization.py | 24 +- .../python/parameter/test_group_quantizer.py | 157 ------------ .../quantization/test_group_quantization.py | 85 +++++++ 12 files changed, 388 insertions(+), 349 deletions(-) create mode 100644 python/mlc_chat/compiler/quantization/group_quantization.py delete mode 100644 python/mlc_chat/compiler/quantization/group_quantizer.py delete mode 100644 tests/python/parameter/test_group_quantizer.py create mode 100644 tests/python/quantization/test_group_quantization.py diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 31b639a68f..e3ff778487 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -6,7 +6,7 @@ from mlc_chat.compiler import ( # pylint: disable=redefined-builtin MODELS, - QUANT, + QUANTIZATION, OptimizationFlags, compile, ) @@ -51,7 +51,7 @@ def _parse_output(path: Union[str, Path]) -> Path: "--quantization", type=str, required=True, - choices=list(QUANT.keys()), + choices=list(QUANTIZATION.keys()), help="Quantization format.", ) parser.add_argument( @@ -119,7 +119,7 @@ def _parse_output(path: Union[str, Path]) -> Path: parsed.model_type = detect_model_type(parsed.model_type, parsed.config) compile( config=parsed.config, - quantization=parsed.quantization, + quantization=QUANTIZATION[parsed.quantization], model_type=parsed.model_type, target=target, opt=parsed.opt, diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index 4905e8ac91..cf68426f8e 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -7,4 +7,4 @@ from .flags_optimization import OptimizationFlags from .model import MODEL_PRESETS, MODELS, Model from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping -from .quantization import QUANT +from .quantization import QUANTIZATION diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 5b77a94f81..88f33b03af 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -7,9 +7,10 @@ from tvm import IRModule, relax from tvm.target import Target -from ..compiler.model import Model from ..support.style import bold from .flags_optimization import OptimizationFlags +from .model import Model +from .quantization import Quantization @dataclasses.dataclass @@ -17,7 +18,7 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes """Arguments to MLC LLM's compiler.""" config: Path - quantization: str + quantization: Quantization model: Model target: Target opt: OptimizationFlags @@ -40,20 +41,19 @@ def _echo_args(args: CompileArgs) -> None: def _compile(args: CompileArgs): model_config = args.model.config.from_file(args.config) - model = args.model.model(model_config) - mod, named_params = model.export_tvm( + quantization = args.quantization + model, _ = args.model.quantize[quantization.kind](model_config, quantization) + mod, _named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore ) with args.target: mod = relax.get_pipeline("mlc_llm")(mod) - mod.show(black_format=False) - for name, param in named_params: - print(f"{name}: {param.shape} {param.dtype}") + args.build_func(mod, args) def compile( # pylint: disable=too-many-arguments,redefined-builtin config: Path, - quantization, + quantization: Quantization, model_type: Model, target: Target, opt: OptimizationFlags, diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 1106b38c56..0c9d2f45ab 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -22,6 +22,8 @@ def __init__(self, config: LlamaConfig): def forward(self, q: Tensor, k: Tensor, offset: tir.Var): def te_op(x: te.Tensor, offset: tir.Var): + dtype = x.dtype + def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): head_dim = tir.const(self.head_dim, "int32") position_embedding_base = tir.const(self.position_embedding_base, "float32") @@ -30,11 +32,13 @@ def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): (d * 2 % head_dim).astype("float32") / head_dim, ) freq = (offset + s) / freq - return tir.cos(freq) * x[b, s, h, d] + tir.sin(freq) * tir.if_then_else( + cos = tir.cos(freq).astype(dtype) * x[b, s, h, d] + sin = tir.sin(freq).astype(dtype) * tir.if_then_else( d < self.head_dim // 2, -x[b, s, h, d + self.head_dim // 2], x[b, s, h, d - self.head_dim // 2], ) + return cos + sin return te.compute(x.shape, compute, name="rotary") @@ -87,6 +91,7 @@ def forward( # pylint: disable=too-many-locals d, h_q, h_kv, t = self.head_dim, self.num_q_heads, self.num_kv_heads, total_seq_len b, s, _ = hidden_states.shape assert b == 1, "Only support batch size 1 at this moment." + q, k, v = self.qkv_proj(hidden_states) q = op.reshape(q, (b, s, h_q, d)) k = op.reshape(k, (b, s, h_kv, d)) diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py index a263ba0c4d..a5f8f0b0df 100644 --- a/python/mlc_chat/compiler/model/llama_quantization.py +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -1,101 +1,24 @@ -""" -Quantization specs for Llama2 architecture. -TODO: add docstring -""" -from typing import Callable, Dict, List, Optional +"""Quantization specs for Llama.""" +from typing import Tuple -import tvm -from tvm.runtime import NDArray +from tvm.relax.frontend import nn from ..parameter import QuantizeMapping -from ..quantization import QuantizeConfig -from ..quantization.group_quantizer import te_quantize as te_group_quantize +from ..quantization import GroupQuantize from .llama_config import LlamaConfig from .llama_model import LlamaForCasualLM -def huggingface_group_quantize( +def group_quant( model_config: LlamaConfig, - quantize_config: QuantizeConfig, - target: Optional[tvm.target.Target] = None, -) -> QuantizeMapping: - """Returns a parameter mapping that maps a parameter in MLC LLM's model - definition to its eventual names and values after quantization. - - Parameters - ---------- - model_config : LlamaConfig - The configuration of the Llama model. - quantize_config : GroupQuantizeConfig - The configuration of the group quantization. - target : Optional[tvm.target.Target] - The target device to run the quantization on, by default None, which - means the quantization will be run on CPU. - - Returns - ------- - quantize_map : QuantizeMapping - The parameter mapping from a parameter in MLC LLM's model definition to - its eventual names and values after quantization. - """ - - def group_quantize( - param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None - ): - if target is None or target.kind.name == "llvm": - target = tvm.target.Target("llvm") - device = tvm.cpu() - elif target.kind.name == "cuda": - device = tvm.cuda() - else: - raise ValueError(f"Invalid target device: {target}") - param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param") - weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore - param_tensor, config - ) - s = tvm.te.create_schedule( # pylint: disable=invalid-name - [compute.op for compute in [weight_compute, scale_compute] + other_computes] - ) - if target.kind.name == "cuda": - # thread_binding for cuda - for compute in [weight_compute, scale_compute] + other_computes: - xo, xi = s[compute].split(compute.op.axis[0], 256) # pylint: disable=invalid-name - s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) - s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) - f_quantize = tvm.build( - s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target - ) - weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device) - scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device) - f_quantize(param.copyto(device), weight, scale) - return weight, scale - - # Param check - assert ( - quantize_config.kind == "group_quantize" - ), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}" - assert ( - quantize_config.name == "q4f16_1" - ), """Only support q4f16_1 quantization scheme for now.""" - - # Fetch model parameter & names - model = LlamaForCasualLM(model_config) - _, named_params = model.export_tvm(spec=model.get_default_spec()) - parameter_names = {name for name, _ in named_params} - - # Init mappings - param_map: Dict[str, List[str]] = {} - map_func: Dict[str, Callable] = {} - - # Dispatch quantization scheme - # Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py - for name in parameter_names: - if "norm.weight" not in name and "embed" not in name: - param_map[name] = [f"{name}_quantized", f"{name}_scale"] - map_func[name] = lambda x: group_quantize(x, quantize_config, target=target) - else: - # skip these parameters - param_map[name] = [name] - map_func[name] = lambda x: [x] - - return QuantizeMapping(param_map, map_func) + quantization: GroupQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model using group quantization.""" + model: nn.Module = LlamaForCasualLM(model_config) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "model", + ) + return model, quant_map diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index 3027a39500..74159cc188 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -1,12 +1,12 @@ """A centralized registry of all existing model architures and their configurations.""" import dataclasses -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, Tuple from tvm.relax.frontend import nn from ..parameter import ExternMapping, QuantizeMapping -from ..quantization.quantization import QuantizeConfig -from . import llama_config, llama_model, llama_parameter +from ..quantization.quantization import Quantization +from . import llama_config, llama_model, llama_parameter, llama_quantization ModelConfig = Any """A ModelConfig is an object that represents a model architecture. It is required to have @@ -16,8 +16,8 @@ def from_file(cls, path: Path) -> ModelConfig: ... """ -FuncGetExternMap = Callable[[ModelConfig, QuantizeConfig], ExternMapping] -FuncGetQuantMap = Callable[[ModelConfig, QuantizeConfig], QuantizeMapping] +FuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping] +FuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]] @dataclasses.dataclass @@ -38,15 +38,16 @@ class Model: source : Dict[str, FuncGetExternMap] A dictionary that maps the name of a source format to parameter mapping. - quantize: Dict[str, FuncGetQuantMap] - A dictionary that maps the name of a quantization method to quantization mapping. + quantize: Dict[str, FuncQuantization] + A dictionary that maps the name of a quantization method to quantized model and the + quantization parameter mapping. """ name: str config: ModelConfig model: Callable[[ModelConfig], nn.Module] source: Dict[str, FuncGetExternMap] - quantize: Dict[str, FuncGetQuantMap] + quantize: Dict[str, FuncQuantization] MODELS: Dict[str, Model] = { @@ -58,7 +59,9 @@ class Model: "huggingface-torch": llama_parameter.huggingface, "huggingface-safetensor": llama_parameter.huggingface, }, - quantize={}, + quantize={ + "group-quant": llama_quantization.group_quant, + }, ) } diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py index a932119f9c..3df96ce18a 100644 --- a/python/mlc_chat/compiler/quantization/__init__.py +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -1,2 +1,3 @@ """A subpackage for quantization and dequantization algorithms""" -from .quantization import QUANT, QuantizeConfig +from .group_quantization import GroupQuantize +from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py new file mode 100644 index 0000000000..5bfaf084b2 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -0,0 +1,237 @@ +"""The group quantization config""" +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +from tvm import DataType, DataTypeCode, device +from tvm import dlight as dl +from tvm import relax, te, tir +from tvm.relax.frontend import nn +from tvm.runtime import NDArray +from tvm.target import Target + +from ..parameter import QuantizeMapping + + +@dataclass +class GroupQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for group quantization""" + + name: str + kind: str + group_size: int + quantize_dtype: str # "int3", "int4", "int8" + storage_dtype: str # "uint32" + model_dtype: str # "float16", "float32" + + num_elem_per_storage: int = 0 + num_storage_per_group: int = 0 + max_int_value: int = 0 + + def __post_init__(self): + assert self.kind == "group-quant" + quantize_dtype = DataType(self.quantize_dtype) + storage_dtype = DataType(self.storage_dtype) + model_dtype = DataType(self.model_dtype) + assert quantize_dtype.type_code == DataTypeCode.INT + assert storage_dtype.type_code == DataTypeCode.UINT + assert model_dtype.type_code == DataTypeCode.FLOAT + if storage_dtype.bits < quantize_dtype.bits: + raise ValueError("Storage unit should be greater or equal to quantized element") + + self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits + if self.group_size % self.num_elem_per_storage != 0: + raise ValueError("Group size should be divisible by numbers of elements per storage") + self.num_storage_per_group = self.group_size // self.num_elem_per_storage + self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + + def quantize_model( + self, + model: nn.Module, + quant_map: QuantizeMapping, + name_prefix: str, + ) -> nn.Module: + """Quantize model with group quantization""" + + class _Mutator(nn.Mutator): + def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None: + super().__init__() + self.config = config + self.quant_map = quant_map + + def visit_module(self, name: str, node: nn.Module) -> Any: + if isinstance(node, nn.Linear): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + # self.quant_map.map_func[weight_name] = self.config.quantize + return GroupQuantizeLinear.from_linear(node, self.config) + return self.visit(name, node) + + model.to(dtype=self.model_dtype) + mutator = _Mutator(self, quant_map) + model = mutator.visit(name_prefix, model) + return model + + def _dequantize( + self, + weight: te.Tensor, + scale: te.Tensor, + out_shape: Optional[List[tir.PrimExpr]] = None, + ): + quantize_dtype = DataType(self.quantize_dtype) + storage_dtype = DataType(self.storage_dtype) + tir_bin_mask = tir.const((2**quantize_dtype.bits) - 1, self.storage_dtype) + tir_max_int = tir.const(self.max_int_value, self.model_dtype) + dequantized_weight = te.compute( + shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage] + if out_shape is None + else out_shape, + fcompute=lambda i, j: tir.multiply( + tir.subtract( + tir.bitwise_and( + tir.shift_right( + weight[i, j // self.num_elem_per_storage], + (j % self.num_elem_per_storage) * storage_dtype.bits, + ), + tir_bin_mask, + ), + tir_max_int, + ), + scale[i, j // self.group_size], + ), + ) + return dequantized_weight + + def quantize_weight(self, weight: NDArray) -> List[NDArray]: + """Quantize weight with group quantization""" + assert weight.dtype == self.model_dtype + assert len(weight.shape) == 2 + bb = relax.BlockBuilder() + weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, self.model_dtype)) + with bb.function(name="quantize", params=[weight_var]): + with bb.dataflow(): + lv = bb.emit_te(self._quantize, weight_var) + gv = bb.emit_output(lv) + bb.emit_func_output(gv) + mod = bb.get() + with Target("cuda"): + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Reduction(), dl.gpu.GeneralReduction(), dl.gpu.Fallback() + )(mod) + ex = relax.build(mod, "cuda") + dev = device("cuda", 0) + vm = relax.VirtualMachine(ex, dev) + return vm["quantize"](weight) + + def _quantize( # pylint: disable=too-many-locals + self, weight: te.Tensor + ) -> Tuple[te.Tensor, te.Tensor]: + """Group quantization for weight tensor, defined in tensor expression.""" + assert len(weight.shape) == 2 + n, k = weight.shape # pylint: disable=invalid-name + quantize_dtype = DataType(self.quantize_dtype) + # compute scale per group + r = te.reduce_axis((0, self.group_size), name="r") # pylint: disable=invalid-name + num_group = tir.ceildiv(k, self.group_size) + scale_shape = (n, num_group) + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda i, j: te.max( + te.abs(weight[i, j * self.group_size + r]), + where=j * self.group_size + r < k, + axis=r, + ), + name="max_abs_value", + ) + scale = te.compute( + scale_shape, + lambda i, j: max_abs[i, j] / tir.const(self.max_int_value, self.model_dtype), + name="scale", + ) + + # compute scaled weight + tir_max_int = tir.const(self.max_int_value, self.model_dtype) + tir_zero = tir.const(0, self.model_dtype) + tir_max_int_2 = tir.const(self.max_int_value * 2, self.model_dtype) + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda i, j: tir.min( + tir.max( + tir.round(weight[i, j] / scale[i, j // self.group_size] + tir_max_int), + tir_zero, + ), + tir_max_int_2, + ).astype(self.storage_dtype), + ) + + # compute quantized weight per storage + r = te.reduce_axis((0, self.num_elem_per_storage), name="r") + num_storage = self.num_storage_per_group * num_group + quantized_weight_shape = (n, num_storage) + quantized_weight = te.compute( + shape=quantized_weight_shape, + fcompute=lambda i, j: tir.sum( + scaled_weight[i, j * self.num_elem_per_storage + r] << (r * quantize_dtype.bits), + axis=r, + where=j * self.num_elem_per_storage + r < k, + ), + name="weight", + ) + return quantized_weight, scale + + +class GroupQuantizeLinear(nn.Module): + """An nn.Linear module with group quantization""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: int, + config: GroupQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + n_group = tir.ceildiv(in_features, config.group_size) + self.weight = nn.Parameter( + (out_features, n_group * config.num_elem_per_storage), + config.storage_dtype, + ) + self.scale = nn.Parameter((out_features, n_group), config.model_dtype) + if bias: + self.bias = nn.Parameter((out_features,), config.model_dtype) + else: + self.bias = None + + @staticmethod + def from_linear(linear: nn.Linear, config: GroupQuantize): + """Converts a non-quantized nn.Linear to a quantized GroupQuantizeLinear""" + return GroupQuantizeLinear( + in_features=linear.in_features, + out_features=linear.out_features, + config=config, + bias=getattr(linear, "bias", None) is not None, + out_dtype=linear.out_dtype, + ) + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name,missing-docstring + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + [ + tir.IntImm("int64", self.out_features), + tir.IntImm("int64", self.in_features), + ], + ), + name_hint="decode", + args=[self.weight, self.scale], + ) + w = nn.op.permute_dims(w) # pylint: disable=invalid-name + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + return x diff --git a/python/mlc_chat/compiler/quantization/group_quantizer.py b/python/mlc_chat/compiler/quantization/group_quantizer.py deleted file mode 100644 index b95c946abd..0000000000 --- a/python/mlc_chat/compiler/quantization/group_quantizer.py +++ /dev/null @@ -1,70 +0,0 @@ -"""A group quantizer for on the fly parameter quantization""" -# pylint: disable=too-few-public-methods - -from typing import List, Tuple - -from tvm import te, tir - -from .quantization import QuantizeConfig - - -def te_quantize( - weight: te.Tensor, config: QuantizeConfig -) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]: - """Group quantization for weight tensor, defined in tensor expression.""" - # pylint: disable=too-many-locals - assert len(weight.shape) == 2 - n, m = weight.shape # pylint: disable=invalid-name - # compute scale per group - r = te.reduce_axis((0, config.group_size), name="r") # pylint: disable=invalid-name - num_group = tir.ceildiv(m, config.group_size) - scale_shape = (n, num_group) - max_abs = te.compute( - shape=scale_shape, - fcompute=lambda i, j: te.max( - tir.if_then_else( - j * config.group_size + r < weight.shape[1], - te.abs(weight[i, j * config.group_size + r]), - tir.const(1e-4, config.weight_dtype), - ), - axis=r, - ), - name="max_abs_value", - ) - scale = te.compute( - (n, m), - lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype), - name="scale", - ) - - # compute scaled weight - tir_max_int = tir.const(config.max_int_value, config.weight_dtype) - tir_zero = tir.const(0, config.weight_dtype) - tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype) - scaled_weight = te.compute( - shape=weight.shape, - fcompute=lambda i, j: tir.min( - tir.max( - tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int), - tir_zero, - ), - tir_max_int_2, - ).astype(config.storage_dtype), - ) - - # compute quantized weight per storage - r = te.reduce_axis((0, config.num_elem_per_storage), name="r") # pylint: disable=invalid-name - num_storage = config.num_storage_per_group * num_group - quantized_weight_shape = (n, num_storage) - quantized_weight = te.compute( - shape=quantized_weight_shape, - fcompute=lambda i, j: tir.sum( - scaled_weight[i, j * config.num_elem_per_storage + r] - << (r * config.quantize_dtype_bits), - axis=r, - where=j * config.num_elem_per_storage + r < m, - ), - name="weight", - ) - return quantized_weight, scale, [max_abs, scaled_weight] - # pylint: enable=too-many-locals diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py index c1ba794063..2efad4beb4 100644 --- a/python/mlc_chat/compiler/quantization/quantization.py +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -1,22 +1,34 @@ """A centralized registry of all existing quantization methods and their configurations.""" from typing import Any, Dict -QuantizeConfig = Any -"""A QuantizeConfig is an object that represents an quantization algorithm. It is required to +from .group_quantization import GroupQuantize + +Quantization = Any +"""Quantization is an object that represents an quantization algorithm. It is required to have the following fields: name : str The name of the quantization algorithm, for example, "q4f16_1". kind : str - The kind of quantization algorithm, for example, "group_quant", "faster_transformer". + The kind of quantization algorithm, for example, "group-quant", "faster-transformer". It is also required to have the following method: - def quantize(self, module: nn.Module) -> nn.Module: + def quantize_model(self, module: nn.Module) -> nn.Module: + ... + + def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArray]: ... """ -QUANT: Dict[str, QuantizeConfig] = { - "q4f16_1": None, +QUANTIZATION: Dict[str, Quantization] = { + "q4f16_1": GroupQuantize( + name="q4f16_1", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + ), } diff --git a/tests/python/parameter/test_group_quantizer.py b/tests/python/parameter/test_group_quantizer.py deleted file mode 100644 index 4c16548b64..0000000000 --- a/tests/python/parameter/test_group_quantizer.py +++ /dev/null @@ -1,157 +0,0 @@ -# pylint: disable=missing-docstring,too-many-instance-attributes -import logging -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Tuple, Union - -import numpy as np -import tvm -from mlc_chat.compiler import MODELS -from mlc_chat.compiler.model.llama_config import LlamaConfig -from mlc_chat.compiler.model.llama_quantization import huggingface_group_quantize -from mlc_chat.compiler.parameter import HuggingFaceLoader -from mlc_chat.support import tqdm -from tvm.runtime import NDArray - -if TYPE_CHECKING: - from tvm.relax.frontend import nn - -logging.basicConfig( - level=logging.DEBUG, - style="{", - datefmt="%Y-%m-%d %H:%M:%S", - format="[{asctime}] {levelname} {filename}:{lineno}: {message}", -) - - -def test_load_torch_llama_group_quantize(base_path: Union[str, Path], target: str = "llvm"): - @dataclass - class TestGroupQuantizeConfig: - name: str = "q4f16_1" - kind: str = "group_quantize" - group_size: int = 32 - weight_dtype: str = "float16" - max_int_value: int = 7 - storage_dtype: str = "uint32" - num_elem_per_storage: int = 8 - num_storage_per_group: int = 4 - quantize_dtype_bits: int = 4 - - def quantize(self, _: "nn.Module") -> "nn.Module": - raise NotImplementedError - - base_path = Path(base_path) - path_config = base_path / "config.json" - path_params = base_path / "pytorch_model.bin.index.json" - - model = MODELS["llama"] - model_config = LlamaConfig.from_file(path_config) - quantize_config = TestGroupQuantizeConfig() - loader = HuggingFaceLoader( - path=path_params, - extern_param_map=model.source["huggingface-torch"](model_config, None), - quantize_param_map=huggingface_group_quantize( - model_config, - quantize_config, - target=tvm.target.Target(target), - ), - ) - with tqdm.redirect(): - for _name, _param in loader.load(): - ... - - -def test_group_quantize_vs_numpy(): - bits = { - "int4": 4, - "int8": 8, - "fp16": 16, - "fp32": 32, - "int32": 32, - "uint32": 32, - } - - # pylint: disable=unused-variable - def group_quantize_np( - w: NDArray, # pylint: disable=invalid-name - quantize_dtype: str = "int4", - storage_dtype: str = "uint32", - group_size: int = 32, - # symmetric: bool = True, - # transpose: bool = False, - ) -> Tuple[NDArray, NDArray]: - # pylint: disable=too-many-locals - def _pad_axis_by_factor(tensor: np.ndarray, axis: int, factor: int) -> np.ndarray: - dim = int(tensor.shape[axis]) - if dim % factor == 0: - return tensor - pad_width = [[0, 0] for i in tensor.shape] - pad_width[axis][1] = factor - (dim % factor) - return np.pad(tensor, pad_width, mode="constant", constant_values=0) - - def _clip( - x: np.ndarray, # pylint: disable=invalid-name - x_min: int, - x_max: int, - dtype: str, - ) -> np.ndarray: - return np.clip(x, a_min=x_min, a_max=x_max).astype(dtype) - - num_elem_per_storage = bits[storage_dtype] // bits[quantize_dtype] - assert group_size % num_elem_per_storage == 0 - num_storage_units = (group_size + num_elem_per_storage - 1) // num_elem_per_storage - - # using numpy for now - w = w.numpy() - - # Step 1. Tile `w`: [n, k'] -> [n, k, group_size] - w = _pad_axis_by_factor(w, axis=1, factor=group_size) - n, k = [int(v) for v in w.shape] # pylint: disable=invalid-name - assert k % group_size == 0, "Padding is not working properly" - k = k // group_size - w = w.reshape([n, k, group_size]) - - # Step 2. Calculate - if quantize_dtype.startswith("int"): - max_int_value = (2 ** (bits[quantize_dtype] - 1)) - 1 - # 1) `scale`: [n, k, group_size] -> [n, k] - scale = np.maximum(np.amax(w, axis=-1), 1e-4) / max_int_value - # 2) `w`: w / scale - - w = _clip( - np.round(w / scale[:, :, np.newaxis]).astype("int") + max_int_value, - x_min=0, - x_max=max_int_value * 2, - dtype=storage_dtype, - ) - else: - raise NotImplementedError - - # Step 3. Compress `w` to every `num_elem_per_storage` elements - res = np.zeros((n, k, num_storage_units), dtype=np.uint32) - for i in range(n): - for j in range(k): - for m in range(num_storage_units): # pylint: disable=invalid-name - for k in range(num_elem_per_storage): - res[i, j, m] += w[i, j, m * num_elem_per_storage + k] * 2**k - return tvm.nd.array(res), tvm.nd.array(scale) - # pylint: enable=too-many-locals - - -if __name__ == "__main__": - test_load_torch_llama_group_quantize( - base_path="./dist/models/Llama-2-7b-hf", - target="llvm", - ) - test_load_torch_llama_group_quantize( - base_path="./dist/models/Llama-2-7b-hf", - target="nvidia/nvidia-a100", - ) - test_load_torch_llama_group_quantize( - base_path="./dist/models/Llama-2-13b-hf", - target="llvm", - ) - test_load_torch_llama_group_quantize( - base_path="./dist/models/Llama-2-13b-hf", - target="nvidia/nvidia-a100", - ) diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py new file mode 100644 index 0000000000..75e754e147 --- /dev/null +++ b/tests/python/quantization/test_group_quantization.py @@ -0,0 +1,85 @@ +# pylint: disable=missing-docstring +from typing import List + +import numpy as np +import tvm +import tvm.testing +from mlc_chat.compiler import QUANTIZATION +from mlc_chat.compiler.quantization import GroupQuantize +from tvm import DataType + + +def quantize_np(config: GroupQuantize, weight: np.ndarray): + n, k = weight.shape + weight_padded = np.pad( + weight, ((0, 0), (0, (config.group_size - k % config.group_size) % config.group_size)) + ) + n, k = weight_padded.shape + weight_reshaped = np.reshape(weight_padded, (n, k // config.group_size, config.group_size)) + max_abs = np.maximum(np.max(np.abs(weight_reshaped), axis=-1), 1e-4) + scale = np.divide(max_abs, config.max_int_value) + scale_reshaped = np.reshape(scale, (*scale.shape, 1)) + weight_scaled_reshaped = np.clip( + np.add( + np.round(np.divide(weight_reshaped, scale_reshaped)), + config.max_int_value, + ), + 0, + config.max_int_value * 2, + ).astype(config.storage_dtype) + weight_scaled = np.reshape( + weight_scaled_reshaped, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) + ) + indice_k = np.indices(weight_scaled.shape, dtype=config.storage_dtype)[-1] + quantized_weight = np.sum( + np.left_shift(weight_scaled, indice_k * DataType(config.quantize_dtype).bits), + axis=-1, + dtype=config.storage_dtype, + ) + return quantized_weight, scale + + +def dequantize_np( + config: GroupQuantize, + weight: np.ndarray, + scale: np.ndarray, + out_shape: List[int] = None, +): + bin_mask = (1 << DataType(config.quantize_dtype).bits) - 1 + max_int = config.max_int_value + out_shape = ( + [weight.shape[0], weight.shape[1] * config.num_elem_per_storage] + if out_shape is None + else out_shape + ) + weight_repeated = np.repeat(weight, config.num_elem_per_storage, axis=-1) + scale_repeated = np.repeat(scale, config.group_size, axis=-1) + indice_j = np.indices(weight_repeated.shape)[1] + weight_bin = np.bitwise_and( + np.right_shift( + weight_repeated, + (indice_j % config.num_elem_per_storage) * DataType(config.storage_dtype).bits, + ), + bin_mask, + ) + return ((weight_bin - max_int) * scale_repeated)[: out_shape[0]][: out_shape[1]] + + +def test_quantize(quant_name: str, shape: List[int], dtype: str): + config = QUANTIZATION[quant_name] + assert isinstance(config, GroupQuantize) + weight_np = np.random.random(shape).astype(dtype) + output = config.quantize_weight(tvm.nd.array(weight_np, device=tvm.device("cuda"))) + quantized_weight, scale = output[0].numpy(), output[1].numpy() + quantized_weight_ref, scale_ref = quantize_np(config, weight_np) + tvm.testing.assert_allclose(scale, scale_ref, rtol=1e-3, atol=1e-3) + tvm.testing.assert_allclose( + dequantize_np(config, quantized_weight, scale, shape), + dequantize_np(config, quantized_weight_ref, scale_ref, shape), + rtol=1e-3, + atol=0.2, + ) + + +if __name__ == "__main__": + test_quantize("q4f16_1", [64, 4096], "float16") From 8438b2777878510ed865d6a79ef41d3db94942d0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 31 Oct 2023 12:09:16 -0700 Subject: [PATCH 075/116] Misc Cleanups of Compilation Pipeline (#1165) --- python/mlc_chat/cli/compile.py | 2 +- python/mlc_chat/compiler/compile.py | 8 +++++++ .../compiler_pass/clean_up_tir_attrs.py | 3 +-- .../compiler_pass/fuse_decode_matmul_ewise.py | 23 ++++++++++-------- .../compiler_pass/fuse_decode_take.py | 24 +++++++++++-------- .../compiler_pass/fuse_decode_transpose.py | 2 +- .../compiler_pass/fuse_transpose_matmul.py | 4 +--- .../compiler_pass/lift_global_buffer_alloc.py | 13 +++++----- .../compiler/compiler_pass/pipeline.py | 23 ++++++++++++++++++ 9 files changed, 68 insertions(+), 34 deletions(-) diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index e3ff778487..d4c648c097 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -15,7 +15,7 @@ from ..support.auto_target import detect_target_and_host logging.basicConfig( - level=logging.DEBUG, + level=logging.INFO, style="{", datefmt="%Y-%m-%d %H:%M:%S", format="[{asctime}] {levelname} {filename}:{lineno}: {message}", diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 88f33b03af..6cb204b000 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -1,5 +1,6 @@ """Python entrypoint of compilation.""" import dataclasses +import logging from io import StringIO from pathlib import Path from typing import Callable @@ -12,6 +13,8 @@ from .model import Model from .quantization import Quantization +logger = logging.getLogger(__name__) + @dataclasses.dataclass class CompileArgs: # pylint: disable=too-many-instance-attributes @@ -40,15 +43,20 @@ def _echo_args(args: CompileArgs) -> None: def _compile(args: CompileArgs): + logger.info("Creating model from: %s", args.config) model_config = args.model.config.from_file(args.config) quantization = args.quantization model, _ = args.model.quantize[quantization.kind](model_config, quantization) + logger.info("Exporting the model to TVM Unity compiler") mod, _named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore ) + logger.info("Running optimizations using TVM Unity") with args.target: mod = relax.get_pipeline("mlc_llm")(mod) + logger.info("Generating code using TVM Unity") args.build_func(mod, args) + logger.info("Code dumped to: %s", args.output) def compile( # pylint: disable=too-many-arguments,redefined-builtin diff --git a/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py b/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py index 71848ba546..f7c9ad2f48 100644 --- a/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py +++ b/python/mlc_chat/compiler/compiler_pass/clean_up_tir_attrs.py @@ -18,8 +18,7 @@ def transform_module( _ctx: tvm.transform.PassContext, ) -> IRModule: """IRModule-level transformation""" - for g_var in list(mod.functions): - func = mod[g_var] + for g_var, func in mod.functions_items(): changed = False for attr in self.attrs: if func.attrs is not None and attr in func.attrs: diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py index 0e02f2ae5a..ddc71818ff 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py @@ -14,20 +14,23 @@ def transform_module( _ctx: tvm.transform.PassContext, ) -> IRModule: """IRModule-level transformation""" + seq = [] for n_aux_tensor in [1, 2, 3, 4]: for match_ewise in [0, 1, 2, 6]: if match_ewise == 6 and n_aux_tensor != 4: continue - mod = relax.transform.FuseOpsByPattern( - [ - ( - "decode_matmul", - *_pattern(match_ewise, n_aux_tensor), - ) - ] - )(mod) - mod = relax.transform.FuseTIR()(mod) - return mod + seq.append( + relax.transform.FuseOpsByPattern( + [ + ( + "decode_matmul", + *_pattern(match_ewise, n_aux_tensor), + ) + ] + ) + ) + seq.append(relax.transform.FuseTIR()) + return tvm.transform.Sequential(seq)(mod) def _pattern(match_ewise: int, n_aux_tensor: int): diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py index 96678fa951..9468f4f425 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py @@ -20,18 +20,22 @@ def transform_module( _ctx: tvm.transform.PassContext, ) -> IRModule: """IRModule-level transformation""" + seq = [] for n_aux_tensor in [2, 3]: for match_tir_vars in [False, True]: - mod = relax.transform.FuseOpsByPattern( - [ - ( - "decode_take", - *_pattern(n_aux_tensor, match_tir_vars), - ) - ] - )(mod) - mod = relax.transform.FuseTIR()(mod) - for g_var, func in mod.functions.items(): + seq.append( + relax.transform.FuseOpsByPattern( + [ + ( + "decode_take", + *_pattern(n_aux_tensor, match_tir_vars), + ) + ] + ) + ) + seq.append(relax.transform.FuseTIR()) + mod = tvm.transform.Sequential(seq)(mod) + for g_var, func in mod.functions_items(): name = g_var.name_hint if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)): mod = tvm.IRModule({"main": func}) diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py index 99bcb1b602..e2a826a1fb 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py @@ -31,7 +31,7 @@ def __init__( def transform(self) -> IRModule: """Entry point""" - for g_var, func in self.mod.functions.items(): + for g_var, func in self.mod.functions_items(): if isinstance(func, relax.Function): updated_func = self.visit_expr(func) updated_func = remove_all_unused(updated_func) diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py b/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py index ac1de41377..5b3ecec860 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_transpose_matmul.py @@ -19,10 +19,8 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR ), ] )(mod) - transpose_matmul_codegen = _TransposeMatmulFuser(mod) - for g_var in mod.functions: - func = mod[g_var] + for g_var, func in mod.functions_items(): if isinstance(func, relax.Function): func = transpose_matmul_codegen.visit_expr(func) transpose_matmul_codegen.builder_.update_func(g_var, func) diff --git a/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py b/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py index dc8eaa5bdc..ebf8f27acf 100644 --- a/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py +++ b/python/mlc_chat/compiler/compiler_pass/lift_global_buffer_alloc.py @@ -32,7 +32,7 @@ def __init__(self, mod: IRModule): def transform(self) -> IRModule: """Entry point of the transformation""" - for g_var, func in self.mod.functions.items(): + for g_var, func in self.mod.functions_items(): if isinstance(func, tir.PrimFunc): updated_func, tensor_sinfo_list = remove_global_buf_alloc(func) if len(tensor_sinfo_list) > 0: @@ -40,12 +40,11 @@ def transform(self) -> IRModule: self.builder_.update_func(g_var, updated_func) self.mod = self.builder_.get() - for g_var, func in self.mod.functions.items(): - if not isinstance(func, relax.Function): - continue - updated_func = self.visit_expr(func) - updated_func = remove_all_unused(updated_func) - self.builder_.update_func(g_var, updated_func) + for g_var, func in self.mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + updated_func = remove_all_unused(updated_func) + self.builder_.update_func(g_var, updated_func) return self.builder_.get() def visit_call_(self, call: relax.Call): # pylint: disable=arguments-renamed diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler/compiler_pass/pipeline.py index 349a5af0f0..43fc8f131c 100644 --- a/python/mlc_chat/compiler/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler/compiler_pass/pipeline.py @@ -1,5 +1,8 @@ """The compilation pipeline for LLM applications.""" +import logging + import tvm +from tvm import IRModule from tvm import dlight as dl from tvm.relax import register_pipeline # pylint: disable=no-name-in-module @@ -10,6 +13,21 @@ from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc +logger = logging.getLogger(__name__) + + +@tvm.transform.module_pass(opt_level=0, name="_LogProgress") +class _LogProgress: # pylint: disable=too-few-public-methods + """A dummy compiler pass that does nothing but logging.""" + + def __init__(self, *args): + self.args = args + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """A dummy transformation""" + logger.info(*self.args) + return mod + @register_pipeline("mlc_llm") def _mlc_llm_pipeline(): @@ -18,20 +36,24 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I seq = tvm.transform.Sequential( [ # Phase 1. Passes on high-level operator graph + _LogProgress("Running TVM Relax graph-level optimizations"), FuseDecodeTranspose(skip_gemm=False), FuseTransposeMatmul(), # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline + _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.transform.LegalizeOps(), tvm.relax.transform.AnnotateTIROpPattern(), tvm.relax.transform.FoldConstant(), tvm.relax.transform.FuseOps(), tvm.relax.transform.FuseTIR(), # Phase 3. Passes on TIR + _LogProgress("Running TVM TIR-level optimizations"), FuseDecodeMatmulEwise(), FuseDecodeTake(), tvm.relax.transform.DeadCodeElimination(), CleanUpTIRAttrs(["op_pattern"]), # Phase 4. Low-level Optimizations + _LogProgress("Running TVM Dlight low-level optimizations"), dl.ApplyDefaultSchedule( dl.gpu.Matmul(), dl.gpu.GEMV(), @@ -39,6 +61,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I dl.gpu.GeneralReduction(), dl.gpu.Fallback(), ), + _LogProgress("Running memory optimizations"), LiftTIRGlobalBufferAlloc(), tvm.tir.transform.ForceNarrowIndexToInt32(), ] From 02d1e57404b3362362f1ec44e48558e5af637d19 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 31 Oct 2023 12:43:17 -0700 Subject: [PATCH 076/116] Support CUDA Multi-Arch Compilation (#1166) --- python/mlc_chat/compiler/compile.py | 2 +- python/mlc_chat/support/auto_target.py | 36 +++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 6cb204b000..02842a1903 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -56,7 +56,7 @@ def _compile(args: CompileArgs): mod = relax.get_pipeline("mlc_llm")(mod) logger.info("Generating code using TVM Unity") args.build_func(mod, args) - logger.info("Code dumped to: %s", args.output) + logger.info("Code dumped to: %s", bold(str(args.output))) def compile( # pylint: disable=too-many-arguments,redefined-builtin diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index f31e813410..29328f8813 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -1,5 +1,6 @@ """Helper functioms for target auto-detection.""" import logging +import os from typing import TYPE_CHECKING, Callable, Optional, Tuple from tvm import IRModule, relax @@ -7,7 +8,7 @@ from tvm.contrib import tar, xcode from tvm.target import Target -from .style import green, red +from .style import bold, green, red if TYPE_CHECKING: from mlc_chat.compiler.compile import CompileArgs @@ -38,6 +39,8 @@ def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, Bu target, build_func = _detect_target_gpu(target_hint) if target.host is None: target = Target(target, host=_detect_target_host(host_hint)) + if target.kind.name == "cuda": + _register_cuda_hook(target) return target, build_func @@ -223,6 +226,37 @@ def build(mod: IRModule, args: "CompileArgs"): return build +def _register_cuda_hook(target: Target): + env_multi_arch = os.environ.get("MLC_MULTI_ARCH", None) + if env_multi_arch is None: + default_arch = target.attrs.get("arch", None) + logger.info("Generating code for CUDA architecture: %s", bold(default_arch)) + logger.info( + "To produce multi-arch fatbin, set environment variable %s. " + "Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90", + bold("MLC_MULTI_ARCH"), + ) + multi_arch = None + else: + logger.info("%s %s: %s", FOUND, bold("MLC_MULTI_ARCH"), env_multi_arch) + multi_arch = [int(x.strip()) for x in env_multi_arch.split(",")] + logger.info("Generating code for CUDA architecture: %s", multi_arch) + + @register_func("tvm_callback_cuda_compile", override=True) + def tvm_callback_cuda_compile(code, target): # pylint: disable=unused-argument + """use nvcc to generate fatbin code for better optimization""" + from tvm.contrib import nvcc # pylint: disable=import-outside-toplevel + + if multi_arch is None: + ptx = nvcc.compile_cuda(code, target_format="fatbin") + else: + arch = [] + for compute_version in multi_arch: + arch += ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"] + ptx = nvcc.compile_cuda(code, target_format="fatbin", arch=arch) + return ptx + + AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan"] PRESET = { From e0cd3f6ef919f53f69b235c6542e7154797da0fb Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 31 Oct 2023 12:56:28 -0700 Subject: [PATCH 077/116] [Bugfix] Cannot find global function `mlc.llm_chat_create` (#1167) --- python/mlc_chat/chat_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 058557c182..9f306c14b6 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -13,6 +13,7 @@ import tvm from tvm.runtime import disco # pylint: disable=unused-import +from .base import _LIB # pylint: disable=unused-import from .interface.openai_api import ChatMessage # pylint: disable=line-too-long From f5b2e885c83add48c68ebf370c2b0a9459c574f7 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Wed, 1 Nov 2023 12:23:40 +0800 Subject: [PATCH 078/116] Fix RWKV Support (#1136) I successfully ran the rwkv-world-3b fp16 model on my Xiaomi phone. This PR is to fix a bug on the main branch where the rwkv model outputs only one word and then stop. ![image](https://github.com/mlc-ai/mlc-llm/assets/35585791/6514d6ef-c93c-4ad2-8e76-8ffa0663080f) --- cpp/llm_chat.cc | 5 ++++- mlc_llm/core.py | 7 ++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 35a8d1f41e..f8c7ef0986 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -1108,7 +1108,10 @@ class LLMChat { if (static_cast(output_ids_.size()) >= gen_max_gen_len) { stop_triggered_ = true; - } else if (total_seq_len_ >= max_window_size_) { + } + // max_window_size_ != -1 to handle + // https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/rwkv.py#L588-L589 + else if (max_window_size_ != -1 && total_seq_len_ >= max_window_size_) { stop_triggered_ = true; } if (stop_triggered_) { diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 6b993c07b5..d914364b9c 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -546,6 +546,7 @@ def dump_mlc_chat_config( mean_gen_len: int = 128, max_gen_len: int = 512, shift_fill_factor: float = 0.3, + rwkv_world=False, ): args.params_path = os.path.join(args.artifact_path, "params") config: Dict[str, Any] = {} @@ -567,7 +568,10 @@ def dump_mlc_chat_config( config["max_window_size"] = max_window_size config["num_shards"] = args.num_shards config["shift_fill_factor"] = shift_fill_factor - config["tokenizer_files"] = utils.get_tokenizer_files(args.params_path) + if rwkv_world: + config["tokenizer_files"] = ["tokenizer_model"] + else: + config["tokenizer_files"] = utils.get_tokenizer_files(args.params_path) config["model_category"] = args.model_category config["model_name"] = args.model config["vocab_size"] = vocab_size @@ -709,6 +713,7 @@ def build_model_from_args(args: argparse.Namespace): top_p=0.6, temperature=1.2, repetition_penalty=0.996, + rwkv_world=True, ) else: dump_mlc_chat_config( From 200653a82d025be7d58d0d7f04442f85aee52c98 Mon Sep 17 00:00:00 2001 From: Git bot Date: Wed, 1 Nov 2023 14:53:54 +0000 Subject: [PATCH 079/116] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 30b4fa3c13..3001b20b0d 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 30b4fa3c13fc80d5c9151a9dc445d22c57ced3e0 +Subproject commit 3001b20b0dd114cad23fccb25cbb055ce80a224e From 9831135cd2df570dc376749f548e45e2ed98d75e Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Wed, 1 Nov 2023 15:16:09 -0400 Subject: [PATCH 080/116] Fix Android app Permission denied error on Android 10 (#1175) Use scoped storage instead of Downloads directory Co-authored-by: Animesh Bohara --- .../MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt index f51d56ec10..6a760efde3 100644 --- a/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt +++ b/android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt @@ -171,7 +171,7 @@ class AppViewModel(application: Application) : AndroidViewModel(application) { val url = URL("${modelUrl}${ModelUrlSuffix}${ModelConfigFilename}") val tempId = UUID.randomUUID().toString() val tempFile = File( - Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS), + application.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS), tempId ) url.openStream().use { From 1757777591f931a29ad9490c7686cd8a5ec49788 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 1 Nov 2023 15:52:19 -0700 Subject: [PATCH 081/116] [SLM] Fix group quantization (#1172) This PR fixes the group quantization and add related unit tests. --- .../quantization/group_quantization.py | 324 ++++++++++++++++-- tests/python/model/test_llama_quantization.py | 32 ++ .../quantization/test_group_quantization.py | 77 ++++- 3 files changed, 399 insertions(+), 34 deletions(-) create mode 100644 tests/python/model/test_llama_quantization.py diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 5bfaf084b2..9a5abed4ad 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -1,8 +1,9 @@ """The group quantization config""" -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple -from tvm import DataType, DataTypeCode, device +import numpy as np +from tvm import DataType, DataTypeCode from tvm import dlight as dl from tvm import relax, te, tir from tvm.relax.frontend import nn @@ -27,6 +28,10 @@ class GroupQuantize: # pylint: disable=too-many-instance-attributes num_storage_per_group: int = 0 max_int_value: int = 0 + prebuilt_quantize_func: Dict[str, Callable[[NDArray], NDArray]] = field( + default_factory=lambda: {} + ) + def __post_init__(self): assert self.kind == "group-quant" quantize_dtype = DataType(self.quantize_dtype) @@ -50,7 +55,25 @@ def quantize_model( quant_map: QuantizeMapping, name_prefix: str, ) -> nn.Module: - """Quantize model with group quantization""" + """ + Quantize model with group quantization + + Parameters + ---------- + model : nn.Module + The non-quantized nn.Module. + + quant_map : QuantizeMapping + The quantize mapping with name mapping and func mapping. + + name_prefix : str + The name prefix for visited weight. + + Returns + ------- + ret : nn.Module + The quantized nn.Module. + """ class _Mutator(nn.Mutator): def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None: @@ -59,11 +82,37 @@ def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None: self.quant_map = quant_map def visit_module(self, name: str, node: nn.Module) -> Any: + """ + The visiting method for group quantization of nn.Module nodes. + + Parameters + ---------- + name : str + The name of the current node. + + node : nn.Module + The current node of nn.Module to mutate. + + Returns + ------ + ret_node: Any + The new node to replace current node. + """ if isinstance(node, nn.Linear): weight_name = f"{name}.weight" self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] - # self.quant_map.map_func[weight_name] = self.config.quantize + self.quant_map.map_func[weight_name] = self.config.quantize_weight return GroupQuantizeLinear.from_linear(node, self.config) + if isinstance(node, nn.MultiLinear): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return GroupQuantizeMultiLinear.from_multilinear(node, self.config) + if isinstance(node, nn.Embedding): + weight_name = f"{name}.weight" + self.quant_map.param_map[weight_name] = [f"{name}.q_weight", f"{name}.q_scale"] + self.quant_map.map_func[weight_name] = self.config.quantize_weight + return GroupQuantizeEmbedding.from_embedding(node, self.config) return self.visit(name, node) model.to(dtype=self.model_dtype) @@ -77,9 +126,7 @@ def _dequantize( scale: te.Tensor, out_shape: Optional[List[tir.PrimExpr]] = None, ): - quantize_dtype = DataType(self.quantize_dtype) - storage_dtype = DataType(self.storage_dtype) - tir_bin_mask = tir.const((2**quantize_dtype.bits) - 1, self.storage_dtype) + tir_bin_mask = tir.const((1 << DataType(self.quantize_dtype).bits) - 1, self.storage_dtype) tir_max_int = tir.const(self.max_int_value, self.model_dtype) dequantized_weight = te.compute( shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage] @@ -90,7 +137,11 @@ def _dequantize( tir.bitwise_and( tir.shift_right( weight[i, j // self.num_elem_per_storage], - (j % self.num_elem_per_storage) * storage_dtype.bits, + tir.Cast( + self.storage_dtype, + (j % self.num_elem_per_storage) + * DataType(self.quantize_dtype).bits, + ), ), tir_bin_mask, ), @@ -102,24 +153,50 @@ def _dequantize( return dequantized_weight def quantize_weight(self, weight: NDArray) -> List[NDArray]: - """Quantize weight with group quantization""" + """ + Quantize weight with group quantization + + Parameters + ---------- + weight : NDArray + The original weight. + + Returns + ------ + ret: List[NDArray] + The list of group quantized weights. + """ assert weight.dtype == self.model_dtype assert len(weight.shape) == 2 - bb = relax.BlockBuilder() + dev = weight.device + device_type = dev.MASK2STR[dev.device_type] + key = str((int(weight.shape[0]), int(weight.shape[1]), weight.dtype, device_type)) + if key in self.prebuilt_quantize_func: + return self.prebuilt_quantize_func[key](weight) + bb = relax.BlockBuilder() # pylint: disable=invalid-name weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, self.model_dtype)) with bb.function(name="quantize", params=[weight_var]): with bb.dataflow(): - lv = bb.emit_te(self._quantize, weight_var) - gv = bb.emit_output(lv) + lv = bb.emit_te(self._quantize, weight_var) # pylint: disable=invalid-name + gv = bb.emit_output(lv) # pylint: disable=invalid-name bb.emit_func_output(gv) mod = bb.get() - with Target("cuda"): - mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable - dl.gpu.Reduction(), dl.gpu.GeneralReduction(), dl.gpu.Fallback() - )(mod) - ex = relax.build(mod, "cuda") - dev = device("cuda", 0) - vm = relax.VirtualMachine(ex, dev) + if device_type in ["cuda", "rocm", "metal", "vulkan"]: + target = Target.current() + if target is None: + target = Target.from_device(dev) + with target: + mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable + dl.gpu.Reduction(), dl.gpu.GeneralReduction(), dl.gpu.Fallback() + )(mod) + elif device_type == "cpu": + target = "llvm" + mod = relax.transform.LegalizeOps()(mod) + else: + raise NotImplementedError + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, dev) # pylint: disable=invalid-name + self.prebuilt_quantize_func[key] = vm["quantize"] return vm["quantize"](weight) def _quantize( # pylint: disable=too-many-locals @@ -164,7 +241,7 @@ def _quantize( # pylint: disable=too-many-locals ) # compute quantized weight per storage - r = te.reduce_axis((0, self.num_elem_per_storage), name="r") + r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name num_storage = self.num_storage_per_group * num_group quantized_weight_shape = (n, num_storage) quantized_weight = te.compute( @@ -195,20 +272,36 @@ def __init__( # pylint: disable=too-many-arguments self.out_features = out_features self.out_dtype = out_dtype self.config = config - n_group = tir.ceildiv(in_features, config.group_size) self.weight = nn.Parameter( - (out_features, n_group * config.num_elem_per_storage), + (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), config.storage_dtype, ) - self.scale = nn.Parameter((out_features, n_group), config.model_dtype) + self.scale = nn.Parameter( + (out_features, tir.ceildiv(in_features, config.group_size)), config.model_dtype + ) if bias: self.bias = nn.Parameter((out_features,), config.model_dtype) else: self.bias = None @staticmethod - def from_linear(linear: nn.Linear, config: GroupQuantize): - """Converts a non-quantized nn.Linear to a quantized GroupQuantizeLinear""" + def from_linear(linear: nn.Linear, config: GroupQuantize) -> "GroupQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a group quantized GroupQuantizeLinear + + Parameters + ---------- + linear : nn.Linear + The non-quantized nn.Linear. + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeLinear + The group quantized GroupQuantizeLinear layer. + """ return GroupQuantizeLinear( in_features=linear.in_features, out_features=linear.out_features, @@ -217,21 +310,194 @@ def from_linear(linear: nn.Linear, config: GroupQuantize): out_dtype=linear.out_dtype, ) - def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name,missing-docstring + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for group quantized linear layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the group quantized linear layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + [tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)], + ), + name_hint="decode", + args=[self.weight, self.scale], + ) + w = nn.op.permute_dims(w) # pylint: disable=invalid-name + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + return x + + +class GroupQuantizeMultiLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.MultiLinear module with group quantization""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: nn.Sequence[int], + config: GroupQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ): + assert len(out_features) > 0 + self.total_out_features = sum(out_features) + + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + self.weight = nn.Parameter( + (self.total_out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), + config.storage_dtype, + ) + self.scale = nn.Parameter( + (self.total_out_features, tir.ceildiv(in_features, config.group_size)), + config.model_dtype, + ) + if bias: + self.bias = nn.Parameter((self.total_out_features,), config.model_dtype) + else: + self.bias = None + + @staticmethod + def from_multilinear( + multi_linear: nn.MultiLinear, config: GroupQuantize + ) -> "GroupQuantizeMultiLinear": + """ + Converts a non-quantized nn.MultiLinear to a group quantized GroupQuantizeLinear + + Parameters + ---------- + linear : nn.Linear + The non-quantized nn.Linear. + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeMultiLinear + The group quantized GroupQuantizeMultiLinear layer. + """ + return GroupQuantizeMultiLinear( + in_features=multi_linear.in_features, + out_features=multi_linear.out_features, + config=config, + bias=getattr(multi_linear, "bias", None) is not None, + out_dtype=multi_linear.out_dtype, + ) + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for multi linear layer. + + Parameters + ---------- + x : Tensor + The input tensor. + + Returns + ------- + ret : Tensor + The output tensor for the multi linear layer. + """ + sections = list(np.cumsum(self.out_features)[:-1]) w = nn.op.tensor_expr_op( # pylint: disable=invalid-name lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access weight, scale, [ - tir.IntImm("int64", self.out_features), + tir.IntImm("int64", self.total_out_features), tir.IntImm("int64", self.in_features), ], ), name_hint="decode", args=[self.weight, self.scale], ) + # x: [*B, in_features] + # w: [in_features, out_features] w = nn.op.permute_dims(w) # pylint: disable=invalid-name + # x: [*B, out_features] x = nn.op.matmul(x, w, out_dtype=self.out_dtype) if self.bias is not None: x = x + self.bias - return x + results = nn.op.split(x, sections, axis=-1) + return results + + +class GroupQuantizeEmbedding(nn.Module): + """An nn.Embedding module with group quantization""" + + def __init__(self, num: int, dim: int, config: GroupQuantize): + self.num = num + self.dim = dim + self.config = config + n_group = tir.ceildiv(dim, config.group_size) + self.weight = nn.Parameter( + (num, n_group * config.num_elem_per_storage), config.storage_dtype + ) + self.scale = nn.Parameter((num, n_group), config.model_dtype) + + @staticmethod + def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> "GroupQuantizeEmbedding": + """ + Converts a non-quantized nn.Embedding to a group quantized GroupQuantizeEmbedding + + Parameters + ---------- + linear : nn.Embedding + The non-quantized nn.Embedding. + + config : GroupQuantize + The group quantization config. + + Returns + ------- + ret : GroupQuantizeEmbedding + The group quantized GroupQuantizeEmbedding layer. + """ + num, dim = embedding.weight.shape + return GroupQuantizeEmbedding(num, dim, config) + + def forward(self, x: nn.Tensor): # pylint: disable=invalid-name + """ + Forward method for group quantized embedding layer. + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the embedding layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + scale, + [tir.IntImm("int64", self.num), tir.IntImm("int64", self.dim)], + ), + name_hint="decode", + args=[self.weight, self.scale], + ) + if x.ndim == 1: + return nn.op.take(w, x, axis=0) + return nn.op.reshape( + nn.op.take(w, nn.op.reshape(x, shape=[-1]), axis=0), + shape=[*x.shape, self.dim], + ) diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py new file mode 100644 index 0000000000..92e2b4c1d6 --- /dev/null +++ b/tests/python/model/test_llama_quantization.py @@ -0,0 +1,32 @@ +# pylint: disable=invalid-name,missing-docstring +from mlc_chat.compiler import MODELS, QUANTIZATION +from mlc_chat.compiler.quantization.group_quantization import ( + GroupQuantizeEmbedding, + GroupQuantizeLinear, + GroupQuantizeMultiLinear, +) + + +def test_llama2_group_quantization(model_name: str, quant_name: str): + model_info = MODELS["llama"] + config = model_info.config.from_predefined(model_name) + model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) + assert "model.model.embed_tokens.weight" in quant_map.param_map + assert isinstance(model.model.embed_tokens, GroupQuantizeEmbedding) + assert "model.lm_head.weight" in quant_map.param_map + assert isinstance(model.lm_head, GroupQuantizeLinear) + for i in range(config.num_hidden_layers): + assert f"model.model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map + assert isinstance(model.model.layers[i].self_attn.qkv_proj, GroupQuantizeMultiLinear) + assert f"model.model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map + assert isinstance(model.model.layers[i].self_attn.o_proj, GroupQuantizeLinear) + assert f"model.model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map + assert isinstance(model.model.layers[i].mlp.gate_up_proj, GroupQuantizeMultiLinear) + assert f"model.model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map + assert isinstance(model.model.layers[i].mlp.down_proj, GroupQuantizeLinear) + + +if __name__ == "__main__": + test_llama2_group_quantization("llama2_7b", "q4f16_1") + test_llama2_group_quantization("llama2_13b", "q4f16_1") + test_llama2_group_quantization("llama2_70b", "q4f16_1") diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 75e754e147..663f7b8e78 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -1,12 +1,19 @@ -# pylint: disable=missing-docstring +# pylint: disable=invalid-name,missing-docstring from typing import List import numpy as np +import torch import tvm import tvm.testing from mlc_chat.compiler import QUANTIZATION +from mlc_chat.compiler.parameter import QuantizeMapping from mlc_chat.compiler.quantization import GroupQuantize +from mlc_chat.compiler.quantization.group_quantization import ( + GroupQuantizeEmbedding, + GroupQuantizeLinear, +) from tvm import DataType +from tvm.relax.frontend import nn def quantize_np(config: GroupQuantize, weight: np.ndarray): @@ -58,18 +65,18 @@ def dequantize_np( weight_bin = np.bitwise_and( np.right_shift( weight_repeated, - (indice_j % config.num_elem_per_storage) * DataType(config.storage_dtype).bits, + (indice_j % config.num_elem_per_storage) * DataType(config.quantize_dtype).bits, ), bin_mask, ) return ((weight_bin - max_int) * scale_repeated)[: out_shape[0]][: out_shape[1]] -def test_quantize(quant_name: str, shape: List[int], dtype: str): +def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str): config = QUANTIZATION[quant_name] assert isinstance(config, GroupQuantize) weight_np = np.random.random(shape).astype(dtype) - output = config.quantize_weight(tvm.nd.array(weight_np, device=tvm.device("cuda"))) + output = config.quantize_weight(tvm.nd.array(weight_np, device=tvm.device(device))) quantized_weight, scale = output[0].numpy(), output[1].numpy() quantized_weight_ref, scale_ref = quantize_np(config, weight_np) tvm.testing.assert_allclose(scale, scale_ref, rtol=1e-3, atol=1e-3) @@ -81,5 +88,65 @@ def test_quantize(quant_name: str, shape: List[int], dtype: str): ) +def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): + class Test(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype) + + def forward(self, x: nn.Tensor): + return self.linear(x) + + config = QUANTIZATION[quant_name] + assert isinstance(config, GroupQuantize) + weight_np = np.random.randint( + np.iinfo(config.storage_dtype).min, + np.iinfo(config.storage_dtype).max, + (shape[0], shape[1] // config.num_elem_per_storage), + ).astype(config.storage_dtype) + scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype( + config.model_dtype + ) + mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") + mod.linear.weight.data = weight_np + mod.linear.scale.data = scale_np + model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) + out = model["forward"]( + torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member + ) + ref = dequantize_np(config, weight_np, scale_np).T + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +def test_quantize_model(quant_name: str, shape: List[int], dtype: str): + class Test(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(shape[0], shape[1], dtype=dtype) + self.embedding = nn.Embedding(shape[0], shape[1], dtype=dtype) + + def forward(self, x: nn.Tensor): + return self.linear(x) + + config = QUANTIZATION[quant_name] + assert isinstance(config, GroupQuantize) + quant_map = QuantizeMapping({}, {}) + mod = config.quantize_model(Test(), quant_map, "model") + assert quant_map.param_map["model.linear.weight"] == [ + "model.linear.q_weight", + "model.linear.q_scale", + ] + assert quant_map.map_func["model.linear.weight"] == config.quantize_weight + assert isinstance(mod.linear, GroupQuantizeLinear) + assert quant_map.param_map["model.embedding.weight"] == [ + "model.embedding.q_weight", + "model.embedding.q_scale", + ] + assert quant_map.map_func["model.embedding.weight"] == config.quantize_weight + assert isinstance(mod.embedding, GroupQuantizeEmbedding) + + if __name__ == "__main__": - test_quantize("q4f16_1", [64, 4096], "float16") + test_quantize_weight("q4f16_1", [16, 128], "float16", "llvm") + test_quantize_model("q4f16_1", [16, 128], "float16") + test_dequantize_weight("q4f16_1", [16, 128], "float16") From 2ca7d15b765205cc6ad0ea8b8982f1b5477952f0 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 2 Nov 2023 11:30:28 -0700 Subject: [PATCH 082/116] [Fix] TIR block name of dequantization (#1177) --- .../compiler/compiler_pass/fuse_decode_take.py | 5 +++-- .../compiler/quantization/group_quantization.py | 14 ++++++++------ python/mlc_chat/support/auto_config.py | 2 +- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py index 9468f4f425..f2022c1161 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py @@ -38,8 +38,9 @@ def transform_module( for g_var, func in mod.functions_items(): name = g_var.name_hint if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)): - mod = tvm.IRModule({"main": func}) - sch = tir.Schedule(mod) + sch_mod = tvm.IRModule({"main": func}) + sch_mod = tir.transform.ForceNarrowIndexToInt32()(sch_mod) + sch = tir.Schedule(sch_mod) sch.compute_inline("decode") mod[g_var] = sch.mod["main"] return mod diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 9a5abed4ad..47ad82ac11 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -1,6 +1,6 @@ """The group quantization config""" from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import numpy as np from tvm import DataType, DataTypeCode @@ -128,7 +128,7 @@ def _dequantize( ): tir_bin_mask = tir.const((1 << DataType(self.quantize_dtype).bits) - 1, self.storage_dtype) tir_max_int = tir.const(self.max_int_value, self.model_dtype) - dequantized_weight = te.compute( + return te.compute( shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage] if out_shape is None else out_shape, @@ -149,8 +149,8 @@ def _dequantize( ), scale[i, j // self.group_size], ), + name="decode", ) - return dequantized_weight def quantize_weight(self, weight: NDArray) -> List[NDArray]: """ @@ -186,8 +186,10 @@ def quantize_weight(self, weight: NDArray) -> List[NDArray]: if target is None: target = Target.from_device(dev) with target: - mod = dl.ApplyDefaultSchedule( # pylint: disable=not-callable - dl.gpu.Reduction(), dl.gpu.GeneralReduction(), dl.gpu.Fallback() + mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), )(mod) elif device_type == "cpu": target = "llvm" @@ -400,7 +402,7 @@ def from_multilinear( out_dtype=multi_linear.out_dtype, ) - def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + def forward(self, x: nn.Tensor) -> Sequence[nn.Tensor]: # pylint: disable=invalid-name """ Forward method for multi linear layer. diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 61a84b4041..0546e49252 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -95,7 +95,7 @@ def detect_model_type(model_type: str, config: Path) -> "Model": f"Please explicitly specify `--model-type` instead" ) model_type = cfg["model_type"] - logger.info("%s Model type: %s", FOUND, model_type) + logger.info("%s model type: %s", FOUND, model_type) if model_type not in MODELS: raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") return MODELS[model_type] From 53060af2367b38320ad10057d843b5737b14970d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 2 Nov 2023 13:08:11 -0700 Subject: [PATCH 083/116] [SLM][AutoLLM] Enable Command Line Weight Conversion (#1170) This PR enables weight conversion in command line. Sample command: `python3 -m mlc_chat.cli.convert_weight --config dist/models/llama-2-13b-chat-hf/ --quantization "q4f16_1" --output dist/test/` --- python/mlc_chat/cli/convert_weight.py | 150 ++++++++++++++++++ .../compiler/model/llama_quantization.py | 2 +- .../compiler/parameter/huggingface_loader.py | 48 ++++-- python/mlc_chat/compiler/parameter/utils.py | 15 +- .../quantization/group_quantization.py | 2 +- python/mlc_chat/support/auto_target.py | 4 +- python/mlc_chat/support/auto_weight.py | 66 ++++---- tests/python/support/test_auto_weight.py | 92 +++++------ 8 files changed, 276 insertions(+), 103 deletions(-) create mode 100644 python/mlc_chat/cli/convert_weight.py diff --git a/python/mlc_chat/cli/convert_weight.py b/python/mlc_chat/cli/convert_weight.py new file mode 100644 index 0000000000..45fe0fa286 --- /dev/null +++ b/python/mlc_chat/cli/convert_weight.py @@ -0,0 +1,150 @@ +"""Command line entrypoint of weight conversion.""" +import argparse +import logging +from pathlib import Path +from typing import Union + +import tvm +from mlc_chat.compiler import MODELS, QUANTIZATION +from mlc_chat.compiler.parameter import HuggingFaceLoader +from mlc_chat.support import tqdm +from tvm.contrib import tvmjs + +from ..support.auto_config import detect_config, detect_model_type +from ..support.auto_target import detect_target_and_host +from ..support.auto_weight import detect_weight + +logging.basicConfig( + level=logging.INFO, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def main(): + """Parse command line argumennts and apply quantization.""" + + def _parse_config(path: Union[str, Path]) -> Path: + try: + return detect_config(path) + except ValueError as err: + raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}") + + def _parse_source(path: Union[str, Path], config_path: Path) -> Path: + if path == "auto": + return config_path.parent + path = Path(path) + if not path.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {path}") + return path + + def _parse_output(path: Union[str, Path]) -> Path: + path = Path(path) + if not path.is_dir(): + path.mkdir(parents=True, exist_ok=True) + return path + + parser = argparse.ArgumentParser("MLC AutoLLM Quantization Framework") + parser.add_argument( + "--config", + type=_parse_config, + required=True, + help="Path to config.json file or to the directory that contains config.json, which is " + "a HuggingFace standard that defines model architecture, for example, " + "https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json", + ) + parser.add_argument( + "--source", + type=str, + required=False, + default="auto", + help="The path to original model weight, infer from `config` if missing", + ) + parser.add_argument( + "--source-format", + type=str, + required=False, + choices=["auto", "huggingface-torch", "huggingface-safetensor"], + default="auto", + help="The format of source model weight, infer from `config` if missing", + ) + parser.add_argument( + "--quantization", + type=str, + required=True, + choices=list(QUANTIZATION.keys()), + help="Quantization format, for example `q4f16_1`.", + ) + parser.add_argument( + "--model-type", + type=str, + default="auto", + choices=["auto"] + list(MODELS.keys()), + help="Model architecture, for example, llama. If not set, it is inferred " + "from the config.json file.", + ) + parser.add_argument( + "--device", + type=str, + default="auto", + help="The device used to do quantization, \ + for example `auto` / `cuda:0` / `cuda --arch sm86`", + ) + parser.add_argument( + "--output", + "-o", + type=_parse_output, + required=True, + help="The output directory to save the quantized model weight, " + "will contain `params_shard_*.bin` and `ndarray-cache.json`.", + ) + + # parse arguments + parsed = parser.parse_args() + parsed.source = _parse_source(parsed.source, parsed.config) + parsed.params, parsed.source_format = detect_weight( + parsed.source, parsed.config, weight_format=parsed.source_format + ) + model = detect_model_type(parsed.model_type, parsed.config) + + # detect quantization target + quantization_target, _ = detect_target_and_host(parsed.device) + if parsed.device != "auto": + device = tvm.runtime.device(parsed.device.split(" ")[0]) + else: + if quantization_target.kind.name == "cuda": + device = tvm.cuda(0) + else: + device = tvm.cpu(0) + + # model config & quantization config + model_config = model.config.from_file(parsed.config) + quantization_config = QUANTIZATION[parsed.quantization] + _, quantize_map = model.quantize[quantization_config.kind](model_config, quantization_config) + + # loader setup + if parsed.source_format in ("huggingface-torch", "huggingface-safetensor"): + loader = HuggingFaceLoader( + path=parsed.params, + extern_param_map=model.source[parsed.source_format](model_config, None), + quantize_param_map=quantize_map, + ) + else: + raise ValueError(f"Unsupported loader source format: {parsed.source_format}") + + # load and quantize + with quantization_target, tqdm.redirect(): + param_dict = dict(loader.load(device=device)) + + # dump to output directory + tvmjs.dump_ndarray_cache( + param_dict, + f"{parsed.output}/params", + meta_data={"ParamSize": len(param_dict)}, + encode_format="raw", + ) + + +if __name__ == "__main__": + main() diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py index a5f8f0b0df..02376ab9db 100644 --- a/python/mlc_chat/compiler/model/llama_quantization.py +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -19,6 +19,6 @@ def group_quant( model = quantization.quantize_model( model, quant_map, - "model", + "", ) return model, quant_map diff --git a/python/mlc_chat/compiler/parameter/huggingface_loader.py b/python/mlc_chat/compiler/parameter/huggingface_loader.py index ed91255c81..550dec3071 100644 --- a/python/mlc_chat/compiler/parameter/huggingface_loader.py +++ b/python/mlc_chat/compiler/parameter/huggingface_loader.py @@ -9,7 +9,7 @@ import numpy as np from tqdm import tqdm -from tvm.runtime import NDArray +from tvm.runtime import Device, NDArray from tvm.runtime.ndarray import array as as_ndarray from .mapping import ExternMapping, QuantizeMapping @@ -100,24 +100,45 @@ def __init__( raise FileNotFoundError(f"Unknown file suffix: {path}") check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) - def load(self) -> Iterator[Tuple[str, NDArray]]: - """Load the parameters and yield the MLC parameter and its value.""" + def load(self, device: Optional[Device] = None) -> Iterator[Tuple[str, NDArray]]: + """Load the parameters and yield the MLC parameter and its value. + + Parameters + ---------- + device : Optional[Device] + The device to store the parameter, default to None, which means using CPU. + + Yields + ------ + Tuple[str, NDArray] + The MLC parameter name and its value, quantized if quantization mapping is provided. + """ mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) for mlc_name in tqdm(mlc_names): - param = self._load_mlc_param(mlc_name) + param = self._load_mlc_param(mlc_name, device=device) + if self.quantize_param_map: with self.stats.timer("quant_time_sec"): quantized_params = ParamQuantizer(self.quantize_param_map).quantize( mlc_name, param ) - for quantized_name, quantized_param in quantized_params: + if not quantized_params: logger.info( - ' Quantized Parameter: "%s", shape: %s, dtype: %s', - quantized_name, - quantized_param.shape, - quantized_param.dtype, + ' Skipped Quantizing Parameter: "%s", shape: %s, dtype: %s', + mlc_name, + param.shape, + param.dtype, ) - yield quantized_name, quantized_param + yield mlc_name, param + else: + for quantized_name, quantized_param in quantized_params: + logger.info( + ' Quantized Parameter: "%s", shape: %s, dtype: %s', + quantized_name, + quantized_param.shape, + quantized_param.dtype, + ) + yield quantized_name, quantized_param else: yield mlc_name, param cached_files = list(self.cached_files.keys()) @@ -126,7 +147,7 @@ def load(self) -> Iterator[Tuple[str, NDArray]]: self.stats.log_time_info("HF") self.stats.log_mem_usage() - def _load_mlc_param(self, mlc_name: str) -> np.ndarray: + def _load_mlc_param(self, mlc_name: str, device: Optional[Device]) -> NDArray: torch_names = self.extern_param_map.param_map[mlc_name] files_required = {self.torch_to_path[p] for p in torch_names} files_existing = set(self.cached_files.keys()) @@ -148,8 +169,9 @@ def _load_mlc_param(self, mlc_name: str) -> np.ndarray: with self.stats.timer("map_time_sec"): param = self.extern_param_map.map_func[mlc_name](*torch_params) logger.info(' Parameter: "%s", shape: %s, dtype: %s', mlc_name, param.shape, param.dtype) - param = as_ndarray(param) - return param + if device: + return as_ndarray(param, device=device) + return as_ndarray(param) def _load_file(self, path: Path) -> None: logger.info("Loading HF parameters from: %s", path) diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/parameter/utils.py index a2789cee55..f297e2f0dc 100644 --- a/python/mlc_chat/compiler/parameter/utils.py +++ b/python/mlc_chat/compiler/parameter/utils.py @@ -2,7 +2,7 @@ # pylint: disable=too-few-public-methods import logging from pathlib import Path -from typing import TYPE_CHECKING, Iterator, Set, Tuple +from typing import TYPE_CHECKING, Iterator, Optional, Set, Tuple import numpy as np @@ -24,7 +24,7 @@ class ParamQuantizer: def __init__(self, quantize_map: "QuantizeMapping") -> None: self.quantize_map = quantize_map - def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]: + def quantize(self, name: str, param: "NDArray") -> Optional[Iterator[Tuple[str, "NDArray"]]]: """Apply quantization to the given parameters Parameters @@ -36,11 +36,14 @@ def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray" Returns ------- - List[Tuple[str, NDArray]] - The quantized parameters, each with its name + Optional[Iterator[Tuple[str, "NDArray"]]] + The quantized parameters, each with its name, returns None if the parameter is not + quantized. """ - - assert name in self.quantize_map.param_map + name = f".{name}" + if name not in self.quantize_map.param_map: + return None + assert name in self.quantize_map.map_func, f"Quantization function for {name} not found." quantized_names = self.quantize_map.param_map[name] quantized_params = self.quantize_map.map_func[name](param) return zip(quantized_names, quantized_params) diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 47ad82ac11..6e28b72a97 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -195,7 +195,7 @@ def quantize_weight(self, weight: NDArray) -> List[NDArray]: target = "llvm" mod = relax.transform.LegalizeOps()(mod) else: - raise NotImplementedError + raise NotImplementedError(f"Device type {device_type} is not supported") ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, dev) # pylint: disable=invalid-name self.prebuilt_quantize_func[key] = vm["quantize"] diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index 29328f8813..491402b008 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -24,7 +24,7 @@ BuildFunc = Callable[[IRModule, "CompileArgs"], None] -def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, BuildFunc]: +def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[Target, BuildFunc]: """Detect the configuration for the target device and its host, for example, target GPU and the host CPU. @@ -34,7 +34,7 @@ def detect_target_and_host(target_hint: str, host_hint: str) -> Tuple[Target, Bu The hint for the target device. host_hint : str - The hint for the host CPU. + The hint for the host CPU, default is "auto". """ target, build_func = _detect_target_gpu(target_hint) if target.host is None: diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py index 042e7b5366..96ca55bfcb 100644 --- a/python/mlc_chat/support/auto_weight.py +++ b/python/mlc_chat/support/auto_weight.py @@ -2,7 +2,7 @@ import json import logging from pathlib import Path -from typing import Tuple +from typing import List, Optional, Tuple from .style import green, red @@ -33,15 +33,16 @@ def detect_weight( Otherwise, check the weights are in that format. Available weight formats: - auto (guess the weight format) - - PyTorch (validate via checking pytorch_model.bin.index.json) - - SafeTensor (validate via checking model.safetensors.index.json) - - AWQ - - GGML/GGUF + - huggingface-torch (validate via checking pytorch_model.bin.index.json) + - huggingface-safetensor (validate via checking model.safetensors.index.json) + - awq + - ggml + - gguf Returns ------- - weight_path : pathlib.Path - The path that points to the weights. + weight_config_path : pathlib.Path + The path that points to the weights config file or the weights directory. weight_format : str The valid weight format. @@ -72,7 +73,7 @@ def detect_weight( # weight_format = "auto", guess the weight format. # otherwise, check the weight format is valid. if weight_format == "auto": - weight_format = _guess_weight_format(weight_path) + return _guess_weight_format(weight_path) if weight_format not in AVAILABLE_WEIGHT_FORMAT: raise ValueError( @@ -80,53 +81,54 @@ def detect_weight( ) if weight_format in CHECK_FORMAT_METHODS: check_func = CHECK_FORMAT_METHODS[weight_format] - if not check_func(weight_path): + weight_config_path = check_func(weight_path) + if not weight_config_path: raise ValueError(f"The weight is not in {weight_format} format.") - return weight_path, weight_format + return weight_config_path, weight_format -def _guess_weight_format(weight_path: Path): - possible_formats = [] +def _guess_weight_format(weight_path: Path) -> Tuple[Path, str]: + possible_formats: List[Tuple[Path, str]] = [] for weight_format, check_func in CHECK_FORMAT_METHODS.items(): - if check_func(weight_path): - possible_formats.append(weight_format) + weight_config_path = check_func(weight_path) + if weight_config_path: + possible_formats.append((weight_config_path, weight_format)) if len(possible_formats) == 0: raise ValueError( "Fail to detect weight format. Use `--weight-format` to manually specify the format." ) - selected_format = possible_formats[0] + weight_config_path, selected_format = possible_formats[0] logger.info( "Using %s format now. Use `--weight-format` to manually specify the format.", selected_format, ) - return selected_format + return weight_config_path, selected_format -def _check_pytorch(weight_path: Path): +def _check_pytorch(weight_path: Path) -> Optional[Path]: pytorch_json_path = weight_path / "pytorch_model.bin.index.json" - result = pytorch_json_path.exists() - if result: + if pytorch_json_path.exists(): logger.info("%s Huggingface PyTorch: %s", FOUND, pytorch_json_path) - else: - logger.info("%s Huggingface PyTorch", NOT_FOUND) - return result + return pytorch_json_path + logger.info("%s Huggingface PyTorch", NOT_FOUND) + return None -def _check_safetensor(weight_path: Path): +def _check_safetensor(weight_path: Path) -> Optional[Path]: safetensor_json_path = weight_path / "model.safetensors.index.json" - result = safetensor_json_path.exists() - if result: - logger.info("%s SafeTensor: %s", FOUND, safetensor_json_path) - else: - logger.info("%s SafeTensor", NOT_FOUND) - return result + if safetensor_json_path.exists(): + logger.info("%s Huggingface Safetensor: %s", FOUND, safetensor_json_path) + return safetensor_json_path + logger.info("%s Huggingface Safetensor", NOT_FOUND) + return None CHECK_FORMAT_METHODS = { - "PyTorch": _check_pytorch, - "SafeTensor": _check_safetensor, + "huggingface-torch": _check_pytorch, + "huggingface-safetensor": _check_safetensor, } -AVAILABLE_WEIGHT_FORMAT = ["PyTorch", "SafeTensor", "GGML", "GGUF", "AWQ"] +# "awq", "ggml", "gguf" are not supported yet. +AVAILABLE_WEIGHT_FORMAT = ["huggingface-torch", "huggingface-safetensor"] diff --git a/tests/python/support/test_auto_weight.py b/tests/python/support/test_auto_weight.py index 2987135267..5776791df1 100644 --- a/tests/python/support/test_auto_weight.py +++ b/tests/python/support/test_auto_weight.py @@ -24,13 +24,10 @@ def _create_json_file(json_path, data): @pytest.mark.parametrize( "weight_format, index_filename, result", [ - ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), - ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), - ("GGML", None, "GGML"), - ("GGUF", None, "GGUF"), - ("AWQ", None, "AWQ"), - ("auto", "pytorch_model.bin.index.json", "PyTorch"), - ("auto", "model.safetensors.index.json", "SafeTensor"), + ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), + ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), + ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), + ("auto", "model.safetensors.index.json", "huggingface-safetensor"), ], ) def test_detect_weight(weight_format, index_filename, result): @@ -39,19 +36,16 @@ def test_detect_weight(weight_format, index_filename, result): if index_filename is not None: weight_index_file = base_path / index_filename _create_json_file(weight_index_file, {}) - assert detect_weight(base_path, None, weight_format) == (base_path, result) + assert detect_weight(base_path, None, weight_format) == (weight_index_file, result) @pytest.mark.parametrize( "weight_format, index_filename, result", [ - ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), - ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), - ("GGML", None, "GGML"), - ("GGUF", None, "GGUF"), - ("AWQ", None, "AWQ"), - ("auto", "pytorch_model.bin.index.json", "PyTorch"), - ("auto", "model.safetensors.index.json", "SafeTensor"), + ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), + ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), + ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), + ("auto", "model.safetensors.index.json", "huggingface-safetensor"), ], ) def test_detect_weight_in_config_json(weight_format, index_filename, result): @@ -64,19 +58,16 @@ def test_detect_weight_in_config_json(weight_format, index_filename, result): weight_index_file = weight_path / index_filename _create_json_file(weight_index_file, {}) - assert detect_weight(None, config_json_path, weight_format) == (weight_path, result) + assert detect_weight(None, config_json_path, weight_format) == (weight_index_file, result) @pytest.mark.parametrize( "weight_format, index_filename, result", [ - ("PyTorch", "pytorch_model.bin.index.json", "PyTorch"), - ("SafeTensor", "model.safetensors.index.json", "SafeTensor"), - ("GGML", None, "GGML"), - ("GGUF", None, "GGUF"), - ("AWQ", None, "AWQ"), - ("auto", "pytorch_model.bin.index.json", "PyTorch"), - ("auto", "model.safetensors.index.json", "SafeTensor"), + ("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch"), + ("huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor"), + ("auto", "pytorch_model.bin.index.json", "huggingface-torch"), + ("auto", "model.safetensors.index.json", "huggingface-safetensor"), ], ) def test_detect_weight_same_dir_config_json(weight_format, index_filename, result): @@ -85,42 +76,47 @@ def test_detect_weight_same_dir_config_json(weight_format, index_filename, resul config_json_path = base_path / "config.json" _create_json_file(config_json_path, {}) if index_filename is not None: - weight_index_file = os.path.join(tmpdir, index_filename) + weight_index_file = Path(os.path.join(tmpdir, index_filename)) _create_json_file(weight_index_file, {}) - assert detect_weight(None, config_json_path, weight_format) == (base_path, result) + assert detect_weight(None, config_json_path, weight_format) == (weight_index_file, result) def test_find_weight_fail(): with tempfile.TemporaryDirectory() as tmpdir: base_path = Path(tmpdir) with pytest.raises(ValueError): - detect_weight(Path("do/not/exist"), base_path, "AWQ") + detect_weight(Path("do/not/exist"), base_path, "awq") with pytest.raises(AssertionError): - detect_weight(None, Path("do/not/exist"), "AWQ") + detect_weight(None, Path("do/not/exist"), "awq") if __name__ == "__main__": - test_detect_weight("PyTorch", "pytorch_model.bin.index.json", "PyTorch") - test_detect_weight("SafeTensor", "model.safetensors.index.json", "SafeTensor") - test_detect_weight("GGML", None, "GGML") - test_detect_weight("GGUF", None, "GGUF") - test_detect_weight("AWQ", None, "AWQ") - test_detect_weight("auto", "pytorch_model.bin.index.json", "PyTorch") - test_detect_weight("auto", "model.safetensors.index.json", "SafeTensor") - test_detect_weight_in_config_json("PyTorch", "pytorch_model.bin.index.json", "PyTorch") - test_detect_weight_in_config_json("SafeTensor", "model.safetensors.index.json", "SafeTensor") - test_detect_weight_in_config_json("GGML", None, "GGML") - test_detect_weight_in_config_json("GGUF", None, "GGUF") - test_detect_weight_in_config_json("AWQ", None, "AWQ") - test_detect_weight_in_config_json("auto", "pytorch_model.bin.index.json", "PyTorch") - test_detect_weight_in_config_json("auto", "model.safetensors.index.json", "SafeTensor") - test_detect_weight_same_dir_config_json("PyTorch", "pytorch_model.bin.index.json", "PyTorch") + test_detect_weight("huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch") + test_detect_weight( + "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight("auto", "pytorch_model.bin.index.json", "huggingface-torch") + test_detect_weight("auto", "model.safetensors.index.json", "huggingface-safetensor") + test_detect_weight_in_config_json( + "huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch" + ) + test_detect_weight_in_config_json( + "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight_in_config_json("auto", "pytorch_model.bin.index.json", "huggingface-torch") + test_detect_weight_in_config_json( + "auto", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight_same_dir_config_json( + "huggingface-torch", "pytorch_model.bin.index.json", "huggingface-torch" + ) + test_detect_weight_same_dir_config_json( + "huggingface-safetensor", "model.safetensors.index.json", "huggingface-safetensor" + ) + test_detect_weight_same_dir_config_json( + "auto", "pytorch_model.bin.index.json", "huggingface-torch" + ) test_detect_weight_same_dir_config_json( - "SafeTensor", "model.safetensors.index.json", "SafeTensor" + "auto", "model.safetensors.index.json", "huggingface-safetensor" ) - test_detect_weight_same_dir_config_json("GGML", None, "GGML") - test_detect_weight_same_dir_config_json("GGUF", None, "GGUF") - test_detect_weight_same_dir_config_json("AWQ", None, "AWQ") - test_detect_weight_same_dir_config_json("auto", "pytorch_model.bin.index.json", "PyTorch") - test_detect_weight_same_dir_config_json("auto", "model.safetensors.index.json", "SafeTensor") test_find_weight_fail() From 2dc81830f5a9daecd0ed8ec65f18190a48cdf64f Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:36:52 -0700 Subject: [PATCH 084/116] [Fix][SLM] Update q4f16 quantization with the new mutator name rule (#1178) [Fix] Update q4f16 quantization with the new mutator name rule --- python/mlc_chat/compiler/parameter/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/parameter/utils.py index f297e2f0dc..3c6c0d1476 100644 --- a/python/mlc_chat/compiler/parameter/utils.py +++ b/python/mlc_chat/compiler/parameter/utils.py @@ -40,7 +40,6 @@ def quantize(self, name: str, param: "NDArray") -> Optional[Iterator[Tuple[str, The quantized parameters, each with its name, returns None if the parameter is not quantized. """ - name = f".{name}" if name not in self.quantize_map.param_map: return None assert name in self.quantize_map.map_func, f"Quantization function for {name} not found." From 6ae02dd0746e51184a4c5041dcca91ea9183c6aa Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Fri, 3 Nov 2023 15:34:29 -0400 Subject: [PATCH 085/116] [Model Support][SWA] Add support for sliding window attention for Mistral (#1087) * mistral base * Add sliding window mask making and its tests * Small changes for sliding window mask * Clean up mask making * Remove kv_seq_len * Add prefill chunking, handle max window size in SWA * Add interleave kv * Temporary fix for kv seq len * Pass in more shapes to SWA prefill and decode in runtime * mistral var fix * Small changes regarding shape passing * Small fix on chunk size * Add build args, fix mlc chat config dump * mistral system prompt --------- Co-authored-by: David Pissarra Co-authored-by: David Pissarra <61968959+davidpissarra@users.noreply.github.com> --- cpp/conv_templates.cc | 4 + cpp/llm_chat.cc | 89 +- mlc_llm/core.py | 75 +- mlc_llm/relax_model/mistral.py | 1062 +++++++++++++++++ .../support/test_sliding_window_mask.py | 339 ++++++ 5 files changed, 1542 insertions(+), 27 deletions(-) create mode 100644 mlc_llm/relax_model/mistral.py create mode 100644 tests/python/support/test_sliding_window_mask.py diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index ae91bf2070..dd90a67fb5 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -49,6 +49,10 @@ Conversation Llama2() { Conversation MistralDefault() { Conversation conv; conv.name = "mistral_default"; + conv.system = + ("[INST] Always assist with care, respect, and truth. Respond with utmost utility yet " + "securely. Avoid harmful, unethical, prejudiced, or negative content. Ensure replies " + "promote fairness and positivity."); conv.roles = {"[INST]", "[/INST]"}; conv.messages = {}; conv.offset = 0; diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index f8c7ef0986..70d89db348 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -295,6 +295,9 @@ class LLMChat { if (ft_.use_disco) { return false; } + if (this->sliding_window_ != -1) { + return false; + } PackedFunc fget_metadata = ft_.mod_get_func("get_metadata"); if (fget_metadata == nullptr) { return false; @@ -369,6 +372,16 @@ class LLMChat { this->max_window_size_ = std::min(this->max_window_size_, config["max_window_size"].get()); } + if (config.count("sliding_window")) { + CHECK(config["sliding_window"].is()); + CHECK(!config.count("max_window_size")) + << "Cannot specify both sliding_window and max_window_size."; + this->sliding_window_ = config["sliding_window"].get(); + } + if (config.count("sliding_window_chunk_size")) { + CHECK(config["sliding_window_chunk_size"].is()); + this->sliding_window_chunk_size_ = config["sliding_window_chunk_size"].get(); + } if (config.count("model_name")) { CHECK(config["model_name"].is()); this->model_name_ = config["model_name"].get(); @@ -462,9 +475,11 @@ class LLMChat { // so there is no explicit abi dependency on these extra // classes other than basic tvm runtime. this->ft_.Init(reload_lib, device_, this->num_shards_); - UpdateMaxWindowSizeFromMetadata(); - CHECK(max_window_size_ != std::numeric_limits::max()) - << "Key \"max_window_size\" not found."; + if (this->sliding_window_ == -1) { + UpdateMaxWindowSizeFromMetadata(); + CHECK(max_window_size_ != std::numeric_limits::max()) + << "Key \"max_window_size\" not found."; + } // Step 4. Initialize sample functions. auto fsample_topp_from_prob_ptr = tvm::runtime::Registry::Get("vm.builtin.sample_top_p_from_prob"); @@ -562,7 +577,8 @@ class LLMChat { std::string all_prompt = GetConcatPrompt(prompts, 0, 0); std::vector encoded = this->tokenizer_->Encode(all_prompt); tokens.insert(tokens.end(), encoded.begin(), encoded.end()); - if (this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) { + if (this->sliding_window_ != -1 || // There is no max window size if we use sliding window + this->total_seq_len_ + tokens.size() + gen_mean_gen_len < this->max_window_size_) { return tokens; } // need shift window and re-encode @@ -753,6 +769,10 @@ class LLMChat { if (ft_.use_disco) { LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model"; } + if (this->sliding_window_ != -1) { + LOG(FATAL) + << "NotImplementedError: Sliding window attention does not support separate embedding"; + } NDArray embedding = Downcast( EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str)); PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str); @@ -772,8 +792,28 @@ class LLMChat { } auto tstart = std::chrono::high_resolution_clock::now(); - int32_t new_seq_len = total_seq_len_ + token_len; - NDArray logits_on_device = this->ForwardTokens(prompt_tokens, new_seq_len); + int32_t new_seq_len = total_seq_len_; + NDArray logits_on_device; + if (this->sliding_window_ != -1) { + // Use chunking if we use sliding window attention (see Mistral paper figure 3). + int64_t sliding_window_chunk_size = this->sliding_window_chunk_size_; + if (this->sliding_window_chunk_size_ == -1) { + // One chunk if chunk size not specified + sliding_window_chunk_size = token_len; + } + for (int64_t begin = 0; begin < token_len; begin += sliding_window_chunk_size) { + int64_t end = std::min(token_len, begin + sliding_window_chunk_size); + std::vector chunk = + std::vector(prompt_tokens.begin() + begin, prompt_tokens.begin() + end); + new_seq_len += static_cast(chunk.size()); + logits_on_device = this->ForwardTokens(chunk, new_seq_len); + } + ICHECK_EQ(new_seq_len, total_seq_len_ + token_len) << "Expect chunking process all tokens"; + } else { + // Otherwise, prefill entire prompt at once. + new_seq_len += token_len; + logits_on_device = this->ForwardTokens(prompt_tokens, new_seq_len); + } total_seq_len_ = new_seq_len; if (!decode_next_token) { @@ -1111,7 +1151,9 @@ class LLMChat { } // max_window_size_ != -1 to handle // https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/relax_model/rwkv.py#L588-L589 - else if (max_window_size_ != -1 && total_seq_len_ >= max_window_size_) { + // sliding_window_ == -1 to make sure we do not stop when using sliding window + else if (max_window_size_ != -1 && sliding_window_ == -1 && + total_seq_len_ >= max_window_size_) { stop_triggered_ = true; } if (stop_triggered_) { @@ -1125,7 +1167,18 @@ class LLMChat { if (input_tokens.size() > 1 && ft_.prefill_func_.defined()) { ObjectRef input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray(input_tokens)); ShapeTuple cur_pos_shape = ShapeTuple({cur_pos}); - ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_); + if (sliding_window_ == -1) { + ret = ft_.prefill_func_(input_data, cur_pos_shape, kv_cache_, params_); + } else { + // Sliding window attention needs extra shape parameters + int64_t seq_len = static_cast(input_tokens.size()); + // Number of elements in the cache + int64_t cache_len = std::min(this->sliding_window_, cur_pos - seq_len); + ShapeTuple cache_len_shape = ShapeTuple({cache_len}); + ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len}); + ret = ft_.prefill_func_(input_data, cur_pos_shape, cache_len_shape, kv_seq_len_shape, + kv_cache_, params_); + } } else { // running decode function when prefill is not available for (int i = 0; i < input_tokens.size(); ++i) { @@ -1138,8 +1191,19 @@ class LLMChat { input_data = ft_.CopyToWorker0(this->GetInputTokenNDArray({input_tokens[i]})); } int64_t pos = cur_pos + i + 1 - input_tokens.size(); - ShapeTuple pos_shape = ShapeTuple({cur_pos}); - ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_); + ShapeTuple pos_shape = ShapeTuple({pos}); + if (sliding_window_ == -1) { + ret = ft_.decode_func_(input_data, pos_shape, kv_cache_, params_); + } else { + // Sliding window attention needs extra shape parameters + int64_t seq_len = static_cast(input_tokens.size()); + // Number of elements in the cache + int64_t cache_len = std::min(this->sliding_window_, pos - seq_len); + ShapeTuple cache_len_shape = ShapeTuple({cache_len}); + ShapeTuple kv_seq_len_shape = ShapeTuple({cache_len + seq_len}); + ret = ft_.decode_func_(input_data, pos_shape, cache_len_shape, kv_seq_len_shape, + kv_cache_, params_); + } } } if (ft_.use_disco) { @@ -1265,9 +1329,10 @@ class LLMChat { Conversation conversation_; // total sequence len, int64_t total_seq_len_{0}; - // max window size, mean generation length + // max window size, mean and max generation length, sliding window + // If we use sliding window, max window size is its default max() value int64_t max_window_size_{std::numeric_limits::max()}, mean_gen_len_{128}, - max_gen_len_{512}; + max_gen_len_{512}, sliding_window_{-1}, sliding_window_chunk_size_{-1}; // size of the vocab table int64_t vocab_size_; // number of shards in distributed inference diff --git a/mlc_llm/core.py b/mlc_llm/core.py index d914364b9c..a0490ecf10 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -24,6 +24,7 @@ llama, llama_batched_vllm, minigpt, + mistral, param_manager, rwkv, stablelm_3b, @@ -80,6 +81,13 @@ class BuildArgs: Build with separated embedding layer, only applicable to LlaMa. This feature is in testing stage, and will be formally replaced after massive overhaul of embedding feature for all models and use cases. + sliding_window: int + The sliding window size in sliding window attention (SWA). This optional field + overrides the `sliding_window` in config.json for those models that use SWA. + Currently only useful when compiling Mistral. + sliding_window_chunk_size: int + The chunk size in sliding window attention (SWA) during prefilling. By default, + the chunk size is the same as sliding window. Currently only useful when compiling Mistral. cc_path: str ``/path/to/cross_compiler_path``; currently only used for cross-compile for nvidia/jetson device. @@ -184,7 +192,10 @@ class BuildArgs: cc_path: str = field( default="", metadata={ - "help": "/path/to/cross_compiler_path, Currently only used for cross-compile for nvidia/jetson device." + "help": ( + "/path/to/cross_compiler_path, Currently only used for " + "cross-compile for nvidia/jetson device." + ) }, ) system_lib: bool = field( @@ -275,6 +286,26 @@ class BuildArgs: "action": "store_true", }, ) + sliding_window: int = field( + default=-1, + metadata={ + "help": ( + "The sliding window size in sliding window attention (SWA). " + "This optional field overrides the `sliding_window` in config.json for " + "those models that use SWA. Currently only useful when compiling Mistral." + ), + }, + ) + sliding_window_chunk_size: int = field( + default=-1, + metadata={ + "help": ( + "The chunk size in sliding window attention (SWA) during prefilling. " + "By default, the chunk size is the same as sliding window. " + "Currently only useful when compiling Mistral." + ), + }, + ) pdb: bool = field( default=False, metadata={ @@ -286,7 +317,8 @@ class BuildArgs: default=False, metadata={ "help": ( - "Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True." + "Use vLLM paged KV cache and attention kernel, only relevant when " + "enable_batching=True." ), "action": "store_true", }, @@ -330,7 +362,9 @@ def _parse_args(parsed) -> argparse.Namespace: if parsed.use_vllm_attention: assert parsed.enable_batching, "--enable_batching is required for using vLLM attention." assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA." - assert tvm.get_global_func("tvm.contrib.vllm.single_query_cached_kv_attention", True), "TVM needs to be built with -DUSE_VLLM=ON." + assert tvm.get_global_func( + "tvm.contrib.vllm.single_query_cached_kv_attention", True + ), "TVM needs to be built with -DUSE_VLLM=ON." parsed.artifact_path = os.path.join( parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" @@ -391,10 +425,10 @@ def _setup_model_path(args: argparse.Namespace): # pylint: disable=too-many-bra def validate_config(model_path: str): if os.path.exists(os.path.join(model_path, "mlc-chat-config.json")): raise KeyError( - "The model located in the directory {} has already been compiled by MLC-LLM. There is" - " no need to compile it again. If you wish to compile a new model, please provide a" - " directory (or hf-path) that contains the pre-compiled model in raw HuggingFace" - " format instead.".format(model_path) + f"The model located in the directory {model_path} has already been compiled " + "by MLC-LLM. There is no need to compile it again. If you wish to compile " + "a new model, please provide a directory (or hf-path) that contains the " + "pre-compiled model in raw HuggingFace format instead." ) if model_path.split("/")[-1].startswith("minigpt"): # minigpt does not contain a config.json file so we skip the check @@ -467,12 +501,13 @@ def mod_transform_before_build( if max_seq_len: num_key_value_heads = config.get_num_key_value_heads() + # pylint: disable=no-value-for-parameter mod = fuse_split_rotary_embedding( - config.num_attention_heads // args.num_shards, - num_key_value_heads // args.num_shards, - config.hidden_size // args.num_shards, - config.position_embedding_base, - )(mod) + config.num_attention_heads // args.num_shards, + num_key_value_heads // args.num_shards, + config.hidden_size // args.num_shards, + config.position_embedding_base, + )(mod) if args.target_kind == "cuda": patterns = [] @@ -480,6 +515,7 @@ def mod_transform_before_build( has_cutlass = tvm.get_global_func("relax.ext.cutlass", True) if has_cutlass and not args.no_cutlass_attn: + # pylint: disable=no-value-for-parameter if args.use_flash_attn_mqa: mod = rewrite_attention(use_flash_mqa=True)(mod) mod = rewrite_attention(use_flash_mqa=False)(mod) @@ -565,7 +601,6 @@ def dump_mlc_chat_config( config["top_p"] = top_p config["mean_gen_len"] = mean_gen_len config["max_gen_len"] = max_gen_len - config["max_window_size"] = max_window_size config["num_shards"] = args.num_shards config["shift_fill_factor"] = shift_fill_factor if rwkv_world: @@ -575,6 +610,12 @@ def dump_mlc_chat_config( config["model_category"] = args.model_category config["model_name"] = args.model config["vocab_size"] = vocab_size + if args.sliding_window != -1: + # Do not add max window size if use sliding window + config["sliding_window"] = args.sliding_window + config["sliding_window_chunk_size"] = args.sliding_window_chunk_size + else: + config["max_window_size"] = max_window_size args.chat_config_path = os.path.join(args.params_path, "mlc-chat-config.json") with open(args.chat_config_path, "w", encoding="utf-8") as outfile: @@ -640,7 +681,7 @@ def build_model_from_args(args: argparse.Namespace): if args.quantization == "q4f16_0": print( "WARNING: q4f16_1 is preferred to q4f16_0, " - "and it is highly recommended to use q4f16_1 instaed" + "and it is highly recommended to use q4f16_1 instead" ) if args.num_shards > 1: if (not args.build_model_only) and (not args.convert_weight_only): @@ -670,7 +711,7 @@ def build_model_from_args(args: argparse.Namespace): if not use_cache or args.convert_weight_only: model_generators = { "llama": llama, - "mistral": llama, + "mistral": mistral, "stablelm_epoch": stablelm_3b, "gpt_neox": gpt_neox, "gpt_bigcode": gpt_bigcode, @@ -691,6 +732,10 @@ def build_model_from_args(args: argparse.Namespace): args, config ) + if args.model_category == "mistral": + args.sliding_window = model_config.sliding_window + args.sliding_window_chunk_size = model_config.sliding_window_chunk_size + for qspec_updater_class in param_manager.qspec_updater_classes: qspec_updater = qspec_updater_class(param_manager) qspec_updater.visit_module(mod) diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py new file mode 100644 index 0000000000..1ef00ff577 --- /dev/null +++ b/mlc_llm/relax_model/mistral.py @@ -0,0 +1,1062 @@ +# pylint: disable=too-many-lines, missing-class-docstring, missing-function-docstring +"""Implements the mistal model with sliding window attention.""" + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import numpy as np +import tvm +from tvm import relax, te +from tvm.relax.op import ccl +from tvm.relax.testing import nn +from tvm.script import relax as R + +from ..quantization import ParamQuantKind, QuantizationScheme +from .commons import create_metadata_func +from .modules import ModuleList +from .param_manager import ParamManager + + +@dataclass +class MistralConfig: + """Configuration for mistral model.""" + + def __init__( + self, + bos_token_id=1, + eos_token_id=2, + pad_token_id=-1, + hidden_act="silu", + hidden_size=4096, + initializer_range=0.02, + intermediate_size=14336, + max_position_embeddings=32768, + num_attention_heads=32, + num_hidden_layers=32, + num_key_value_heads=8, + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + vocab_size=32000, + dtype="float32", + sliding_window_chunk_size=-1, + max_sequence_length=-1, # Does not play a role, kept for compatibility. + combine_matmul=True, + build_model_only=False, + num_shards=1, + **kwargs, + ): + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.hidden_act = hidden_act + self.hidden_size = hidden_size + self.initializer_range = initializer_range + self.intermediate_size = intermediate_size + self.max_position_embeddings = max_position_embeddings + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.rope_theta = rope_theta + self.sliding_window = sliding_window + self.tie_word_embeddings = tie_word_embeddings + self.vocab_size = vocab_size + self.dtype = dtype + if sliding_window_chunk_size == -1: + # chunk size same as sliding window by default + self.sliding_window_chunk_size = self.sliding_window + else: + self.sliding_window_chunk_size = sliding_window_chunk_size + self.max_sequence_length = max_sequence_length + self.combine_matmul = combine_matmul + if build_model_only and num_shards > 1: + self.num_shards = num_shards + else: + self.num_shards = 1 + self.kwargs = kwargs + + def get_num_key_value_heads(self): + if self.num_key_value_heads is None: + return self.num_attention_heads + + return self.num_key_value_heads + + +class Linear(nn.Module): + def __init__(self, in_features, out_features, dtype: str, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter((out_features, in_features), dtype=dtype, name="linear_weight") + if bias: + self.bias = nn.Parameter((out_features,), dtype=dtype, name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + return nn.emit(relax.op.linear(input, self.weight, self.bias)) + + +class Embedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim, dtype: str): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.weight = nn.Parameter( + (num_embeddings, embedding_dim), dtype=dtype, name="embedding_weight" + ) + + def forward(self, x: relax.Expr) -> relax.Var: + from tvm.relax.op import ( # pylint: disable=import-outside-toplevel + reshape, + take, + ) + + ndim = x.struct_info.ndim + if ndim == 1: + return nn.emit(take(self.weight, x, axis=0)) + else: + x_shape = x.struct_info.shape.values + emb_size = self.weight.struct_info.shape.values[-1] + x = nn.emit(reshape(x, shape=[-1])) + embedding = nn.emit(take(self.weight, x, axis=0)) + return nn.emit(reshape(embedding, [*x_shape, emb_size])) + + +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, dtype, eps=1e-6): + self.weight = nn.Parameter((hidden_size,), dtype=dtype, name="rms_norm_weight") + self.variance_epsilon = tvm.tir.const(eps, dtype) + + def forward(self, hidden_states): + from tvm import te, tir + + def f_rms_norm(x, weight): + is_float32 = x.dtype == "float32" + + def f_square(x): + return tir.Cast("float32", x) * tir.Cast("float32", x) if not is_float32 else x * x + + k = te.reduce_axis((0, x.shape[2]), name="k") + square_sum = te.compute( + (x.shape[0], x.shape[1]), + lambda bsz, i: te.sum(f_square(x[bsz, i, k]), axis=k), + name=x.op.name + "red_temp", + ) + + def f_div_cast(bsz, i, k): + x_val = x[bsz, i, k] + if not is_float32: + x_val = tir.Cast("float32", x_val) + return x_val / tir.sqrt(square_sum[bsz, i] / x.shape[2] + self.variance_epsilon) + + def f_mul_cast(x, y): + value = x * y + if not is_float32: + value = tir.Cast(x.dtype, value) + return value + + return te.compute( + x.shape, + lambda bsz, i, k: f_mul_cast(weight(k), f_div_cast(bsz, i, k)), + name="rms_norm", + ) + + return nn.emit_te(f_rms_norm, hidden_states, self.weight, primfunc_name_hint="rms_norm") + + +class MistralMLP(nn.Module): + def __init__(self, config: MistralConfig): + self.combine_matmul = config.combine_matmul + self.num_shards = config.num_shards + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size // self.num_shards + dtype = config.dtype + if self.combine_matmul: + self.gate_up_proj = Linear(hidden_size, 2 * intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.gate_up_proj.weight.shard_dim = 0 + self.gate_up_proj.weight.shard_strategy = "shard_gate_up" + self.down_proj.weight.shard_dim = 1 + self.down_proj.weight.shard_strategy = "shard_mlp_k" + else: + self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False) + + def forward(self, x): + if self.combine_matmul: + gate_up_results = nn.emit( + relax.op.split( + self.gate_up_proj(x), + indices_or_sections=2, + axis=-1, + ) + ) + gate_result = relax.TupleGetItem(gate_up_results, 0) + up_result = relax.TupleGetItem(gate_up_results, 1) + else: + gate_result = self.gate_proj(x) + up_result = self.up_proj(x) + + result = self.down_proj(relax.op.nn.silu(gate_result) * up_result) + return result + + +def apply_rotary_pos_emb(q, k, base, offset: int = 0): + def f_rotary_embedding(tensor, offset): + dtype = tensor.dtype + head_dim = tensor.shape[-1] + n_feat_half = tensor.shape[-1] // 2 + + def rotary_compute(*idx): + i, j = idx[-3], idx[-1] + pos = (offset + i).astype("float32") + inv_freq = te.const(1, "float32") / ( + te.power( + te.const(base, "float32"), + ((2 * j) % head_dim).astype("float32") / head_dim.astype("float32"), + ) + ) + freq = pos * inv_freq + return te.cos(freq).astype(dtype) * tensor(*idx) + te.sin(freq).astype( + dtype + ) * tvm.tir.Select( + j >= n_feat_half, + tensor[idx[0], i, idx[2], j - n_feat_half], + -tensor[idx[0], i, idx[2], j + n_feat_half], + ) + + return tvm.te.compute(tensor.shape, rotary_compute, name="rotary") + + q_embed = nn.emit_te(f_rotary_embedding, q, offset, primfunc_name_hint="rotary_embedding") + k_embed = nn.emit_te(f_rotary_embedding, k, offset, primfunc_name_hint="rotary_embedding") + return q_embed, k_embed + + +class MistralAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: MistralConfig): + dtype = config.dtype + self.num_shards = config.num_shards + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + self.num_query_heads = config.num_attention_heads // self.num_shards + self.head_dim = self.hidden_size // config.num_attention_heads + self.rope_theta = config.rope_theta + self.sliding_window = config.sliding_window + + self.combine_matmul = config.combine_matmul + if self.combine_matmul: + self.query_key_value_proj = Linear( + self.hidden_size, + (self.num_query_heads + 2 * self.num_key_value_heads) * self.head_dim, + dtype=dtype, + bias=False, + ) + self.query_key_value_proj.weight.shard_dim = 0 + self.query_key_value_proj.weight.shard_strategy = "shard_qkv" + else: + self.q_proj = Linear( + self.hidden_size, + self.num_query_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.k_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.v_proj = Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + dtype=dtype, + bias=False, + ) + self.q_proj.weight.shard_dim = 0 + self.k_proj.weight.shard_dim = 0 + self.v_proj.weight.shard_dim = 0 + + self.o_proj = Linear( + self.head_dim * self.num_query_heads, self.hidden_size, dtype=dtype, bias=False + ) + self.o_proj.weight.shard_dim = 1 + self.o_proj.weight.shard_strategy = "shard_o_proj_k" + + def interleave_kv( + self, + key_cur: relax.Expr, + value_cur: relax.Expr, + kv_seq_len: int, + cache_len: int, + cache_offset: int, + past_key_value: Tuple[relax.Expr], + ): + from tvm.relax.op import reshape, squeeze + + # [bsz, t, nh, hd] + kv_cur_shape = key_cur.struct_info.shape + kv_cur_dtype = key_cur.struct_info.dtype + assert kv_cur_shape[0] == 1 # bsz + kv_batched_cache_shape = R.shape( + [kv_cur_shape[0], cache_len, kv_cur_shape[2], kv_cur_shape[3]] + ) + kv_cache_shape = R.shape([cache_len, kv_cur_shape[2], kv_cur_shape[3]]) + + # fecth past keys and values from cache + k_cache, v_cache = past_key_value + + f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view") + key_cached = nn.emit( + relax.Call( + f_kv_cache_view, + args=[k_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], + ) + ) + value_cached = nn.emit( + relax.Call( + f_kv_cache_view, + args=[v_cache, kv_cache_shape], + sinfo_args=[R.Tensor(kv_cache_shape, kv_cur_dtype)], + ) + ) + key_cached = nn.emit(reshape(key_cached, kv_batched_cache_shape)) + value_cached = nn.emit(reshape(value_cached, kv_batched_cache_shape)) + + def te_unrotate_concat(x, x_cached, cache_offset, cache_len): + return te.compute( + (kv_cur_shape[0], kv_seq_len, kv_cur_shape[2], kv_cur_shape[3]), + lambda b, s, h, d: te.if_then_else( + s < cache_len - cache_offset, + x_cached[b, cache_offset + s, h, d], + te.if_then_else( + s < cache_len, + x_cached[b, s + cache_offset - cache_len, h, d], + x[b, s - cache_len, h, d], + ), + ), + name="unrotate_concat_te", + ) + + key = nn.emit_te( + te_unrotate_concat, + key_cur, + key_cached, + cache_offset, + cache_len, + primfunc_name_hint="te_unrotate_concat_key", + ) + value = nn.emit_te( + te_unrotate_concat, + value_cur, + value_cached, + cache_offset, + cache_len, + primfunc_name_hint="te_unrotate_concat_value", + ) + + # # update cache + # k_cache, v_cache = past_key_value + # squeezed_key = nn.emit(squeeze(key_cur)) + # squeezed_value = nn.emit(squeeze(value_cur)) + + def te_squeeze(x): + return te.compute( + x.shape[1:], + lambda s, h, d: x[0, s, h, d], + name="squeeze_te", + ) + + # update cache + squeezed_key = nn.emit_te(te_squeeze, key_cur) + squeezed_value = nn.emit_te(te_squeeze, value_cur) + + f_kv_cache_overwrite = relax.extern("vm.builtin.attention_kv_cache_window_override") + k_cache = nn.emit( + relax.Call( + f_kv_cache_overwrite, + args=[k_cache, squeezed_key, relax.PrimValue(self.sliding_window)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + v_cache = nn.emit( + relax.Call( + f_kv_cache_overwrite, + args=[v_cache, squeezed_value, relax.PrimValue(self.sliding_window)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + + return key, value, (k_cache, v_cache) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]: + # pylint: disable=import-outside-toplevel + from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, split + from tvm.relax.op.nn import softmax + + bsz, q_len, _ = hidden_states.struct_info.shape + assert bsz == 1, "Only support batch size 1 at this moment." + + if self.combine_matmul: + qkv_cur = nn.emit( + split( + self.query_key_value_proj(hidden_states), + indices_or_sections=[ + self.num_query_heads * self.head_dim, + (self.num_query_heads + self.num_key_value_heads) * self.head_dim, + ], + axis=-1, + ) + ) + query = relax.TupleGetItem(qkv_cur, 0) + key_cur = relax.TupleGetItem(qkv_cur, 1) + value_cur = relax.TupleGetItem(qkv_cur, 2) + else: + query = self.q_proj(hidden_states) + key_cur = self.k_proj(hidden_states) + value_cur = self.v_proj(hidden_states) + + query = nn.emit( + reshape( + query, + (bsz, q_len, self.num_query_heads, self.head_dim), + ), + ) + key_cur = nn.emit( + reshape( + key_cur, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + value_cur = nn.emit( + reshape( + value_cur, + (bsz, q_len, self.num_key_value_heads, self.head_dim), + ), + ) + + all_seq_len = all_seq_len_shape.struct_info.values[0] + offset = all_seq_len - q_len + query, key_cur = apply_rotary_pos_emb( + query, + key_cur, + self.rope_theta, + offset=offset, + ) + + # concat current kv with cached kv (unrotating the cache) + cache_len = cache_len_shape.struct_info.values[0] + kv_seq_len = kv_seq_len_shape.struct_info.values[0] + cache_offset = (all_seq_len - q_len) % self.sliding_window + key, value, updated_key_value = self.interleave_kv( + key_cur, value_cur, kv_seq_len, cache_len, cache_offset, past_key_value + ) + + if self.num_key_value_heads != self.num_query_heads: + n_rep = self.num_query_heads // self.num_key_value_heads + key = nn.emit(relax.op.repeat(key, n_rep, axis=2)) + value = nn.emit(relax.op.repeat(value, n_rep, axis=2)) + + query = nn.emit(permute_dims(query, [0, 2, 1, 3])) + key = nn.emit(permute_dims(key, [0, 2, 1, 3])) + value = nn.emit(permute_dims(value, [0, 2, 1, 3])) + + attn_weights = nn.emit( + matmul(query, permute_dims(key, [0, 1, 3, 2])) + / relax.const(math.sqrt(self.head_dim), query.struct_info.dtype) + ) + + tvm.ir.assert_structural_equal( + attention_mask.struct_info.shape.values, + (bsz, tvm.tir.IntImm("int64", 1), q_len, kv_seq_len), + ) + + attn_weights = nn.emit( + maximum( + attn_weights, + relax.const( + tvm.tir.min_value(attn_weights.struct_info.dtype).value, + attn_weights.struct_info.dtype, + ), + ) + ) + attn_weights = nn.emit(relax.op.minimum(attn_weights, attention_mask)) + + # upcast attention to fp32 + if attn_weights.struct_info.dtype != "float32": + attn_weights = astype(attn_weights, "float32") + attn_weights = nn.emit(softmax(attn_weights, axis=-1)) + if attn_weights.struct_info.dtype != query.struct_info.dtype: + attn_weights = astype(attn_weights, query.struct_info.dtype) + attn_output = nn.emit(matmul(attn_weights, value)) + + attn_output = nn.emit(permute_dims(attn_output, [0, 2, 1, 3])) + attn_output = nn.emit( + reshape(attn_output, (bsz, q_len, self.head_dim * self.num_query_heads)) + ) + + attn_output = self.o_proj(attn_output) + + return attn_output, ((None, None) if updated_key_value is None else updated_key_value) + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig): + self.hidden_size = config.hidden_size + self.self_attn = MistralAttention(config) + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = MistralRMSNorm( + config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: relax.Expr, + all_seq_len_shape: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + past_key_value: Tuple[relax.Expr], + attention_mask: Optional[relax.Expr] = None, + ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + attention_mask=attention_mask, + all_seq_len_shape=all_seq_len_shape, + cache_len_shape=cache_len_shape, + kv_seq_len_shape=kv_seq_len_shape, + ) + if self.self_attn.num_shards > 1: + residual = nn.emit( + residual / R.const(self.self_attn.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.self_attn.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if self.mlp.num_shards > 1: + residual = nn.emit( + residual / R.const(self.mlp.num_shards, dtype=residual.struct_info.dtype) + ) + hidden_states = nn.emit(residual + hidden_states) + if self.mlp.num_shards > 1: + hidden_states = nn.emit(ccl.allreduce(hidden_states, "sum")) + return hidden_states, present_key_value + + +def _make_sliding_window_mask(input_shape, kv_seq_len, sliding_window, dtype): + # See `tests/python/test_sliding_window_mask.py` for more on its behavior. + # [bsz, tgt_len] -> [bsz, 1, tgt_len, kv_seq_len] + + bsz, tgt_len = input_shape # TODO: only support batch size of 1 for now + cache_len = kv_seq_len - tgt_len # number of elements in cache + + if isinstance(tgt_len, tvm.tir.Var) or tgt_len > 1: + # Either 1. First prefill, or 2. Subsequent prefill + from tvm.relax.op import broadcast_to # pylint: disable=import-outside-toplevel + + def sliding_window_min_max_te(sliding_window): + return te.compute( + (tgt_len, kv_seq_len), + lambda i, j: tvm.tir.Select( + tvm.tir.all(i + cache_len >= j, i + cache_len - j < sliding_window), + tvm.tir.max_value(dtype), + tvm.tir.min_value(dtype), + ), + name="make_diag_mask_sliding_window_te", + ) + + mask = nn.emit_te(sliding_window_min_max_te, sliding_window) + return nn.emit(broadcast_to(mask, (bsz, 1, tgt_len, kv_seq_len))) + + else: + # 3. Decode (equivalent to prefilling a chunk of size 1) + # Mask nothing here since WS == cache_size + bsz, tgt_len = input_shape + return nn.emit( + relax.op.full( + (bsz, 1, tgt_len, kv_seq_len), + relax.const(tvm.tir.max_value(dtype).value, dtype), + dtype, + ) + ) + + +class MistralEmbedTokens(nn.Module): + def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.Var): + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.embed_tokens(input_ids) + return inputs_embeds + + +class MistralEmbedTokensWrapper(nn.Module): + def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.Var): + # build a wrapper to ensure that the naming of the embed_tokens parameter is consistent + self.model = MistralEmbedTokens(config, vocab_size_var) + + def forward(self, input_ids: relax.Expr): + inputs_embeds = self.model(input_ids) + return inputs_embeds + + +class MistralModel(nn.Module): + def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + self.num_shards = config.num_shards + self.padding_idx = config.pad_token_id + self.embed_tokens = None + + if not sep_embed: + self.embed_tokens = Embedding(vocab_size_var, config.hidden_size, dtype=config.dtype) + + self.layers = ModuleList( + [MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.norm = MistralRMSNorm(config.hidden_size, dtype=config.dtype, eps=config.rms_norm_eps) + self.sliding_window = config.sliding_window + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + if self.num_shards > 1: + inputs = nn.emit(ccl.broadcast_from_worker0(inputs)) + if self.embed_tokens: + inputs_embeds = self.embed_tokens(inputs) + else: + inputs_embeds = inputs + # retrieve input_ids + batch_size, seq_length, _ = inputs_embeds.struct_info.shape + kv_seq_len = kv_seq_len_shape.struct_info.values[0] + + # embed positions + attention_mask = _make_sliding_window_mask( + (batch_size, seq_length), + kv_seq_len, + self.sliding_window, + inputs_embeds.struct_info.dtype, + ) + + hidden_states = inputs_embeds + + # decoder layers + next_decoder_cache = () + + for idx, decoder_layer in enumerate(self.layers): + assert past_key_values is not None + past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1]) + + hidden_states, key_value_cache = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + all_seq_len_shape=all_seq_len_shape, + cache_len_shape=cache_len_shape, + kv_seq_len_shape=kv_seq_len_shape, + ) + next_decoder_cache += key_value_cache + + hidden_states = self.norm(hidden_states) + + assert len(next_decoder_cache) == len(self.layers) * 2 + return hidden_states, next_decoder_cache + + +class MistralForCausalLM(nn.Module): + def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.Var, sep_embed: bool = False): + self.model = MistralModel(config, vocab_size_var, sep_embed) + self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False) + + ############ Rotary embedding constants ############ + assert config.hidden_size % config.num_attention_heads == 0 + head_dim = config.hidden_size // config.num_attention_heads + + # Set the cached sin/cos to the maximum of 2048 and max seq len. + # This will be eliminated further with online rotary embedding calculation. + rope_cache_len = te.var("rope_cache_len", "int64") + self.cos_cached = nn.Parameter( + (rope_cache_len, head_dim), dtype=config.dtype, name="cos_cached" + ) + self.sin_cached = nn.Parameter( + (rope_cache_len, head_dim), dtype=config.dtype, name="sin_cached" + ) + ############ End ############ + + def forward( + self, + inputs: relax.Expr, + all_seq_len_shape: relax.Expr, + cache_len_shape: relax.Expr, + kv_seq_len_shape: relax.Expr, + past_key_values: relax.Expr, + ): + hidden_states, key_value_cache = self.model( + inputs=inputs, + all_seq_len_shape=all_seq_len_shape, + cache_len_shape=cache_len_shape, + kv_seq_len_shape=kv_seq_len_shape, + past_key_values=past_key_values, + ) + + def te_slicing(x: te.Tensor): + return te.compute( + shape=(1, 1, x.shape[-1]), + fcompute=lambda i, j, k: x[i, x.shape[1] - 1, k], + name="slice", + ) + + logits = self.lm_head(nn.emit_te(te_slicing, hidden_states, primfunc_name_hint="slice")) + if logits.struct_info.dtype != "float32": + logits = nn.emit(relax.op.astype(logits, "float32")) + + return logits, key_value_cache + + +def get_param_quant_kind(name: str, param_info: relax.TensorStructInfo) -> ParamQuantKind: + if "embed_tokens" in name: + return ParamQuantKind.embedding_table + elif "lm_head.weight" in name: + return ParamQuantKind.final_fc_weight + elif param_info.ndim == 2 and name.endswith(".weight"): + return ParamQuantKind.linear_weight + else: + return ParamQuantKind.others + + +def create_embed_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MistralConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "embed" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") + with bb.function(func_name): + model = MistralEmbedTokensWrapper(config, tvm.tir.Var("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + with bb.dataflow(): + inputs_embeds = model(input_ids) + params = [input_ids] + model.parameters() + gv = bb.emit_output(inputs_embeds) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 1)) + + +def create_encoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MistralConfig, + quant_scheme: QuantizationScheme, + sep_embed: bool = False, +) -> None: + func_name = "prefill_with_embed" if sep_embed else "prefill" + + bsz = 1 + seq_len = tvm.tir.Var("n", "int64") # number of tokens for the input + all_seq_len = tvm.tir.Var("m", "int64") # total_seq_len in `llm_chat.cc` (including seq_len) + cache_len = tvm.tir.Var("c", "int64") # cache_len captures number of elements in the cache + kv_seq_len = tvm.tir.Var( + "k", "int64" + ) # kv_seq_len captures number of elements in cache + seq_len + + hidden_size = config.hidden_size + with bb.function(func_name): + model = MistralForCausalLM(config, tvm.tir.Var("vocab_size", "int64"), sep_embed) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + inputs = ( + nn.Placeholder((bsz, seq_len, hidden_size), dtype=config.dtype, name="inputs_embeds") + if sep_embed + else nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids") + ) + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + cache_len_shape = relax.Var("cache_len", relax.ShapeStructInfo((cache_len,))) + kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + inputs, + all_seq_len_shape, + cache_len_shape, + kv_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + inputs, + all_seq_len_shape, + cache_len_shape, + kv_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 5)) + + +def create_decoding_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: MistralConfig, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "decode" + + bsz = 1 + all_seq_len = tvm.tir.Var("m", "int64") + cache_len = tvm.tir.Var("c", "int64") # cache_len captures number of elements in the cache + kv_seq_len = tvm.tir.Var( + "k", "int64" + ) # kv_seq_len captures number of elements in cache + seq_len + + with bb.function(func_name): + model = MistralForCausalLM(config, tvm.tir.Var("vocab_size", "int64")) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids = nn.Placeholder((bsz, 1), dtype="int32", name="input_ids") + all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((all_seq_len,))) + cache_len_shape = relax.Var("cache_len", relax.ShapeStructInfo((cache_len,))) + kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) + past_key_values = relax.Var( + "kv_cache", + relax.TupleStructInfo( + [relax.ObjectStructInfo() for _ in range(config.num_hidden_layers * 2)] + ), + ) + with bb.dataflow(): + logits, key_value_cache = model( + input_ids, + all_seq_len_shape, + cache_len_shape, + kv_seq_len_shape, + past_key_values=past_key_values, + ) + params = [ + input_ids, + all_seq_len_shape, + cache_len_shape, + kv_seq_len_shape, + past_key_values, + ] + model.parameters() + gv = bb.emit_output((logits, relax.Tuple(key_value_cache))) + bb.emit_func_output(gv, params) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", 5)) + + +def create_kv_cache_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: + num_key_value_heads = config.get_num_key_value_heads() // config.num_shards + init_shape = relax.ShapeExpr( + ( + config.sliding_window, + num_key_value_heads, + config.hidden_size // config.num_attention_heads, # head_dim + ) + ) + with bb.function("create_kv_cache", []): + with bb.dataflow(): + zeros = bb.emit(relax.op.zeros(init_shape, config.dtype)) + caches = [] + f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create") + for _ in range(config.num_hidden_layers * 2): + caches.append( + bb.emit( + relax.Call( + f_kv_cache_create, + args=[zeros, init_shape, relax.PrimValue(0)], + sinfo_args=[relax.ObjectStructInfo()], + ) + ) + ) + gv = bb.emit_output(caches) + bb.emit_func_output(gv) + + +def create_softmax_func(bb: relax.BlockBuilder, config: MistralConfig) -> None: + with bb.function("softmax_with_temperature"): + logits = nn.Placeholder( + (1, 1, tvm.tir.Var("vocab_size", "int64")), dtype="float32", name="logits" + ) + temperature = nn.Placeholder((), dtype="float32", name="temperature") + with bb.dataflow(): + div = bb.emit(relax.op.divide(logits, temperature)) + softmax = bb.emit(relax.op.nn.softmax(div, axis=-1)) + gv = bb.emit_output(softmax) + bb.emit_func_output(gv, [logits, temperature]) + + +def get_model(args, hf_config): + model_name = args.model + dtype = args.quantization.model_dtype + sep_embed = args.sep_embed + assert not sep_embed, "Mistral does not support separate embedding." + + if args.sliding_window != -1: + hf_config["sliding_window"] = args.sliding_window + + config = MistralConfig( + **hf_config, + dtype=dtype, + combine_matmul=True, + num_shards=args.num_shards, + build_model_only=args.build_model_only, + sliding_window_chunk_size=args.sliding_window_chunk_size, + ) + + param_manager = ParamManager() + bb = relax.BlockBuilder() + + create_encoding_func(bb, param_manager, config, args.quantization, sep_embed) + create_decoding_func(bb, param_manager, config, args.quantization) + create_kv_cache_func(bb, config) + create_softmax_func(bb, config) + create_metadata_func( + bb, + model_name=model_name, + max_window_size=config.max_sequence_length, + stop_tokens=[2], + add_prefix_space=False, + ) + + mod = bb.get() + for gv in mod.functions: + func = mod[gv] + if isinstance(func, relax.Function): + mod[gv] = func.with_attr( + "tir_var_upper_bound", + { + "n": config.sliding_window_chunk_size, + "c": config.sliding_window, + "k": config.sliding_window + config.sliding_window_chunk_size, + }, + ) + + if args.build_model_only: + return mod, param_manager, None, config + + def f_convert_pname_fwd(pname: str) -> List[str]: + if not config.combine_matmul: + return [pname] + + qkv_str = "query_key_value_proj" + gate_up_str = "gate_up_proj" + if qkv_str in pname: + return [ + pname.replace(qkv_str, "q_proj"), + pname.replace(qkv_str, "k_proj"), + pname.replace(qkv_str, "v_proj"), + ] + elif gate_up_str in pname: + return [ + pname.replace(gate_up_str, "gate_proj"), + pname.replace(gate_up_str, "up_proj"), + ] + else: + return [pname] + + def f_convert_param_bkwd(torch_pname: str, torch_param): + if not config.combine_matmul: + return [(torch_pname, torch_param.astype(dtype))] + + combined_layers = ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"] + if any([name in torch_pname for name in combined_layers]): + return None + return [(torch_pname, torch_param.astype(dtype))] + + def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): + # Expected to enter this function only for the combined linear matmul weights. + # Other weights are supposed to be loaded in `f_convert_param_bkwd` since + # each other relax param has a unique corresponding torch param. + if not config.combine_matmul: + # When matmul combination is not turned on, each relax param has a unique + # corresponding torch param, and this function is not expected to be entered. + raise NotImplementedError( + "Matmul combination is not turned on, and the function " + "is not expected to be entered" + ) + hidden_size = config.hidden_size + head_dim = config.hidden_size // config.num_attention_heads + + if "query_key_value_proj" in relax_pname: + q_heads = config.num_attention_heads + kv_heads = config.get_num_key_value_heads() + q, k, v = torch_params + assert q.shape == (q_heads * head_dim, hidden_size) + assert k.shape == (kv_heads * head_dim, hidden_size) + assert v.shape == (kv_heads * head_dim, hidden_size) + qkv = np.concatenate([q, k, v], axis=0).astype(dtype) + return qkv + if "gate_up_proj" in relax_pname: + gate, up = torch_params + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + return gate_up + raise ValueError("Unexpected param loading") + + param_manager.set_param_loading_func( + args.model_path, + args.use_safetensors, + f_convert_pname_fwd, + f_convert_param_bkwd, + f_compute_relax_param, + ) + + device = tvm.cpu() + param_list = [None] * param_manager.nparam_to_load + + head_dim = config.hidden_size / config.num_attention_heads + inv_freq = 1.0 / (config.rope_theta ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)) + + # The following cos/sin values can be removed but **are kept for compatibility issues**. + t = np.arange(2048, dtype=inv_freq.dtype) + freqs = np.einsum("i,j->ij", t, inv_freq) + emb = np.concatenate((freqs, freqs), axis=-1) + param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device) + param_list[-1] = tvm.nd.array(np.sin(emb).astype(config.dtype), device) + + return mod, param_manager, param_list, config diff --git a/tests/python/support/test_sliding_window_mask.py b/tests/python/support/test_sliding_window_mask.py new file mode 100644 index 0000000000..fa727ced9a --- /dev/null +++ b/tests/python/support/test_sliding_window_mask.py @@ -0,0 +1,339 @@ +# fmt: off +"""For testing `_make_sliding_window_mask` in mistral.py""" + +import unittest + +import numpy as np +import tvm +from tvm import relax +from tvm.runtime import ShapeTuple + +from mlc_llm.relax_model.mistral import _make_sliding_window_mask + + +def _create_vm(): + # pylint: disable=too-many-locals + bb = relax.BlockBuilder() + + # Step 1: Build `_make_sliding_window_mask()` into an IRModule + bsz = tvm.tir.Var("bsz", "int64") + seq_length = tvm.tir.Var("seq_length", "int64") # tgt_len + kv_seq_len = tvm.tir.Var("kv_seq_len", "int64") + sliding_window = tvm.tir.Var("sliding_window", "int64") + + with bb.function("main"): + # Convert to relax.Var because params to an IRModule function needs to be relax.Var + bsz_shape = relax.Var("bsz", relax.ShapeStructInfo((bsz,))) + seq_length_shape = relax.Var("seq_length", relax.ShapeStructInfo((seq_length,))) + kv_seq_len_shape = relax.Var("kv_seq_len", relax.ShapeStructInfo((kv_seq_len,))) + sliding_window_shape = relax.Var("sliding_window", relax.ShapeStructInfo((sliding_window,))) + + # Convert back to tir.Var since `_prepare_sliding_window_mask` needs it to be tir.Var + with bb.dataflow(): + bsz_input = bsz_shape.struct_info.values[0] + seq_length_input = seq_length_shape.struct_info.values[0] + kv_seq_len_input = kv_seq_len_shape.struct_info.values[0] + sliding_window_input = sliding_window_shape.struct_info.values[0] + mask = _make_sliding_window_mask( + (bsz_input, seq_length_input), + kv_seq_len_input, + sliding_window_input, + "float32", + ) + params = [ + bsz_shape, + seq_length_shape, + kv_seq_len_shape, + sliding_window_shape, + ] + gv = bb.emit_output(mask) + bb.emit_func_output(gv, params) + + # Step 2. Optimize IRModule + mod = bb.get() + mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter + with tvm.target.Target("cuda"): + mod = tvm.tir.transform.DefaultGPUSchedule()(mod) + + # Step 3. Deploy to GPU + ex = relax.build(mod, "cuda") + vm = relax.VirtualMachine(ex, tvm.cuda()) #pylint: disable=redefined-outer-name + return vm + + +vm = _create_vm() + +class SlidingWindowMaskTest(unittest.TestCase): + """ + The sliding window mask is based on figure 3 of the Mistral paper. + There are three cases when making a mask: first prefill, subsequent prefill, + and decoding. + + 1. First Prefill + This is when the cache is empty (i.e. kv_seq_len == 0). If tgt_len <= sliding_window, + this is just a normal causal mask. Otherwise, e.g. tgt_len = 3, WS = 2, we create a + mask below: + 1, 0, 0 + 1, 1, 0 + 0, 1, 1 + + 2. Subsequent Prefill + This is when the cache is not empty and yet tgt_len > 1. + e.g. t0-t4 in cache; current input is t5-t7; WS=5 + 0, 1, 2, 3, 4, | 5, 6, 7 + + 0, 1, 1, 1, 1, | 1, 0, 0 + 0, 0, 1, 1, 1, | 1, 1, 0 + 0, 0, 0, 1, 1, | 1, 1, 1 + [in cache] [current] + + 3. Decode + It will always be ones with shape (1 + kv_seq_len) since cache_size equals sliding_window. + Note that a prefilling (first or subsequent) with chunk_size of 1 is equivalent to a decode + in mask making. + """ + + ################### 1. TESTS FOR FIRST PREFILL ################### + def test_first_prefill_chunk_size_smaller_than_ws(self): + """ + When chunk size < WS, we return a normal causal mask. + Here, chunk size 3, WS 5. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([3]) # chunk size is 3 + kv_seq_len = ShapeTuple([3]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38], + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_first_prefill_chunk_size_equals_ws(self): + """ + When chunk_size == WS, we also return a normal causal mask. + Here both chunk size and WS are 5. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([5]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38], + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_first_prefill_chunk_size_greater_than_ws(self): + """ + When chunk_size > WS, return a normal causal mask but each row only has at most WS 1's. + Here chunk_size = 5, WS=3. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([5]) + sliding_window = ShapeTuple([3]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38, -3.402823e38], + [3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38, -3.402823e38], + [-3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38, -3.402823e38], + [-3.402823e38, -3.402823e38, 3.402823e38, 3.402823e38, 3.402823e38], + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_first_prefill_chunk_size_one(self): + """ + Corner case: the prompt only has 1 token. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([1]) + sliding_window = ShapeTuple([3]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + [3.402823e38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + ################### 2. TESTS FOR SUBSEQUENT PREFILL ################### + def test_subsequent_prefill_1(self): + """ + Test 1: chunk size is 3, WS is 5, cache carrying t0, t1, t2; input t3, t4, t5. + """ + + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([3]) + kv_seq_len = ShapeTuple([6]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t0 t1 t2 t3 t4 t5 + [ 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [ 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_2(self): + """ + Test 2: chunk size is 3, WS is 5, cache carrying t1 - t5 (t0 is overwritten); + input t6, t7, t8. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([3]) + kv_seq_len = ShapeTuple([8]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t1 t2 t3 t4 t5 t6 t7 t8 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_3(self): + """ + Test 3: chunk size is 5, WS is 5, cache carrying t0-t4; input t5-t9. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([10]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_4(self): + """ + Test 4: chunk size is 5, WS is 3, cache carrying t2-t4 (t0, t1 did not + stay in cache); input t5-t9. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([5]) + kv_seq_len = ShapeTuple([8]) + sliding_window = ShapeTuple([3]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE | CURRENT CHUNK | + # t2 t3 t4 t5 t6 t7 t8 t9 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, -3.402823e+38], + [-3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, -3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_subsequent_prefill_5(self): + """ + Test 5: chunk size is 5, WS is 5, cache carrying t5-t9 (t0-t4 overwritten); + input t10 (remainder of a prompt). Note that this test can also be + viewed as a decode. That is, prefilling a chunk of size 1, is the same is decoding. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([6]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE |CURRENT CHUNK| + # t5 t6 t7 t8 t9 t10 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + ################### 3. TESTS FOR DECODE ################### + def test_decode_1(self): + """ + Test 1: chunk size is 5, WS is 5, cache carrying t5-t9 (t0-t4 overwritten); + input t10 (decoding). + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([6]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # pylint: disable=line-too-long + # | IN CACHE |CURRENT CHUNK| + # t5 t6 t7 t8 t9 t10 + [-3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + def test_decode_2(self): + """ + Test 2 (Cache not full): prompt is size 4, WS is 5, cache carrying t0-t3; input t4. + """ + bsz = ShapeTuple([1]) + seq_length = ShapeTuple([1]) + kv_seq_len = ShapeTuple([5]) + sliding_window = ShapeTuple([5]) + + result = vm["main"](bsz, seq_length, kv_seq_len, sliding_window) + + correct = np.array([[[ + # | IN CACHE |CURRENT CHUNK| + # t0 t1 t2 t3 t4 + [3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38, 3.402823e+38] + ]]]).astype("float32") + + np.testing.assert_array_equal(result.numpy(), correct) + + +if __name__ == "__main__": + unittest.main() From 47167040be6bac42b67113736c16844b60e400fc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 01:20:43 -0700 Subject: [PATCH 086/116] Add Python API for Weight Conversion (#1182) This PR primarily does a major refactoring to introduce Python API that is consistent with the CLI API. Besides, it includes the following fixes and enhancements: - More info provided to `isort` for better formatting in `pyproject.toml`; - Print out the default value of all arguments in argparse command line; - Ensure `--device` is always available locally when doing weight conversion; - Add argument echoing in weight conversion to be consistent with its counterpart in compilation; - Add a consistency checker to make sure the shapes/dtypes of all tensors from weight conversion is consistent with compilation; - Echo the total size of parameters; - Better logging of each parameter's shape and dtype, and either or not its quantized; - More structure robustification, renaming `parameter/` to `loader/` to be more explicit about its intention; - Inline and remove `ParamQuantizer` into the loader to improve logging and the logic flow; - Always add instructions "Use `--xxx` to override" for any options that are auto detected to be more informative to end users; - Fix wrong shape calculation when quantizing `nn.Embedding`; - Fix wrong dtype calculation in group quantization when the input dtype is different from model dtype (e.g. "float32" in torch, but the model dtype in quantization is fp16 in `q4f16_1`); - Fix inconsistent param names in layers such as `GroupQuantizeLinear`; - Fix dtype inconsistency when a parameter is not quantized; - Fix existing unittests. --- pyproject.toml | 2 + python/mlc_chat/cli/compile.py | 17 ++- python/mlc_chat/cli/convert_weight.py | 78 ++++------- python/mlc_chat/compiler/__init__.py | 3 +- python/mlc_chat/compiler/convert_weight.py | 124 +++++++++++++++++ .../{parameter => loader}/__init__.py | 1 + .../huggingface_loader.py | 48 +++---- python/mlc_chat/compiler/loader/loader.py | 11 ++ .../compiler/{parameter => loader}/mapping.py | 18 ++- .../compiler/{parameter => loader}/stats.py | 4 +- .../compiler/{parameter => loader}/utils.py | 41 +----- .../mlc_chat/compiler/model/llama_loader.py | 85 ++++++++++++ .../compiler/model/llama_parameter.py | 61 --------- .../compiler/model/llama_quantization.py | 5 +- python/mlc_chat/compiler/model/model.py | 8 +- .../quantization/group_quantization.py | 129 +++++++++--------- python/mlc_chat/rest.py | 1 + python/mlc_chat/support/auto_config.py | 6 +- python/mlc_chat/support/auto_target.py | 29 ++++ python/mlc_chat/support/auto_weight.py | 27 +++- .../test_sliding_window_mask.py | 3 +- .../{parameter => loader}/test_huggingface.py | 11 +- tests/python/model/test_llama.py | 1 + tests/python/model/test_llama_quantization.py | 22 ++- .../quantization/test_group_quantization.py | 33 ++++- tests/python/support/test_auto_config.py | 1 + tests/python/support/test_auto_weight.py | 1 + 27 files changed, 481 insertions(+), 289 deletions(-) create mode 100644 python/mlc_chat/compiler/convert_weight.py rename python/mlc_chat/compiler/{parameter => loader}/__init__.py (87%) rename python/mlc_chat/compiler/{parameter => loader}/huggingface_loader.py (86%) create mode 100644 python/mlc_chat/compiler/loader/loader.py rename python/mlc_chat/compiler/{parameter => loader}/mapping.py (84%) rename python/mlc_chat/compiler/{parameter => loader}/stats.py (95%) rename python/mlc_chat/compiler/{parameter => loader}/utils.py (59%) create mode 100644 python/mlc_chat/compiler/model/llama_loader.py delete mode 100644 python/mlc_chat/compiler/model/llama_parameter.py rename tests/{python/support => legacy-python}/test_sliding_window_mask.py (99%) rename tests/python/{parameter => loader}/test_huggingface.py (88%) diff --git a/pyproject.toml b/pyproject.toml index ccf754554f..b1f082240c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ # under the License. [tool.isort] profile = "black" +src_paths = ["python/mlc_chat"] +known_third_party = ["numpy", "tvm", "tqdm", "torch", "transformers"] [tool.black] line-length = 100 diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index d4c648c097..2fcc9c0213 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -60,14 +60,16 @@ def _parse_output(path: Union[str, Path]) -> Path: default="auto", choices=["auto"] + list(MODELS.keys()), help="Model architecture, for example, llama. If not set, it is inferred " - "from the config.json file.", + "from the config.json file. " + "(default: %(default)s)", ) parser.add_argument( "--device", type=str, default="auto", help="The GPU device to compile the model to. If not set, it is inferred from locally " - "available GPUs.", + "available GPUs. " + "(default: %(default)s)", ) parser.add_argument( "--host", @@ -81,17 +83,19 @@ def _parse_output(path: Union[str, Path]) -> Path: "x86-64", ], help="The host CPU ISA to compile the model to. If not set, it is inferred from the " - "local CPU.", + "local CPU. " + "(default: %(default)s)", ) parser.add_argument( "--opt", type=OptimizationFlags.from_str, - default="", + default="O2", help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, " "denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, " "and O3 represents extreme optimization that could potentially break the system. " "Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. " - '--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"', + '--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0. ' + "(default: %(default)s)", ) parser.add_argument( "--prefix-symbols", @@ -99,7 +103,8 @@ def _parse_output(path: Union[str, Path]) -> Path: default="", help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". ' "This is useful when compiling multiple models into a single library to avoid symbol " - "conflicts. Differet from objcopy, this takes no effect for shared library.", + "conflicts. Differet from objcopy, this takes no effect for shared library. " + '(default: "")', ) parser.add_argument( "--output", diff --git a/python/mlc_chat/cli/convert_weight.py b/python/mlc_chat/cli/convert_weight.py index 45fe0fa286..cf4c205009 100644 --- a/python/mlc_chat/cli/convert_weight.py +++ b/python/mlc_chat/cli/convert_weight.py @@ -4,14 +4,10 @@ from pathlib import Path from typing import Union -import tvm -from mlc_chat.compiler import MODELS, QUANTIZATION -from mlc_chat.compiler.parameter import HuggingFaceLoader -from mlc_chat.support import tqdm -from tvm.contrib import tvmjs +from mlc_chat.compiler import MODELS, QUANTIZATION, convert_weight from ..support.auto_config import detect_config, detect_model_type -from ..support.auto_target import detect_target_and_host +from ..support.auto_target import detect_device from ..support.auto_weight import detect_weight logging.basicConfig( @@ -57,17 +53,17 @@ def _parse_output(path: Union[str, Path]) -> Path: parser.add_argument( "--source", type=str, - required=False, default="auto", - help="The path to original model weight, infer from `config` if missing", + help="The path to original model weight, infer from `config` if missing. " + "(default: %(default)s)", ) parser.add_argument( "--source-format", type=str, - required=False, choices=["auto", "huggingface-torch", "huggingface-safetensor"], default="auto", - help="The format of source model weight, infer from `config` if missing", + help="The format of source model weight, infer from `config` if missing. " + "(default: %(default)s)", ) parser.add_argument( "--quantization", @@ -82,14 +78,16 @@ def _parse_output(path: Union[str, Path]) -> Path: default="auto", choices=["auto"] + list(MODELS.keys()), help="Model architecture, for example, llama. If not set, it is inferred " - "from the config.json file.", + "from the config.json file. " + "(default: %(default)s)", ) parser.add_argument( "--device", - type=str, default="auto", - help="The device used to do quantization, \ - for example `auto` / `cuda:0` / `cuda --arch sm86`", + type=detect_device, + help="The device used to do quantization, for example, / `cuda:0`. " + "Detect from local environment if not specified. " + "(default: %(default)s)", ) parser.add_argument( "--output", @@ -100,49 +98,21 @@ def _parse_output(path: Union[str, Path]) -> Path: "will contain `params_shard_*.bin` and `ndarray-cache.json`.", ) - # parse arguments parsed = parser.parse_args() - parsed.source = _parse_source(parsed.source, parsed.config) - parsed.params, parsed.source_format = detect_weight( - parsed.source, parsed.config, weight_format=parsed.source_format + parsed.source, parsed.source_format = detect_weight( + weight_path=_parse_source(parsed.source, parsed.config), + config_json_path=parsed.config, + weight_format=parsed.source_format, ) model = detect_model_type(parsed.model_type, parsed.config) - - # detect quantization target - quantization_target, _ = detect_target_and_host(parsed.device) - if parsed.device != "auto": - device = tvm.runtime.device(parsed.device.split(" ")[0]) - else: - if quantization_target.kind.name == "cuda": - device = tvm.cuda(0) - else: - device = tvm.cpu(0) - - # model config & quantization config - model_config = model.config.from_file(parsed.config) - quantization_config = QUANTIZATION[parsed.quantization] - _, quantize_map = model.quantize[quantization_config.kind](model_config, quantization_config) - - # loader setup - if parsed.source_format in ("huggingface-torch", "huggingface-safetensor"): - loader = HuggingFaceLoader( - path=parsed.params, - extern_param_map=model.source[parsed.source_format](model_config, None), - quantize_param_map=quantize_map, - ) - else: - raise ValueError(f"Unsupported loader source format: {parsed.source_format}") - - # load and quantize - with quantization_target, tqdm.redirect(): - param_dict = dict(loader.load(device=device)) - - # dump to output directory - tvmjs.dump_ndarray_cache( - param_dict, - f"{parsed.output}/params", - meta_data={"ParamSize": len(param_dict)}, - encode_format="raw", + convert_weight( + config=parsed.config, + quantization=QUANTIZATION[parsed.quantization], + model=model, + device=parsed.device, + source=parsed.source, + source_format=parsed.source_format, + output=parsed.output, ) diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index cf68426f8e..e65e12a5d9 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -4,7 +4,8 @@ """ from . import compiler_pass from .compile import CompileArgs, compile # pylint: disable=redefined-builtin +from .convert_weight import convert_weight from .flags_optimization import OptimizationFlags +from .loader import ExternMapping, HuggingFaceLoader, QuantizeMapping from .model import MODEL_PRESETS, MODELS, Model -from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping from .quantization import QUANTIZATION diff --git a/python/mlc_chat/compiler/convert_weight.py b/python/mlc_chat/compiler/convert_weight.py new file mode 100644 index 0000000000..554c422a40 --- /dev/null +++ b/python/mlc_chat/compiler/convert_weight.py @@ -0,0 +1,124 @@ +"""Python entrypoint of weight conversion.""" +import dataclasses +import logging +import math +from io import StringIO +from pathlib import Path + +import numpy as np +from tvm.contrib import tvmjs +from tvm.runtime import Device, NDArray +from tvm.runtime import cpu as cpu_device +from tvm.target import Target + +from mlc_chat.support import tqdm + +from ..support.style import bold, green +from .loader import LOADER +from .model import Model +from .quantization import Quantization + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ConversionArgs: # pylint: disable=too-many-instance-attributes + """Arguments to MLC LLM's weight conversation and quantization flow.""" + + config: Path + quantization: Quantization + model: Model + device: Device + source: Path + source_format: str + output: Path + + +def _echo_args(args: ConversionArgs) -> None: + def _device_to_str(device: Device) -> str: + return f"{Device.MASK2STR[device.device_type]}:{device.device_id}" + + out = StringIO() + print(f"{bold('Weight conversion with arguments:')}", file=out) + print(f" {bold('--config'):<25} {args.config}", file=out) + print(f" {bold('--quantization'):<25} {args.quantization}", file=out) + print(f" {bold('--model-type'):<25} {args.model.name}", file=out) + print(f" {bold('--device'):<25} {_device_to_str(args.device)}", file=out) + print(f" {bold('--source'):<25} {args.source}", file=out) + print(f" {bold('--source-format'):<25} {args.source_format}", file=out) + print(f" {bold('--output'):<25} {args.output}", file=out) + print(out.getvalue().rstrip()) + + +def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-locals + # model config & quantization config + model_config = args.model.config.from_file(args.config) + model, quantize_map = args.model.quantize[args.quantization.kind]( + model_config, args.quantization + ) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) # type: ignore[attr-defined] + named_params = dict(_named_params) + + def _check_param(name: str, param: NDArray): + nonlocal named_params + if name not in named_params: + raise ValueError(f"Parameter not found in model: {name}") + if name in param_dict: + raise ValueError(f"Duplication: Parameter {name} already computed") + expect_shape = tuple(int(x) for x in named_params[name].shape) + expect_dtype = named_params[name].dtype + actual_shape = tuple(int(x) for x in param.shape) + actual_dtype = param.dtype + if actual_shape != expect_shape: + raise ValueError( + f"Parameter {name} has shape {param.shape}, but expected {expect_shape}" + ) + if actual_dtype != expect_dtype: + raise ValueError( + f"Parameter {name} has dtype {param.dtype}, but expected {expect_dtype}" + ) + del named_params[name] + + # load and quantize + param_dict = {} + total_bytes = 0.0 + total_params = 0 + with Target.from_device(args.device), tqdm.redirect(): + for name, param in LOADER[args.source_format]( + path=args.source, + extern_param_map=args.model.source[args.source_format](model_config, args.quantization), + quantize_param_map=quantize_map, + ).load(device=args.device): + _check_param(name, param) + param = param.copyto(cpu_device()) + param_dict[name] = param + total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize + total_params += math.prod(param.shape) + if named_params: + raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}") + # dump to output directory + tvmjs.dump_ndarray_cache( + param_dict, + str(args.output), + meta_data={"ParamSize": len(param_dict)}, + encode_format="raw", + ) + logger.info("%s to %s", green("Saved"), bold(str(args.output))) + logger.info("%s: %.3f GB", green("Total parameter size"), total_bytes / (1024**3)) + logger.info("%s: %d", green("Total number of parameter tensors"), len(param_dict)) + logger.info(f"%s: {total_params:,}", green("Total number of parameters")) + + +def convert_weight( # pylint: disable=too-many-arguments + config: Path, + quantization: Quantization, + model: Model, + device: Device, + source: Path, + source_format: str, + output: Path, +): + """MLC LLM's weight conversation and quantization flow.""" + args = ConversionArgs(config, quantization, model, device, source, source_format, output) + _echo_args(args) + _convert_args(args) diff --git a/python/mlc_chat/compiler/parameter/__init__.py b/python/mlc_chat/compiler/loader/__init__.py similarity index 87% rename from python/mlc_chat/compiler/parameter/__init__.py rename to python/mlc_chat/compiler/loader/__init__.py index f119b01f91..cc8ba9c9ed 100644 --- a/python/mlc_chat/compiler/parameter/__init__.py +++ b/python/mlc_chat/compiler/loader/__init__.py @@ -3,4 +3,5 @@ parameters and parameters in MLC-defined models. """ from .huggingface_loader import HuggingFaceLoader +from .loader import LOADER, Loader from .mapping import ExternMapping, QuantizeMapping diff --git a/python/mlc_chat/compiler/parameter/huggingface_loader.py b/python/mlc_chat/compiler/loader/huggingface_loader.py similarity index 86% rename from python/mlc_chat/compiler/parameter/huggingface_loader.py rename to python/mlc_chat/compiler/loader/huggingface_loader.py index 550dec3071..a58220ab65 100644 --- a/python/mlc_chat/compiler/parameter/huggingface_loader.py +++ b/python/mlc_chat/compiler/loader/huggingface_loader.py @@ -12,14 +12,10 @@ from tvm.runtime import Device, NDArray from tvm.runtime.ndarray import array as as_ndarray +from ...support.style import bold from .mapping import ExternMapping, QuantizeMapping from .stats import Stats -from .utils import ( - ParamQuantizer, - check_parameter_usage, - load_safetensor_shard, - load_torch_shard, -) +from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard logger = logging.getLogger(__name__) @@ -100,7 +96,7 @@ def __init__( raise FileNotFoundError(f"Unknown file suffix: {path}") check_parameter_usage(extern_param_map, set(self.torch_to_path.keys())) - def load(self, device: Optional[Device] = None) -> Iterator[Tuple[str, NDArray]]: + def load(self, device: Device) -> Iterator[Tuple[str, NDArray]]: """Load the parameters and yield the MLC parameter and its value. Parameters @@ -116,30 +112,27 @@ def load(self, device: Optional[Device] = None) -> Iterator[Tuple[str, NDArray]] mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) for mlc_name in tqdm(mlc_names): param = self._load_mlc_param(mlc_name, device=device) - - if self.quantize_param_map: + if self.quantize_param_map and mlc_name in self.quantize_param_map.param_map: with self.stats.timer("quant_time_sec"): - quantized_params = ParamQuantizer(self.quantize_param_map).quantize( - mlc_name, param - ) - if not quantized_params: + q_names = self.quantize_param_map.param_map[mlc_name] + q_params = self.quantize_param_map.map_func[mlc_name](param) + device.sync() + for q_name, q_param in zip(q_names, q_params): logger.info( - ' Skipped Quantizing Parameter: "%s", shape: %s, dtype: %s', - mlc_name, - param.shape, - param.dtype, + '[Quantized] Parameter: "%s", shape: %s, dtype: %s', + bold(q_name), + q_param.shape, + q_param.dtype, ) - yield mlc_name, param - else: - for quantized_name, quantized_param in quantized_params: - logger.info( - ' Quantized Parameter: "%s", shape: %s, dtype: %s', - quantized_name, - quantized_param.shape, - quantized_param.dtype, - ) - yield quantized_name, quantized_param + yield q_name, q_param else: + logger.info( + '[Not quantized] Parameter: "%s", shape: %s, dtype: %s', + bold(mlc_name), + param.shape, + param.dtype, + ) + device.sync() yield mlc_name, param cached_files = list(self.cached_files.keys()) for path in cached_files: @@ -168,7 +161,6 @@ def _load_mlc_param(self, mlc_name: str, device: Optional[Device]) -> NDArray: # Step 4. Apply the mapping function with self.stats.timer("map_time_sec"): param = self.extern_param_map.map_func[mlc_name](*torch_params) - logger.info(' Parameter: "%s", shape: %s, dtype: %s', mlc_name, param.shape, param.dtype) if device: return as_ndarray(param, device=device) return as_ndarray(param) diff --git a/python/mlc_chat/compiler/loader/loader.py b/python/mlc_chat/compiler/loader/loader.py new file mode 100644 index 0000000000..267ece72ab --- /dev/null +++ b/python/mlc_chat/compiler/loader/loader.py @@ -0,0 +1,11 @@ +"""A centralized registry of all existing loaders.""" +from typing import Any, Dict + +from .huggingface_loader import HuggingFaceLoader + +Loader = Any + +LOADER: Dict[str, Any] = { + "huggingface-torch": HuggingFaceLoader, + "huggingface-safetensor": HuggingFaceLoader, +} diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/loader/mapping.py similarity index 84% rename from python/mlc_chat/compiler/parameter/mapping.py rename to python/mlc_chat/compiler/loader/mapping.py index aab674cfa8..26d6811086 100644 --- a/python/mlc_chat/compiler/parameter/mapping.py +++ b/python/mlc_chat/compiler/loader/mapping.py @@ -40,10 +40,24 @@ class ExternMapping: Parameter names in the source weights that are not used in the MLC LLM model definition. """ - param_map: Dict[str, List[str]] - map_func: Dict[str, MapFuncVariadic] + param_map: Dict[str, List[str]] = dataclasses.field(default_factory=dict) + map_func: Dict[str, MapFuncVariadic] = dataclasses.field(default_factory=dict) unused_params: Set[str] = dataclasses.field(default_factory=set) + def add_mapping( + self, + map_from: str, + map_to: List[str], + func: MapFuncVariadic, + ) -> None: + """Add a mapping from MLC parameters to source parametes as well as a mapping function.""" + self.param_map[map_from] = map_to + self.map_func[map_from] = func + + def add_unused(self, name: str): + """Add a parameter name in the source parameters to the set of unused parameters.""" + self.unused_params.add(name) + @dataclasses.dataclass class QuantizeMapping: diff --git a/python/mlc_chat/compiler/parameter/stats.py b/python/mlc_chat/compiler/loader/stats.py similarity index 95% rename from python/mlc_chat/compiler/parameter/stats.py rename to python/mlc_chat/compiler/loader/stats.py index 9f5d1e16fa..d12cd2f257 100644 --- a/python/mlc_chat/compiler/parameter/stats.py +++ b/python/mlc_chat/compiler/loader/stats.py @@ -67,7 +67,7 @@ def mem_rm(self, nbytes: int): def log_time_info(self, weight_format: str): """Log the time used in loading, pre-quantization and quantization.""" logger.info( - "Time used: " + "Time usage: " "%s loading: %.3f sec; " "Pre-quantization mapping: %.3f sec; " "Quantization: %.3f sec", @@ -80,7 +80,7 @@ def log_time_info(self, weight_format: str): def log_mem_usage(self): """Log the Memory usage information.""" logger.info( - "Memory usage: Total size loaded from disk: %.3f GB; Peak memory usage: %.3f GB", + "RAM usage: Peak RAM: %.3f GB. Total bytes loaded from disk: %.3f GB", self.total_memory_gb, self.max_memory_gb, ) diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/loader/utils.py similarity index 59% rename from python/mlc_chat/compiler/parameter/utils.py rename to python/mlc_chat/compiler/loader/utils.py index 3c6c0d1476..af08b1804d 100644 --- a/python/mlc_chat/compiler/parameter/utils.py +++ b/python/mlc_chat/compiler/loader/utils.py @@ -2,53 +2,20 @@ # pylint: disable=too-few-public-methods import logging from pathlib import Path -from typing import TYPE_CHECKING, Iterator, Optional, Set, Tuple +from typing import TYPE_CHECKING, Iterator, Set, Tuple import numpy as np -from .mapping import ExternMapping - if TYPE_CHECKING: from tvm.runtime import NDArray - from ..parameter import QuantizeMapping - -logger = logging.getLogger(__name__) - - -class ParamQuantizer: - """A parameter quantizer that quantizes given mlc-llm parameters""" + from .mapping import ExternMapping - quantize_map: "QuantizeMapping" - def __init__(self, quantize_map: "QuantizeMapping") -> None: - self.quantize_map = quantize_map - - def quantize(self, name: str, param: "NDArray") -> Optional[Iterator[Tuple[str, "NDArray"]]]: - """Apply quantization to the given parameters - - Parameters - ---------- - name : str - The name of the parameter - param : NDArray - The parameter to be quantized - - Returns - ------- - Optional[Iterator[Tuple[str, "NDArray"]]] - The quantized parameters, each with its name, returns None if the parameter is not - quantized. - """ - if name not in self.quantize_map.param_map: - return None - assert name in self.quantize_map.map_func, f"Quantization function for {name} not found." - quantized_names = self.quantize_map.param_map[name] - quantized_params = self.quantize_map.map_func[name](param) - return zip(quantized_names, quantized_params) +logger = logging.getLogger(__name__) -def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]): +def check_parameter_usage(param_map: "ExternMapping", extern_weights: Set[str]): """Check that all external parameters have been used and are stored in the weights file.""" used_extern_names = set(sum(param_map.param_map.values(), [])) # Check 1. All extern parameters in the weight files are used unless explicitly specified diff --git a/python/mlc_chat/compiler/model/llama_loader.py b/python/mlc_chat/compiler/model/llama_loader.py new file mode 100644 index 0000000000..94a6d80600 --- /dev/null +++ b/python/mlc_chat/compiler/model/llama_loader.py @@ -0,0 +1,85 @@ +""" +This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace +PyTorch, HuggingFace safetensors. +""" +import functools + +import numpy as np + +from ..loader import ExternMapping +from ..quantization import Quantization +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM + + +def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of HuggingFace PyTorch parameters. + + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to HuggingFace PyTorch. + """ + model = LlamaForCasualLM(model_config) + if quantization is not None: + model.to(quantization.model_dtype) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + mlc_name = f"{attn}.qkv_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.weight", + f"{attn}.k_proj.weight", + f"{attn}.v_proj.weight", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # Add gates in MLP + mlp = f"model.layers.{i}.mlp" + mlc_name = f"{mlp}.gate_up_proj.weight" + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.weight", + f"{mlp}.up_proj.weight", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial( + lambda x, dtype: x.astype(dtype), + dtype=mlc_param.dtype, + ), + ) + return mapping diff --git a/python/mlc_chat/compiler/model/llama_parameter.py b/python/mlc_chat/compiler/model/llama_parameter.py deleted file mode 100644 index 4c68fdc899..0000000000 --- a/python/mlc_chat/compiler/model/llama_parameter.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -This file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace -PyTorch, HuggingFace safetensors. -""" -from typing import Callable, Dict, List - -import numpy as np - -from ..parameter import ExternMapping -from .llama_config import LlamaConfig -from .llama_model import LlamaForCasualLM - - -def huggingface(model_config: LlamaConfig, _) -> ExternMapping: - """Returns a parameter mapping that maps from the names of MLC LLM parameters to - the names of HuggingFace PyTorch parameters. - - Parameters - ---------- - model_config : LlamaConfig - The configuration of the Llama model. - - Returns - ------- - param_map : ExternMapping - The parameter mapping from MLC to HuggingFace PyTorch. - """ - model = LlamaForCasualLM(model_config) - _, named_params = model.export_tvm(spec=model.get_default_spec()) - parameter_names = {name for name, _ in named_params} - - param_map: Dict[str, List[str]] = {} - map_func: Dict[str, Callable] = {} - unused_params = set() - - for i in range(model_config.num_hidden_layers): - # Add QKV in self attention - attn = f"model.layers.{i}.self_attn" - assert f"{attn}.qkv_proj.weight" in parameter_names - map_func[f"{attn}.qkv_proj.weight"] = lambda q, k, v: np.concatenate([q, k, v], axis=0) - param_map[f"{attn}.qkv_proj.weight"] = [ - f"{attn}.q_proj.weight", - f"{attn}.k_proj.weight", - f"{attn}.v_proj.weight", - ] - # Add gates in MLP - mlp = f"model.layers.{i}.mlp" - assert f"{mlp}.gate_up_proj.weight" in parameter_names - map_func[f"{mlp}.gate_up_proj.weight"] = lambda gate, up: np.concatenate([gate, up], axis=0) - param_map[f"{mlp}.gate_up_proj.weight"] = [ - f"{mlp}.gate_proj.weight", - f"{mlp}.up_proj.weight", - ] - # inv_freq is not used in the model - unused_params.add(f"{attn}.rotary_emb.inv_freq") - - for name in parameter_names: - if name not in map_func: - map_func[name] = lambda x: x - param_map[name] = [name] - return ExternMapping(param_map, map_func, unused_params) diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py index 02376ab9db..ffb2ec71e3 100644 --- a/python/mlc_chat/compiler/model/llama_quantization.py +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -1,9 +1,10 @@ -"""Quantization specs for Llama.""" +"""This file specifies how MLC's Llama parameters are quantized using group quantization +or other formats.""" from typing import Tuple from tvm.relax.frontend import nn -from ..parameter import QuantizeMapping +from ..loader import QuantizeMapping from ..quantization import GroupQuantize from .llama_config import LlamaConfig from .llama_model import LlamaForCasualLM diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index 74159cc188..4e440156c9 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -4,9 +4,9 @@ from tvm.relax.frontend import nn -from ..parameter import ExternMapping, QuantizeMapping +from ..loader import ExternMapping, QuantizeMapping from ..quantization.quantization import Quantization -from . import llama_config, llama_model, llama_parameter, llama_quantization +from . import llama_config, llama_loader, llama_model, llama_quantization ModelConfig = Any """A ModelConfig is an object that represents a model architecture. It is required to have @@ -56,8 +56,8 @@ class Model: model=llama_model.LlamaForCasualLM, config=llama_config.LlamaConfig, source={ - "huggingface-torch": llama_parameter.huggingface, - "huggingface-safetensor": llama_parameter.huggingface, + "huggingface-torch": llama_loader.huggingface, + "huggingface-safetensor": llama_loader.huggingface, }, quantize={ "group-quant": llama_quantization.group_quant, diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 6e28b72a97..ea27410bea 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -1,16 +1,19 @@ """The group quantization config""" -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +import logging +from dataclasses import dataclass +from typing import Any, Callable, List, Optional, Sequence, Tuple import numpy as np -from tvm import DataType, DataTypeCode +from tvm import DataType, DataTypeCode, IRModule from tvm import dlight as dl from tvm import relax, te, tir from tvm.relax.frontend import nn from tvm.runtime import NDArray from tvm.target import Target -from ..parameter import QuantizeMapping +from ..loader import QuantizeMapping + +logger = logging.getLogger(__name__) @dataclass @@ -28,10 +31,6 @@ class GroupQuantize: # pylint: disable=too-many-instance-attributes num_storage_per_group: int = 0 max_int_value: int = 0 - prebuilt_quantize_func: Dict[str, Callable[[NDArray], NDArray]] = field( - default_factory=lambda: {} - ) - def __post_init__(self): assert self.kind == "group-quant" quantize_dtype = DataType(self.quantize_dtype) @@ -48,6 +47,7 @@ def __post_init__(self): raise ValueError("Group size should be divisible by numbers of elements per storage") self.num_storage_per_group = self.group_size // self.num_elem_per_storage self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + self._quantize_func_cache = {} def quantize_model( self, @@ -166,46 +166,55 @@ def quantize_weight(self, weight: NDArray) -> List[NDArray]: ret: List[NDArray] The list of group quantized weights. """ - assert weight.dtype == self.model_dtype assert len(weight.shape) == 2 - dev = weight.device - device_type = dev.MASK2STR[dev.device_type] + device = weight.device + device_type = device.MASK2STR[device.device_type] + + def _create_quantize_func() -> IRModule: + bb = relax.BlockBuilder() # pylint: disable=invalid-name + weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, weight.dtype)) + with bb.function(name="main", params=[weight_var]): + with bb.dataflow(): + lv = bb.emit_te(self._quantize, weight_var) # pylint: disable=invalid-name + gv = bb.emit_output(lv) # pylint: disable=invalid-name + bb.emit_func_output(gv) + return bb.get() + + def _compile_quantize_func(mod: IRModule) -> Callable: + if device_type in ["cuda", "rocm", "metal", "vulkan"]: + target = Target.current() + if target is None: + target = Target.from_device(device) + with target: + mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable + dl.gpu.Reduction(), + dl.gpu.GeneralReduction(), + dl.gpu.Fallback(), + )(mod) + elif device_type == "cpu": + target = "llvm" + mod = relax.transform.LegalizeOps()(mod) + else: + raise NotImplementedError(f"Device type {device_type} is not supported") + ex = relax.build(mod, target=target) + vm = relax.VirtualMachine(ex, device) # pylint: disable=invalid-name + return vm["main"] + key = str((int(weight.shape[0]), int(weight.shape[1]), weight.dtype, device_type)) - if key in self.prebuilt_quantize_func: - return self.prebuilt_quantize_func[key](weight) - bb = relax.BlockBuilder() # pylint: disable=invalid-name - weight_var = relax.Var("weight", relax.TensorStructInfo(weight.shape, self.model_dtype)) - with bb.function(name="quantize", params=[weight_var]): - with bb.dataflow(): - lv = bb.emit_te(self._quantize, weight_var) # pylint: disable=invalid-name - gv = bb.emit_output(lv) # pylint: disable=invalid-name - bb.emit_func_output(gv) - mod = bb.get() - if device_type in ["cuda", "rocm", "metal", "vulkan"]: - target = Target.current() - if target is None: - target = Target.from_device(dev) - with target: - mod = dl.ApplyDefaultSchedule( # type: ignore # pylint: disable=not-callable - dl.gpu.Reduction(), - dl.gpu.GeneralReduction(), - dl.gpu.Fallback(), - )(mod) - elif device_type == "cpu": - target = "llvm" - mod = relax.transform.LegalizeOps()(mod) - else: - raise NotImplementedError(f"Device type {device_type} is not supported") - ex = relax.build(mod, target) - vm = relax.VirtualMachine(ex, dev) # pylint: disable=invalid-name - self.prebuilt_quantize_func[key] = vm["quantize"] - return vm["quantize"](weight) + quantize_func = self._quantize_func_cache.get(key, None) + if quantize_func is None: + logger.info("Compiling quantize function for key: %s", key) + quantize_func = _compile_quantize_func(_create_quantize_func()) + self._quantize_func_cache[key] = quantize_func + return quantize_func(weight) def _quantize( # pylint: disable=too-many-locals - self, weight: te.Tensor + self, + weight: te.Tensor, ) -> Tuple[te.Tensor, te.Tensor]: """Group quantization for weight tensor, defined in tensor expression.""" assert len(weight.shape) == 2 + max_int = tir.const(self.max_int_value, self.model_dtype) n, k = weight.shape # pylint: disable=invalid-name quantize_dtype = DataType(self.quantize_dtype) # compute scale per group @@ -223,25 +232,20 @@ def _quantize( # pylint: disable=too-many-locals ) scale = te.compute( scale_shape, - lambda i, j: max_abs[i, j] / tir.const(self.max_int_value, self.model_dtype), + lambda i, j: max_abs[i, j].astype(self.model_dtype) / max_int, name="scale", ) - # compute scaled weight - tir_max_int = tir.const(self.max_int_value, self.model_dtype) - tir_zero = tir.const(0, self.model_dtype) - tir_max_int_2 = tir.const(self.max_int_value * 2, self.model_dtype) scaled_weight = te.compute( shape=weight.shape, fcompute=lambda i, j: tir.min( tir.max( - tir.round(weight[i, j] / scale[i, j // self.group_size] + tir_max_int), - tir_zero, + tir.round(weight[i, j] / scale[i, j // self.group_size] + max_int), + tir.const(0, self.model_dtype), ), - tir_max_int_2, + max_int * 2, ).astype(self.storage_dtype), ) - # compute quantized weight per storage r = te.reduce_axis((0, self.num_elem_per_storage), name="r") # pylint: disable=invalid-name num_storage = self.num_storage_per_group * num_group @@ -274,11 +278,11 @@ def __init__( # pylint: disable=too-many-arguments self.out_features = out_features self.out_dtype = out_dtype self.config = config - self.weight = nn.Parameter( + self.q_weight = nn.Parameter( (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), config.storage_dtype, ) - self.scale = nn.Parameter( + self.q_scale = nn.Parameter( (out_features, tir.ceildiv(in_features, config.group_size)), config.model_dtype ) if bias: @@ -333,7 +337,7 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name [tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)], ), name_hint="decode", - args=[self.weight, self.scale], + args=[self.q_weight, self.q_scale], ) w = nn.op.permute_dims(w) # pylint: disable=invalid-name x = nn.op.matmul(x, w, out_dtype=self.out_dtype) @@ -361,11 +365,11 @@ def __init__( # pylint: disable=too-many-arguments self.out_features = out_features self.out_dtype = out_dtype self.config = config - self.weight = nn.Parameter( + self.q_weight = nn.Parameter( (self.total_out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), config.storage_dtype, ) - self.scale = nn.Parameter( + self.q_scale = nn.Parameter( (self.total_out_features, tir.ceildiv(in_features, config.group_size)), config.model_dtype, ) @@ -427,7 +431,7 @@ def forward(self, x: nn.Tensor) -> Sequence[nn.Tensor]: # pylint: disable=inval ], ), name_hint="decode", - args=[self.weight, self.scale], + args=[self.q_weight, self.q_scale], ) # x: [*B, in_features] # w: [in_features, out_features] @@ -447,11 +451,14 @@ def __init__(self, num: int, dim: int, config: GroupQuantize): self.num = num self.dim = dim self.config = config - n_group = tir.ceildiv(dim, config.group_size) - self.weight = nn.Parameter( - (num, n_group * config.num_elem_per_storage), config.storage_dtype + self.q_weight = nn.Parameter( + (num, tir.ceildiv(dim, config.num_elem_per_storage)), + config.storage_dtype, + ) + self.q_scale = nn.Parameter( + (num, tir.ceildiv(dim, config.group_size)), + config.model_dtype, ) - self.scale = nn.Parameter((num, n_group), config.model_dtype) @staticmethod def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> "GroupQuantizeEmbedding": @@ -495,7 +502,7 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name [tir.IntImm("int64", self.num), tir.IntImm("int64", self.dim)], ), name_hint="decode", - args=[self.weight, self.scale], + args=[self.q_weight, self.q_scale], ) if x.ndim == 1: return nn.op.take(w, x, axis=0) diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index e92c2824d3..8611db017a 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -10,6 +10,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse + from mlc_chat.chat_module import GenerationConfig from .base import set_global_random_seed diff --git a/python/mlc_chat/support/auto_config.py b/python/mlc_chat/support/auto_config.py index 0546e49252..708b675513 100644 --- a/python/mlc_chat/support/auto_config.py +++ b/python/mlc_chat/support/auto_config.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Union -from .style import green +from .style import bold, green if TYPE_CHECKING: from mlc_chat.compiler import Model # pylint: disable=unused-import @@ -92,10 +92,10 @@ def detect_model_type(model_type: str, config: Path) -> "Model": if "model_type" not in cfg: raise ValueError( f"'model_type' not found in: {config}. " - f"Please explicitly specify `--model-type` instead" + f"Please explicitly specify `--model-type` instead." ) model_type = cfg["model_type"] - logger.info("%s model type: %s", FOUND, model_type) + logger.info("%s model type: %s. Use `--model-type` to override.", FOUND, bold(model_type)) if model_type not in MODELS: raise ValueError(f"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}") return MODELS[model_type] diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index 491402b008..6bfd51c06d 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -3,9 +3,11 @@ import os from typing import TYPE_CHECKING, Callable, Optional, Tuple +import tvm from tvm import IRModule, relax from tvm._ffi import register_func from tvm.contrib import tar, xcode +from tvm.runtime import Device from tvm.target import Target from .style import bold, green, red @@ -44,6 +46,33 @@ def detect_target_and_host(target_hint: str, host_hint: str = "auto") -> Tuple[T return target, build_func +def detect_device(device_hint: str) -> Device: + """Detect locally available device from string hint.""" + if device_hint == "auto": + device = None + for device_type in AUTO_DETECT_DEVICES: + cur_device = tvm.device(dev_type=device_type, dev_id=0) + if cur_device.exist: + logger.info("%s device: %s:0", FOUND, device_type) + if device is None: + device = cur_device + else: + logger.info("%s device: %s:0", NOT_FOUND, device_type) + if device is None: + logger.info("%s: No available device detected. Falling back to CPU", NOT_FOUND) + return tvm.device("cpu:0") + device_str = f"{tvm.runtime.Device.MASK2STR[device.device_type]}:{device.device_id}" + logger.info("Using device: %s. Use `--device` to override.", bold(device_str)) + return device + try: + device = tvm.device(device_hint) + except Exception as err: + raise ValueError(f"Invalid device name: {device_hint}") from err + if not device.exist: + raise ValueError(f"Device is not found on your local environment: {device_hint}") + return device + + def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: if hint in ["iphone", "android", "webgpu", "mali", "opencl"]: hint += ":generic" diff --git a/python/mlc_chat/support/auto_weight.py b/python/mlc_chat/support/auto_weight.py index 96ca55bfcb..959e795169 100644 --- a/python/mlc_chat/support/auto_weight.py +++ b/python/mlc_chat/support/auto_weight.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import List, Optional, Tuple -from .style import green, red +from .style import bold, green, red logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ def detect_weight( if not weight_path.exists(): raise ValueError(f"weight_path doesn't exist: {weight_path}") - logger.info("%s weights from directory: %s", FOUND, weight_path) + logger.info("Finding weights in: %s", weight_path) # check weight format # weight_format = "auto", guess the weight format. @@ -96,13 +96,18 @@ def _guess_weight_format(weight_path: Path) -> Tuple[Path, str]: if len(possible_formats) == 0: raise ValueError( - "Fail to detect weight format. Use `--weight-format` to manually specify the format." + "Fail to detect source weight format. " + "Use `--source-format` to explicitly specify the format." ) weight_config_path, selected_format = possible_formats[0] logger.info( - "Using %s format now. Use `--weight-format` to manually specify the format.", - selected_format, + "Using source weight configuration: %s. Use `--source` to override.", + bold(str(weight_config_path)), + ) + logger.info( + "Using source weight format: %s. Use `--source-format` to override.", + bold(selected_format), ) return weight_config_path, selected_format @@ -110,7 +115,11 @@ def _guess_weight_format(weight_path: Path) -> Tuple[Path, str]: def _check_pytorch(weight_path: Path) -> Optional[Path]: pytorch_json_path = weight_path / "pytorch_model.bin.index.json" if pytorch_json_path.exists(): - logger.info("%s Huggingface PyTorch: %s", FOUND, pytorch_json_path) + logger.info( + "%s source weight format: huggingface-torch. Source configuration: %s", + FOUND, + pytorch_json_path, + ) return pytorch_json_path logger.info("%s Huggingface PyTorch", NOT_FOUND) return None @@ -119,7 +128,11 @@ def _check_pytorch(weight_path: Path) -> Optional[Path]: def _check_safetensor(weight_path: Path) -> Optional[Path]: safetensor_json_path = weight_path / "model.safetensors.index.json" if safetensor_json_path.exists(): - logger.info("%s Huggingface Safetensor: %s", FOUND, safetensor_json_path) + logger.info( + "%s source weight format: huggingface-safetensor. Source configuration: %s", + FOUND, + safetensor_json_path, + ) return safetensor_json_path logger.info("%s Huggingface Safetensor", NOT_FOUND) return None diff --git a/tests/python/support/test_sliding_window_mask.py b/tests/legacy-python/test_sliding_window_mask.py similarity index 99% rename from tests/python/support/test_sliding_window_mask.py rename to tests/legacy-python/test_sliding_window_mask.py index fa727ced9a..51be2d0749 100644 --- a/tests/python/support/test_sliding_window_mask.py +++ b/tests/legacy-python/test_sliding_window_mask.py @@ -5,11 +5,10 @@ import numpy as np import tvm +from mlc_llm.relax_model.mistral import _make_sliding_window_mask from tvm import relax from tvm.runtime import ShapeTuple -from mlc_llm.relax_model.mistral import _make_sliding_window_mask - def _create_vm(): # pylint: disable=too-many-locals diff --git a/tests/python/parameter/test_huggingface.py b/tests/python/loader/test_huggingface.py similarity index 88% rename from tests/python/parameter/test_huggingface.py rename to tests/python/loader/test_huggingface.py index ecd8e16455..8424c8a34e 100644 --- a/tests/python/parameter/test_huggingface.py +++ b/tests/python/loader/test_huggingface.py @@ -4,11 +4,10 @@ from typing import Union import pytest -from mlc_chat.compiler import MODELS +import tvm -# from mlc_chat.compiler.model.llama_config import LlamaConfig -# from mlc_chat.compiler.model.llama_parameter import huggingface -from mlc_chat.compiler.parameter import HuggingFaceLoader +from mlc_chat.compiler import MODELS +from mlc_chat.compiler.loader import HuggingFaceLoader from mlc_chat.support import tqdm logging.basicConfig( @@ -39,7 +38,7 @@ def test_load_torch_llama(base_path: Union[str, Path]): extern_param_map=model.source["huggingface-torch"](config, None), ) with tqdm.redirect(): - for _name, _param in loader.load(): + for _name, _param in loader.load(device=tvm.device("cpu")): return # To reduce the time of the test @@ -63,7 +62,7 @@ def test_load_safetensor_llama(base_path: Union[str, Path]): extern_param_map=model.source["huggingface-safetensor"](config, None), ) with tqdm.redirect(): - for _name, _param in loader.load(): + for _name, _param in loader.load(device=tvm.device("cpu")): return # To reduce the time of the test diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index 9e75247c32..0cd22f1572 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -1,5 +1,6 @@ # pylint: disable=invalid-name,missing-docstring import pytest + from mlc_chat.compiler import MODELS diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py index 92e2b4c1d6..9de7d8fa51 100644 --- a/tests/python/model/test_llama_quantization.py +++ b/tests/python/model/test_llama_quantization.py @@ -1,4 +1,6 @@ # pylint: disable=invalid-name,missing-docstring +import pytest + from mlc_chat.compiler import MODELS, QUANTIZATION from mlc_chat.compiler.quantization.group_quantization import ( GroupQuantizeEmbedding, @@ -7,22 +9,30 @@ ) +@pytest.mark.parametrize( + "model_name, quant_name", + [ + ("llama2_7b", "q4f16_1"), + ("llama2_13b", "q4f16_1"), + ("llama2_70b", "q4f16_1"), + ], +) def test_llama2_group_quantization(model_name: str, quant_name: str): model_info = MODELS["llama"] config = model_info.config.from_predefined(model_name) model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) - assert "model.model.embed_tokens.weight" in quant_map.param_map + assert "model.embed_tokens.weight" in quant_map.param_map assert isinstance(model.model.embed_tokens, GroupQuantizeEmbedding) - assert "model.lm_head.weight" in quant_map.param_map + assert "lm_head.weight" in quant_map.param_map assert isinstance(model.lm_head, GroupQuantizeLinear) for i in range(config.num_hidden_layers): - assert f"model.model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map + assert f"model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map assert isinstance(model.model.layers[i].self_attn.qkv_proj, GroupQuantizeMultiLinear) - assert f"model.model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map + assert f"model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map assert isinstance(model.model.layers[i].self_attn.o_proj, GroupQuantizeLinear) - assert f"model.model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map + assert f"model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map assert isinstance(model.model.layers[i].mlp.gate_up_proj, GroupQuantizeMultiLinear) - assert f"model.model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map + assert f"model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map assert isinstance(model.model.layers[i].mlp.down_proj, GroupQuantizeLinear) diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 663f7b8e78..106d0f5fb5 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -2,18 +2,19 @@ from typing import List import numpy as np +import pytest import torch import tvm import tvm.testing -from mlc_chat.compiler import QUANTIZATION -from mlc_chat.compiler.parameter import QuantizeMapping -from mlc_chat.compiler.quantization import GroupQuantize +from tvm import DataType +from tvm.relax.frontend import nn + +from mlc_chat.compiler import QUANTIZATION, QuantizeMapping from mlc_chat.compiler.quantization.group_quantization import ( + GroupQuantize, GroupQuantizeEmbedding, GroupQuantizeLinear, ) -from tvm import DataType -from tvm.relax.frontend import nn def quantize_np(config: GroupQuantize, weight: np.ndarray): @@ -72,6 +73,12 @@ def dequantize_np( return ((weight_bin - max_int) * scale_repeated)[: out_shape[0]][: out_shape[1]] +@pytest.mark.parametrize( + "quant_name, shape, dtype, device", + [ + ("q4f16_1", [16, 128], "float16", "cpu"), + ], +) def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str): config = QUANTIZATION[quant_name] assert isinstance(config, GroupQuantize) @@ -88,6 +95,12 @@ def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: ) +@pytest.mark.parametrize( + "quant_name, shape, dtype", + [ + ("q4f16_1", [16, 128], "float16"), + ], +) def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): class Test(nn.Module): def __init__(self) -> None: @@ -108,8 +121,8 @@ def forward(self, x: nn.Tensor): config.model_dtype ) mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") - mod.linear.weight.data = weight_np - mod.linear.scale.data = scale_np + mod.linear.q_weight.data = weight_np + mod.linear.q_scale.data = scale_np model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) out = model["forward"]( torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member @@ -118,6 +131,12 @@ def forward(self, x: nn.Tensor): tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) +@pytest.mark.parametrize( + "quant_name, shape, dtype", + [ + ("q4f16_1", [16, 128], "float16"), + ], +) def test_quantize_model(quant_name: str, shape: List[int], dtype: str): class Test(nn.Module): def __init__(self) -> None: diff --git a/tests/python/support/test_auto_config.py b/tests/python/support/test_auto_config.py index 540c544c22..ff3dcc4e7b 100644 --- a/tests/python/support/test_auto_config.py +++ b/tests/python/support/test_auto_config.py @@ -5,6 +5,7 @@ from pathlib import Path import pytest + from mlc_chat.support.auto_config import detect_config logging.basicConfig( diff --git a/tests/python/support/test_auto_weight.py b/tests/python/support/test_auto_weight.py index 5776791df1..f4ec20a8b7 100644 --- a/tests/python/support/test_auto_weight.py +++ b/tests/python/support/test_auto_weight.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest + from mlc_chat.support.auto_weight import detect_weight logging.basicConfig( From 9d20575aa6638e168c92da7971ef1583a077516c Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 01:30:15 -0700 Subject: [PATCH 087/116] Merge `llama_config.CONFIG` into `MODEL_PRESETS` (#1188) --- .../mlc_chat/compiler/model/llama_config.py | 77 ------------------- python/mlc_chat/compiler/model/model.py | 77 ++++++++++++++++++- tests/python/model/test_llama.py | 10 +-- tests/python/model/test_llama_quantization.py | 34 +++++--- 4 files changed, 105 insertions(+), 93 deletions(-) diff --git a/python/mlc_chat/compiler/model/llama_config.py b/python/mlc_chat/compiler/model/llama_config.py index 113acd456f..044161023a 100644 --- a/python/mlc_chat/compiler/model/llama_config.py +++ b/python/mlc_chat/compiler/model/llama_config.py @@ -29,80 +29,3 @@ def __post_init__(self): self.head_dim = self.hidden_size // self.num_attention_heads assert self.num_attention_heads % self.num_key_value_heads == 0 assert self.head_dim * self.num_attention_heads == self.hidden_size - - @staticmethod - def from_predefined(name: str) -> "LlamaConfig": - """Create a LlamaConfig from a predefined configuration.""" - return LlamaConfig.from_dict(CONFIG[name]) - - -CONFIG = { - "llama2_7b": { - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": None, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - }, - "llama2_13b": { - "_name_or_path": "meta-llama/Llama-2-13b-hf", - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 5120, - "initializer_range": 0.02, - "intermediate_size": 13824, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 40, - "num_hidden_layers": 40, - "num_key_value_heads": 40, - "pad_token_id": 0, - "pretraining_tp": 2, - "rms_norm_eps": 1e-05, - "rope_scaling": None, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - }, - "llama2_70b": { - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 8192, - "initializer_range": 0.02, - "intermediate_size": 28672, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 64, - "num_hidden_layers": 80, - "num_key_value_heads": 8, - "pad_token_id": 0, - "rms_norm_eps": 1e-05, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - }, -} diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index 4e440156c9..7e3074388e 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -65,4 +65,79 @@ class Model: ) } -MODEL_PRESETS: Dict[str, Dict[str, Any]] = llama_config.CONFIG +MODEL_PRESETS: Dict[str, Any] = { + "llama2_7b": llama_config.LlamaConfig.from_dict( + { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + } + ), + "llama2_13b": llama_config.LlamaConfig.from_dict( + { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pad_token_id": 0, + "pretraining_tp": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + } + ), + "llama2_70b": llama_config.LlamaConfig.from_dict( + { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + } + ), +} diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index 0cd22f1572..24bfa2afe8 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -1,15 +1,15 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODELS +from mlc_chat.compiler import MODEL_PRESETS, MODELS @pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) def test_llama2_creation(model_name: str): - model_info = MODELS["llama"] - config = model_info.config.from_predefined(model_name) - model = model_info.model(config) - mod, named_params = model.export_tvm(spec=model.get_default_spec()) + model = MODELS["llama"].model(MODEL_PRESETS[model_name]) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) mod.show(black_format=False) for name, param in named_params: print(name, param.shape, param.dtype) diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py index 9de7d8fa51..68461e2174 100644 --- a/tests/python/model/test_llama_quantization.py +++ b/tests/python/model/test_llama_quantization.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name,missing-docstring import pytest -from mlc_chat.compiler import MODELS, QUANTIZATION +from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION from mlc_chat.compiler.quantization.group_quantization import ( GroupQuantizeEmbedding, GroupQuantizeLinear, @@ -18,22 +18,36 @@ ], ) def test_llama2_group_quantization(model_name: str, quant_name: str): - model_info = MODELS["llama"] - config = model_info.config.from_predefined(model_name) - model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) + config = MODEL_PRESETS[model_name] + model, quant_map = MODELS["llama"].quantize["group-quant"](config, QUANTIZATION[quant_name]) assert "model.embed_tokens.weight" in quant_map.param_map - assert isinstance(model.model.embed_tokens, GroupQuantizeEmbedding) + assert isinstance( + model.model.embed_tokens, # type: ignore[attr-defined] + GroupQuantizeEmbedding, + ) assert "lm_head.weight" in quant_map.param_map - assert isinstance(model.lm_head, GroupQuantizeLinear) + assert isinstance(model.lm_head, GroupQuantizeLinear) # type: ignore[attr-defined] for i in range(config.num_hidden_layers): assert f"model.layers.{i}.self_attn.qkv_proj.weight" in quant_map.param_map - assert isinstance(model.model.layers[i].self_attn.qkv_proj, GroupQuantizeMultiLinear) + assert isinstance( + model.model.layers[i].self_attn.qkv_proj, # type: ignore[attr-defined] + GroupQuantizeMultiLinear, + ) assert f"model.layers.{i}.self_attn.o_proj.weight" in quant_map.param_map - assert isinstance(model.model.layers[i].self_attn.o_proj, GroupQuantizeLinear) + assert isinstance( + model.model.layers[i].self_attn.o_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) assert f"model.layers.{i}.mlp.gate_up_proj.weight" in quant_map.param_map - assert isinstance(model.model.layers[i].mlp.gate_up_proj, GroupQuantizeMultiLinear) + assert isinstance( + model.model.layers[i].mlp.gate_up_proj, # type: ignore[attr-defined] + GroupQuantizeMultiLinear, + ) assert f"model.layers.{i}.mlp.down_proj.weight" in quant_map.param_map - assert isinstance(model.model.layers[i].mlp.down_proj, GroupQuantizeLinear) + assert isinstance( + model.model.layers[i].mlp.down_proj, # type: ignore[attr-defined] + GroupQuantizeLinear, + ) if __name__ == "__main__": From 5d1dc34a7e319ae487474191c13333e25c2a9956 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 01:46:19 -0700 Subject: [PATCH 088/116] Merge llama_config.py into llama_model.py (#1189) --- .../mlc_chat/compiler/model/llama_config.py | 31 ---- .../mlc_chat/compiler/model/llama_loader.py | 3 +- python/mlc_chat/compiler/model/llama_model.py | 32 +++- .../compiler/model/llama_quantization.py | 3 +- python/mlc_chat/compiler/model/model.py | 146 +++++++++--------- tests/python/model/test_llama.py | 4 +- tests/python/model/test_llama_quantization.py | 5 +- 7 files changed, 108 insertions(+), 116 deletions(-) delete mode 100644 python/mlc_chat/compiler/model/llama_config.py diff --git a/python/mlc_chat/compiler/model/llama_config.py b/python/mlc_chat/compiler/model/llama_config.py deleted file mode 100644 index 044161023a..0000000000 --- a/python/mlc_chat/compiler/model/llama_config.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Common configuration for Llama models.""" -import dataclasses -from typing import Any, Dict - -from ...support.config import ConfigBase - - -@dataclasses.dataclass -class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes - """Configuration of the Llama model.""" - - hidden_act: str - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_hidden_layers: int - rms_norm_eps: float - vocab_size: int - max_sequence_length: int = 2048 - position_embedding_base: int = 10000 - num_key_value_heads: int = 0 - kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - head_dim: int = 0 - - def __post_init__(self): - if self.num_key_value_heads == 0: - self.num_key_value_heads = self.num_attention_heads - if self.head_dim == 0: - self.head_dim = self.hidden_size // self.num_attention_heads - assert self.num_attention_heads % self.num_key_value_heads == 0 - assert self.head_dim * self.num_attention_heads == self.hidden_size diff --git a/python/mlc_chat/compiler/model/llama_loader.py b/python/mlc_chat/compiler/model/llama_loader.py index 94a6d80600..12d957952e 100644 --- a/python/mlc_chat/compiler/model/llama_loader.py +++ b/python/mlc_chat/compiler/model/llama_loader.py @@ -8,8 +8,7 @@ from ..loader import ExternMapping from ..quantization import Quantization -from .llama_config import LlamaConfig -from .llama_model import LlamaForCasualLM +from .llama_model import LlamaConfig, LlamaForCasualLM def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 0c9d2f45ab..a2a6c28d31 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -2,14 +2,42 @@ Implementation for Llama2 architecture. TODO: add docstring """ +import dataclasses import math -from typing import Optional +from typing import Any, Dict, Optional from tvm import te, tir from tvm.relax.frontend import nn from tvm.relax.frontend.nn import Tensor, op -from .llama_config import LlamaConfig +from ...support.config import ConfigBase + + +@dataclasses.dataclass +class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes + """Configuration of the Llama model.""" + + hidden_act: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + rms_norm_eps: float + vocab_size: int + max_sequence_length: int = 2048 + position_embedding_base: int = 10000 + num_key_value_heads: int = 0 + head_dim: int = 0 + kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + if self.num_key_value_heads == 0: + self.num_key_value_heads = self.num_attention_heads + if self.head_dim == 0: + self.head_dim = self.hidden_size // self.num_attention_heads + assert self.num_attention_heads % self.num_key_value_heads == 0 + assert self.head_dim * self.num_attention_heads == self.hidden_size + # pylint: disable=invalid-name,missing-docstring diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py index ffb2ec71e3..cec9bd86e5 100644 --- a/python/mlc_chat/compiler/model/llama_quantization.py +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -6,8 +6,7 @@ from ..loader import QuantizeMapping from ..quantization import GroupQuantize -from .llama_config import LlamaConfig -from .llama_model import LlamaForCasualLM +from .llama_model import LlamaConfig, LlamaForCasualLM def group_quant( diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index 7e3074388e..eb1f9b2d11 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -6,7 +6,7 @@ from ..loader import ExternMapping, QuantizeMapping from ..quantization.quantization import Quantization -from . import llama_config, llama_loader, llama_model, llama_quantization +from . import llama_loader, llama_model, llama_quantization ModelConfig = Any """A ModelConfig is an object that represents a model architecture. It is required to have @@ -54,7 +54,7 @@ class Model: "llama": Model( name="llama", model=llama_model.LlamaForCasualLM, - config=llama_config.LlamaConfig, + config=llama_model.LlamaConfig, source={ "huggingface-torch": llama_loader.huggingface, "huggingface-safetensor": llama_loader.huggingface, @@ -66,78 +66,72 @@ class Model: } MODEL_PRESETS: Dict[str, Any] = { - "llama2_7b": llama_config.LlamaConfig.from_dict( - { - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 4096, - "initializer_range": 0.02, - "intermediate_size": 11008, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 32, - "num_hidden_layers": 32, - "num_key_value_heads": 32, - "pad_token_id": 0, - "pretraining_tp": 1, - "rms_norm_eps": 1e-05, - "rope_scaling": None, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - } - ), - "llama2_13b": llama_config.LlamaConfig.from_dict( - { - "_name_or_path": "meta-llama/Llama-2-13b-hf", - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 5120, - "initializer_range": 0.02, - "intermediate_size": 13824, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 40, - "num_hidden_layers": 40, - "num_key_value_heads": 40, - "pad_token_id": 0, - "pretraining_tp": 2, - "rms_norm_eps": 1e-05, - "rope_scaling": None, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - } - ), - "llama2_70b": llama_config.LlamaConfig.from_dict( - { - "architectures": ["LlamaForCausalLM"], - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 8192, - "initializer_range": 0.02, - "intermediate_size": 28672, - "max_position_embeddings": 2048, - "model_type": "llama", - "num_attention_heads": 64, - "num_hidden_layers": 80, - "num_key_value_heads": 8, - "pad_token_id": 0, - "rms_norm_eps": 1e-05, - "tie_word_embeddings": False, - "torch_dtype": "float16", - "transformers_version": "4.31.0.dev0", - "use_cache": True, - "vocab_size": 32000, - } - ), + "llama2_7b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "llama2_13b": { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pad_token_id": 0, + "pretraining_tp": 2, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, + "llama2_70b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pad_token_id": 0, + "rms_norm_eps": 1e-05, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + }, } diff --git a/tests/python/model/test_llama.py b/tests/python/model/test_llama.py index 24bfa2afe8..8bbbd75971 100644 --- a/tests/python/model/test_llama.py +++ b/tests/python/model/test_llama.py @@ -6,7 +6,9 @@ @pytest.mark.parametrize("model_name", ["llama2_7b", "llama2_13b", "llama2_70b"]) def test_llama2_creation(model_name: str): - model = MODELS["llama"].model(MODEL_PRESETS[model_name]) + model_info = MODELS["llama"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model = model_info.model(config) mod, named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore ) diff --git a/tests/python/model/test_llama_quantization.py b/tests/python/model/test_llama_quantization.py index 68461e2174..5bf3b2dd08 100644 --- a/tests/python/model/test_llama_quantization.py +++ b/tests/python/model/test_llama_quantization.py @@ -18,8 +18,9 @@ ], ) def test_llama2_group_quantization(model_name: str, quant_name: str): - config = MODEL_PRESETS[model_name] - model, quant_map = MODELS["llama"].quantize["group-quant"](config, QUANTIZATION[quant_name]) + model_info = MODELS["llama"] + config = model_info.config.from_dict(MODEL_PRESETS[model_name]) + model, quant_map = model_info.quantize["group-quant"](config, QUANTIZATION[quant_name]) assert "model.embed_tokens.weight" in quant_map.param_map assert isinstance( model.model.embed_tokens, # type: ignore[attr-defined] From 4832c2f9b84941b08aab0598a7a281828b604e6a Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 01:58:55 -0700 Subject: [PATCH 089/116] Add CodeLlama as part of model presets (#1190) --- python/mlc_chat/compiler/model/model.py | 70 +++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index eb1f9b2d11..b18742201f 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -134,4 +134,74 @@ class Model: "use_cache": True, "vocab_size": 32000, }, + "codellama_7b": { + "_name_or_path": "codellama/CodeLlama-7b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.33.0.dev0", + "use_cache": True, + "vocab_size": 32016, + }, + "codellama_13b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.32.0.dev0", + "use_cache": True, + "vocab_size": 32016, + }, + "codellama_34b": { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 22016, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 1000000, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.32.0.dev0", + "use_cache": True, + "vocab_size": 32016, + }, } From 78424f0de6783988933405e85f22d848056c5c89 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 02:13:21 -0700 Subject: [PATCH 090/116] [Docs] Clarify zstd installation on Windows (#1191) --- docs/install/tvm.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/install/tvm.rst b/docs/install/tvm.rst index c2b7998ada..d9a6890b10 100644 --- a/docs/install/tvm.rst +++ b/docs/install/tvm.rst @@ -132,8 +132,11 @@ A nightly prebuilt Python package of Apache TVM Unity is provided. FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. Hint - To locate the "tvm.dll" file in Conda, navigate to your user home directory (e.g., "/users/xxxx"). Search for "tvm.dll" and find the folder whose path contains the name of the current environment, such as "mlc-chat-venv." Once located, copy "zstd.dll" to that specific folder. + It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed: + .. code-block:: bash + + conda install zstd .. _tvm-unity-build-from-source: From 5d63f7e587472e34739049558ca6b641b5e08f27 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 15:42:19 -0700 Subject: [PATCH 091/116] [Docs] Clarify zstd installation on Windows (#1196) Update zstd installation --- docs/install/mlc_llm.rst | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/install/mlc_llm.rst b/docs/install/mlc_llm.rst index 13fc373dbf..fff04969b2 100644 --- a/docs/install/mlc_llm.rst +++ b/docs/install/mlc_llm.rst @@ -124,7 +124,11 @@ Select your operating system/compute platform and run the command in your termin FileNotFoundError: Could not find module 'path\to\site-packages\tvm\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax. - It is likely `zstd`, a dependency to LLVM, was missing. Please `download `__ the 64 bit version of precompiled binary, rename it to `zstd.dll` and copy to the same folder as `tvm.dll`. + It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed: + + .. code-block:: bash + + conda install zstd Option 2. Build from Source From 3417505a4f63c8d3b188661efd296d329f519a05 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 4 Nov 2023 19:44:25 -0700 Subject: [PATCH 092/116] Support overriding `--max-sequence-length` in command line (#1197) --- python/mlc_chat/cli/compile.py | 31 ++++++++++--- python/mlc_chat/compiler/__init__.py | 5 ++- python/mlc_chat/compiler/compile.py | 45 ++++++++++++------- python/mlc_chat/compiler/convert_weight.py | 29 ++++++------ .../compiler/flags_model_config_override.py | 32 +++++++++++++ python/mlc_chat/compiler/model/llama_model.py | 22 ++++++++- python/mlc_chat/support/auto_target.py | 7 ++- 7 files changed, 130 insertions(+), 41 deletions(-) create mode 100644 python/mlc_chat/compiler/flags_model_config_override.py diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 2fcc9c0213..c340119c98 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -1,6 +1,7 @@ """Command line entrypoint of compilation.""" import argparse import logging +import re from pathlib import Path from typing import Union @@ -38,6 +39,15 @@ def _parse_output(path: Union[str, Path]) -> Path: raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") return path + def _check_prefix_symbols(prefix: str) -> str: + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$" + if prefix == "" or re.match(pattern, prefix): + return prefix + raise argparse.ArgumentTypeError( + "Invalid prefix. It should only consist of " + "numbers (0-9), alphabets (A-Z, a-z) and underscore (_)." + ) + parser = argparse.ArgumentParser("MLC LLM Compiler") parser.add_argument( "--config", @@ -106,17 +116,27 @@ def _parse_output(path: Union[str, Path]) -> Path: "conflicts. Differet from objcopy, this takes no effect for shared library. " '(default: "")', ) + parser.add_argument( + "--max-sequence-length", + type=int, + default=None, + help="Option to override the maximum sequence length supported by the model. " + "An LLM is usually trained with a fixed maximum sequence length, which is usually " + "explicitly specified in model spec. By default, if this option is not set explicitly, " + "the maximum sequence length is determined by `max_sequence_length` or " + "`max_position_embeddings` in config.json, which can be inaccuate for some models.", + ) parser.add_argument( "--output", "-o", type=_parse_output, required=True, help="The name of the output file. The suffix determines if the output file is a " - "shared library or a static library. Available suffixes: " - "1) Linux: .so (shared), .tar (static); " - "2) macOS: .dylib (shared), .tar (static); " - "3) Windows: .dll (shared), .tar (static); " - "4) Android, iOS: .tar (static); " + "shared library or objects. Available suffixes: " + "1) Linux: .so (shared), .tar (objects); " + "2) macOS: .dylib (shared), .tar (objects); " + "3) Windows: .dll (shared), .tar (objects); " + "4) Android, iOS: .tar (objects); " "5) Web: .wasm (web assembly)", ) parsed = parser.parse_args() @@ -131,6 +151,7 @@ def _parse_output(path: Union[str, Path]) -> Path: build_func=build_func, prefix_symbols=parsed.prefix_symbols, output=parsed.output, + max_sequence_length=parsed.max_sequence_length, ) diff --git a/python/mlc_chat/compiler/__init__.py b/python/mlc_chat/compiler/__init__.py index e65e12a5d9..6d0c8c223d 100644 --- a/python/mlc_chat/compiler/__init__.py +++ b/python/mlc_chat/compiler/__init__.py @@ -4,8 +4,9 @@ """ from . import compiler_pass from .compile import CompileArgs, compile # pylint: disable=redefined-builtin -from .convert_weight import convert_weight +from .convert_weight import ConversionArgs, convert_weight +from .flags_model_config_override import ModelConfigOverride from .flags_optimization import OptimizationFlags -from .loader import ExternMapping, HuggingFaceLoader, QuantizeMapping +from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping from .model import MODEL_PRESETS, MODELS, Model from .quantization import QUANTIZATION diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 02842a1903..c4d7a9881f 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -3,12 +3,13 @@ import logging from io import StringIO from pathlib import Path -from typing import Callable +from typing import Callable, Optional from tvm import IRModule, relax from tvm.target import Target from ..support.style import bold +from .flags_model_config_override import ModelConfigOverride from .flags_optimization import OptimizationFlags from .model import Model from .quantization import Quantization @@ -28,25 +29,28 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes build_func: Callable[[IRModule, "CompileArgs"], None] prefix_symbols: str output: Path + overrides: ModelConfigOverride - -def _echo_args(args: CompileArgs) -> None: - out = StringIO() - print(f"{bold('Compiling with arguments:')}", file=out) - print(f" {bold('--config'):<25} {args.config}", file=out) - print(f" {bold('--quantization'):<25} {args.quantization}", file=out) - print(f" {bold('--model-type'):<25} {args.model.name}", file=out) - print(f" {bold('--target'):<25} {args.target.export()}", file=out) - print(f" {bold('--opt'):<25} {args.opt}", file=out) - print(f" {bold('--output'):<25} {args.output}", file=out) - print(out.getvalue().rstrip()) + def display(self) -> None: + """Display the arguments to stdout.""" + out = StringIO() + print(f"{bold('Compiling with arguments:')}", file=out) + print(f" {bold('--config'):<25} {self.config}", file=out) + print(f" {bold('--quantization'):<25} {self.quantization}", file=out) + print(f" {bold('--model-type'):<25} {self.model.name}", file=out) + print(f" {bold('--target'):<25} {self.target.export()}", file=out) + print(f" {bold('--opt'):<25} {self.opt}", file=out) + print(f" {bold('--prefix-symbols'):<25} \"{self.prefix_symbols}\"", file=out) + print(f" {bold('--output'):<25} {self.output}", file=out) + print(f" {bold('--overrides'):<25} {dataclasses.asdict(self.overrides)}", file=out) + print(out.getvalue().rstrip()) def _compile(args: CompileArgs): logger.info("Creating model from: %s", args.config) model_config = args.model.config.from_file(args.config) - quantization = args.quantization - model, _ = args.model.quantize[quantization.kind](model_config, quantization) + args.overrides.apply(model_config) + model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization) logger.info("Exporting the model to TVM Unity compiler") mod, _named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore @@ -68,10 +72,19 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin build_func: Callable[[IRModule, CompileArgs], None], prefix_symbols: str, output: Path, + max_sequence_length: Optional[int], ): """Compile a model given its configuration and quantization format to a specific target.""" args = CompileArgs( - config, quantization, model_type, target, opt, build_func, prefix_symbols, output + config, + quantization, + model_type, + target, + opt, + build_func, + prefix_symbols, + output, + ModelConfigOverride(max_sequence_length=max_sequence_length), ) - _echo_args(args) + args.display() _compile(args) diff --git a/python/mlc_chat/compiler/convert_weight.py b/python/mlc_chat/compiler/convert_weight.py index 554c422a40..7b1f4576b9 100644 --- a/python/mlc_chat/compiler/convert_weight.py +++ b/python/mlc_chat/compiler/convert_weight.py @@ -33,21 +33,22 @@ class ConversionArgs: # pylint: disable=too-many-instance-attributes source_format: str output: Path + def display(self) -> None: + """Display the arguments to stdout.""" -def _echo_args(args: ConversionArgs) -> None: - def _device_to_str(device: Device) -> str: - return f"{Device.MASK2STR[device.device_type]}:{device.device_id}" + def _device_to_str(device: Device) -> str: + return f"{Device.MASK2STR[device.device_type]}:{device.device_id}" - out = StringIO() - print(f"{bold('Weight conversion with arguments:')}", file=out) - print(f" {bold('--config'):<25} {args.config}", file=out) - print(f" {bold('--quantization'):<25} {args.quantization}", file=out) - print(f" {bold('--model-type'):<25} {args.model.name}", file=out) - print(f" {bold('--device'):<25} {_device_to_str(args.device)}", file=out) - print(f" {bold('--source'):<25} {args.source}", file=out) - print(f" {bold('--source-format'):<25} {args.source_format}", file=out) - print(f" {bold('--output'):<25} {args.output}", file=out) - print(out.getvalue().rstrip()) + out = StringIO() + print(f"{bold('Weight conversion with arguments:')}", file=out) + print(f" {bold('--config'):<25} {self.config}", file=out) + print(f" {bold('--quantization'):<25} {self.quantization}", file=out) + print(f" {bold('--model-type'):<25} {self.model.name}", file=out) + print(f" {bold('--device'):<25} {_device_to_str(self.device)}", file=out) + print(f" {bold('--source'):<25} {self.source}", file=out) + print(f" {bold('--source-format'):<25} {self.source_format}", file=out) + print(f" {bold('--output'):<25} {self.output}", file=out) + print(out.getvalue().rstrip()) def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-locals @@ -120,5 +121,5 @@ def convert_weight( # pylint: disable=too-many-arguments ): """MLC LLM's weight conversation and quantization flow.""" args = ConversionArgs(config, quantization, model, device, source, source_format, output) - _echo_args(args) + args.display() _convert_args(args) diff --git a/python/mlc_chat/compiler/flags_model_config_override.py b/python/mlc_chat/compiler/flags_model_config_override.py new file mode 100644 index 0000000000..f1a25346a7 --- /dev/null +++ b/python/mlc_chat/compiler/flags_model_config_override.py @@ -0,0 +1,32 @@ +"""Flags for overriding model config.""" +import dataclasses +import logging +from typing import Optional + +from ..support.style import bold + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ModelConfigOverride: + """Flags for overriding model config.""" + + max_sequence_length: Optional[int] = None + max_batch_size: Optional[int] = None + num_shards: Optional[int] = None + + def apply(self, model_config): + """Apply the overrides to the given model config.""" + if self.max_sequence_length is not None: + logger.info( + "Overriding %s from %d to %d", + bold("max_sequence_length"), + model_config.max_sequence_length, + self.max_sequence_length, + ) + model_config.max_sequence_length = self.max_sequence_length + if self.max_batch_size is not None: + model_config.max_batch_size = self.max_batch_size + if self.num_shards is not None: + model_config.num_shards = self.num_shards diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index a2a6c28d31..023db05e82 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -3,6 +3,7 @@ TODO: add docstring """ import dataclasses +import logging import math from typing import Any, Dict, Optional @@ -11,26 +12,43 @@ from tvm.relax.frontend.nn import Tensor, op from ...support.config import ConfigBase +from ...support.style import bold + +logger = logging.getLogger(__name__) @dataclasses.dataclass class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes """Configuration of the Llama model.""" - hidden_act: str hidden_size: int intermediate_size: int num_attention_heads: int num_hidden_layers: int rms_norm_eps: float vocab_size: int - max_sequence_length: int = 2048 position_embedding_base: int = 10000 + max_sequence_length: int = 0 num_key_value_heads: int = 0 head_dim: int = 0 kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) def __post_init__(self): + if self.max_sequence_length == 0: + if "max_position_embeddings" in self.kwargs: + self.max_sequence_length = self.kwargs.pop("max_position_embeddings") + logger.info( + "%s not found in config.json. Falling back to %s (%d)", + bold("max_sequence_length"), + bold("max_position_embeddings"), + self.max_sequence_length, + ) + else: + raise ValueError( + "Unable to determine the maxmimum sequence length, because neither " + "`max_sequence_length` nor `max_position_embeddings` is provided " + "in `config.json`." + ) if self.num_key_value_heads == 0: self.num_key_value_heads = self.num_attention_heads if self.head_dim == 0: diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index 6bfd51c06d..f84d3cd6b1 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -133,9 +133,12 @@ def _is_device(device: str): def _add_prefix_symbol(mod: IRModule, prefix: str, is_system_lib: bool) -> IRModule: if is_system_lib and prefix: - mod = mod.with_attr("system_lib_prefix", prefix) + mod = mod.with_attrs({"system_lib_prefix": prefix}) # type: ignore[dict-item] elif is_system_lib: - logger.warning("--prefix-symbols is not specified when building a static library") + logger.warning( + "%s is not specified when building a static library", + bold("--prefix-symbols"), + ) elif prefix: logger.warning( "--prefix-symbols is specified, but it will not take any effect " From 0e08845a116781a11c03ee453a0be03a3a1a9f1c Mon Sep 17 00:00:00 2001 From: Animesh Bohara Date: Sun, 5 Nov 2023 01:01:26 -0400 Subject: [PATCH 093/116] [RestAPI] Added docs (#1193) Add docs for RestAPI Co-authored-by: Animesh Bohara --- docs/deploy/rest.rst | 256 ++++++++++++++++++++++++ python/mlc_chat/interface/openai_api.py | 7 +- 2 files changed, 260 insertions(+), 3 deletions(-) diff --git a/docs/deploy/rest.rst b/docs/deploy/rest.rst index 8451624fdb..d12029a80d 100644 --- a/docs/deploy/rest.rst +++ b/docs/deploy/rest.rst @@ -74,12 +74,136 @@ The REST API provides the following endpoints: .. http:get:: /v1/completions +------------------------------------------------ + Get a completion from MLC-Chat using a prompt. +**Request body** + +**model**: *str* (required) + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. +**prompt**: *str* (required) + A list of chat messages. The last message should be from the user. +**stream**: *bool* (optional) + Whether to stream the response. If ``True``, the response will be streamed + as the model generates the response. If ``False``, the response will be + returned after the model finishes generating the response. +**temperature**: *float* (optional) + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. +**top_p**: *float* (optional) + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. +**repetition_penalty**: *float* (optional) + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). +**presence_penalty**: *float* (optional) + Positive values penalize new tokens if they are already present in the text so far, + decreasing the model's likelihood to repeat tokens. +**frequency_penalty**: *float* (optional) + Positive values penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat tokens. +**mean_gen_len**: *int* (optional) + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. +**max_gen_len**: *int* (optional) + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. + +------------------------------------------------ + +**Returns** + If ``stream`` is set to ``False``, the response will be a ``CompletionResponse`` object. + If ``stream`` is set to ``True``, the response will be a stream of ``CompletionStreamResponse`` objects. + + .. http:get:: /v1/chat/completions +------------------------------------------------ + Get a response from MLC-Chat using a prompt, either with or without streaming. +**Request body** + +**model**: *str* (required) + The model folder after compiling with MLC-LLM build process. The parameter + can either be the model name with its quantization scheme + (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model + folder. In the former case, we will use the provided name to search + for the model folder over possible paths. +**messages**: *list[ChatMessage]* (required) + A list of chat messages. The last message should be from the user. +**stream**: *bool* (optional) + Whether to stream the response. If ``True``, the response will be streamed + as the model generates the response. If ``False``, the response will be + returned after the model finishes generating the response. +**temperature**: *float* (optional) + The temperature applied to logits before sampling. The default value is + ``0.7``. A higher temperature encourages more diverse outputs, while a + lower temperature produces more deterministic outputs. +**top_p**: *float* (optional) + This parameter determines the set of tokens from which we sample during + decoding. The default value is set to ``0.95``. At each step, we select + tokens from the minimal set that has a cumulative probability exceeding + the ``top_p`` parameter. + + For additional information on top-p sampling, please refer to this blog + post: https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling. +**repetition_penalty**: *float* (optional) + The repetition penalty controls the likelihood of the model generating + repeated texts. The default value is set to ``1.0``, indicating that no + repetition penalty is applied. Increasing the value reduces the + likelihood of repeat text generation. However, setting a high + ``repetition_penalty`` may result in the model generating meaningless + texts. The ideal choice of repetition penalty may vary among models. + + For more details on how repetition penalty controls text generation, please + check out the CTRL paper (https://arxiv.org/pdf/1909.05858.pdf). +**presence_penalty**: *float* (optional) + Positive values penalize new tokens if they are already present in the text so far, + decreasing the model's likelihood to repeat tokens. +**frequency_penalty**: *float* (optional) + Positive values penalize new tokens based on their existing frequency in the text so far, + decreasing the model's likelihood to repeat tokens. +**mean_gen_len**: *int* (optional) + The approximated average number of generated tokens in each round. Used + to determine whether the maximum window size would be exceeded. +**max_gen_len**: *int* (optional) + This parameter determines the maximum length of the generated text. If it is + not set, the model will generate text until it encounters a stop token. +**n**: *int* (optional) + This parameter determines the number of text samples to generate. The default + value is ``1``. Note that this parameter is only used when ``stream`` is set to + ``False``. +**stop**: *str* or *list[str]* (optional) + When ``stop`` is encountered, the model will stop generating output. + It can be a string or a list of strings. If it is a list of strings, the model + will stop generating output when any of the strings in the list is encountered. + Note that this parameter does not override the default stop string of the model. + +------------------------------------------------ + +**Returns** + If ``stream`` is set to ``False``, the response will be a ``ChatCompletionResponse`` object. + If ``stream`` is set to ``True``, the response will be a stream of ``ChatCompletionStreamResponse`` objects. + .. http:get:: /chat/reset Reset the chat. @@ -92,6 +216,138 @@ The REST API provides the following endpoints: Get the verbose runtime stats (encode/decode speed, total runtime). + +Request Objects +--------------- + +**ChatMessage** + +**role**: *str* (required) + The role(author) of the message. It can be either ``user`` or ``assistant``. +**content**: *str* (required) + The content of the message. +**name**: *str* (optional) + The name of the author of the message. + +Response Objects +---------------- + +**CompletionResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``text.completion``. +**created**: *int* + The time when the completion is created. +**choices**: *list[CompletionResponseChoice]* + A list of choices generated by the model. +**usage**: *UsageInfo* or *None* + The usage information of the model. + +------------------------------------------------ + +**CompletionResponseChoice** + +**index**: *int* + The index of the choice. +**text**: *str* + The message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + + +------------------------------------------------ + +**CompletionStreamResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``text.completion.chunk``. +**created**: *int* + The time when the completion is created. +**choices**: *list[ChatCompletionResponseStreamhoice]* + A list of choices generated by the model. + +------------------------------------------------ + +**ChatCompletionResponseStreamChoice** + +**index**: *int* + The index of the choice. +**text**: *str* + The message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + +------------------------------------------------ + +**ChatCompletionResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``chat.completion``. +**created**: *int* + The time when the completion is created. +**choices**: *list[ChatCompletionResponseChoice]* + A list of choices generated by the model. +**usage**: *UsageInfo* or *None* + The usage information of the model. + +------------------------------------------------ + +**ChatCompletionResponseChoice** + +**index**: *int* + The index of the choice. +**message**: *ChatMessage* + The message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + +------------------------------------------------ + +**ChatCompletionStreamResponse** + +**id**: *str* + The id of the completion. +**object**: *str* + The object name ``chat.completion.chunk``. +**created**: *int* + The time when the completion is created. +**choices**: *list[ChatCompletionResponseStreamhoice]* + A list of choices generated by the model. + +------------------------------------------------ + +**ChatCompletionResponseStreamChoice** + +**index**: *int* + The index of the choice. +**delta**: *DeltaMessage* + The delta message generated by the model. +**finish_reason**: *str* + The reason why the model finishes generating the message. It can be either + ``stop`` or ``length``. + +------------------------------------------------ + + +**DeltaMessage** + +**role**: *str* + The role(author) of the message. It can be either ``user`` or ``assistant``. +**content**: *str* + The content of the message. + +------------------------------------------------ + + Use REST API in your own program -------------------------------- diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index 654b1646bc..55a32d1f5f 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -107,13 +107,14 @@ class CompletionRequest(BaseModel): class CompletionResponseChoice(BaseModel): index: int text: str - logprobs: int | None = None finish_reason: Literal["stop", "length"] | None = None + # TODO: logprobs support + logprobs: int | None = None class CompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") - object: str = "text_completion" + object: str = "text.completion" created: int = Field(default_factory=lambda: int(time.time())) choices: list[CompletionResponseChoice] usage: UsageInfo @@ -127,7 +128,7 @@ class CompletionResponseStreamChoice(BaseModel): class CompletionStreamResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") - object: str = "text_completion" + object: str = "text.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) choices: List[CompletionResponseStreamChoice] From 145a984940bbba40351edc5e3015af9f52b25d1b Mon Sep 17 00:00:00 2001 From: David Pissarra <61968959+davidpissarra@users.noreply.github.com> Date: Sun, 5 Nov 2023 06:18:47 +0000 Subject: [PATCH 094/116] [API] ```llm-vscode``` extension support (#1198) This PR enables ```llm-vscode``` extension API support for copilot-like code completion, following [HF's LSP](https://github.com/huggingface/llm-ls). Fully compatible with ```CodeLlama``` and ```starcoder``` on mlc-llm. - https://github.com/huggingface/llm-vscode/pull/103 enhances extension user experience when used with mlc-llm rest api. Thanks @ pacman100, who came up with this on his latest blogpost: https://huggingface.co/blog/personal-copilot --- python/mlc_chat/interface/openai_api.py | 15 +++++++++++++++ python/mlc_chat/rest.py | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index 55a32d1f5f..7e7cfd67bb 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -144,3 +144,18 @@ class EmbeddingsResponse(BaseModel): data: List[Dict[str, Any]] model: Optional[str] = None usage: UsageInfo + + +class VisualStudioCodeCompletionParameters(BaseModel): + temperature: float = None + top_p: float = None + max_new_tokens: int = None + + +class VisualStudioCodeCompletionRequest(BaseModel): + inputs: str + parameters: VisualStudioCodeCompletionParameters + + +class VisualStudioCodeCompletionResponse(BaseModel): + generated_text: str diff --git a/python/mlc_chat/rest.py b/python/mlc_chat/rest.py index 8611db017a..67b872f979 100644 --- a/python/mlc_chat/rest.py +++ b/python/mlc_chat/rest.py @@ -31,6 +31,8 @@ EmbeddingsRequest, EmbeddingsResponse, UsageInfo, + VisualStudioCodeCompletionRequest, + VisualStudioCodeCompletionResponse, ) @@ -364,6 +366,23 @@ async def read_stats_verbose(): return session["chat_mod"].stats(verbose=True) +@app.post("/v1/llm-vscode/completions") +async def request_llm_vscode(request: VisualStudioCodeCompletionRequest): + """ + Creates a vscode code completion for a given prompt. + Follows huggingface LSP (https://github.com/huggingface/llm-ls) + """ + generation_config = GenerationConfig( + temperature=request.parameters.temperature, + top_p=request.parameters.top_p, + mean_gen_len=request.parameters.max_new_tokens, + max_gen_len=request.parameters.max_new_tokens, + ) + msg = session["chat_mod"].generate(prompt=request.inputs, generation_config=generation_config) + + return VisualStudioCodeCompletionResponse(generated_text=msg) + + ARGS = convert_args_to_argparser().parse_args() if __name__ == "__main__": uvicorn.run("mlc_chat.rest:app", host=ARGS.host, port=ARGS.port, reload=False, access_log=False) From 3413d17f36175be6600d7b8b9050ed720236cbd1 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Nov 2023 12:03:33 -0800 Subject: [PATCH 095/116] [Fix] Use `fabs` as floating point abs function in C++ (#1202) --- cpp/llm_chat.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 70d89db348..b2d486be43 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -997,7 +997,7 @@ class LLMChat { } if (generation_config.count("presence_penalty")) { CHECK(generation_config["presence_penalty"].is()); - CHECK(abs(generation_config["presence_penalty"].get()) <= 2) + CHECK(fabs(generation_config["presence_penalty"].get()) <= 2) << "Presence penalty must be in the range -2 to 2!"; *gen_presence_penalty = generation_config["presence_penalty"].get(); } else { @@ -1005,7 +1005,7 @@ class LLMChat { } if (generation_config.count("frequency_penalty")) { CHECK(generation_config["frequency_penalty"].is()); - CHECK(abs(generation_config["frequency_penalty"].get()) <= 2) + CHECK(fabs(generation_config["frequency_penalty"].get()) <= 2) << "Frequency penalty must be in the range -2 to 2!"; *gen_frequency_penalty = generation_config["frequency_penalty"].get(); } else { From 7ccb51ac89513372e89b76e6fb8117c86c72b0e4 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Nov 2023 18:33:49 -0800 Subject: [PATCH 096/116] Integrating MLC runtime with the new compilation workflow (#1203) --- cpp/json_parser.h | 63 +++++++++++++++++++ cpp/llm_chat.cc | 34 ++++++++-- cpp/model_metadata.cc | 51 +++++++++++++++ cpp/model_metadata.h | 37 +++++++++++ python/mlc_chat/compiler/compile.py | 48 +++++++++++++- ...ise.py => fuse_dequantize_matmul_ewise.py} | 14 +++-- ...decode_take.py => fuse_dequantize_take.py} | 32 +++++----- ...nspose.py => fuse_dequantize_transpose.py} | 42 +++++++------ .../compiler/compiler_pass/pipeline.py | 12 ++-- .../quantization/group_quantization.py | 8 +-- 10 files changed, 282 insertions(+), 59 deletions(-) create mode 100644 cpp/json_parser.h create mode 100644 cpp/model_metadata.cc create mode 100644 cpp/model_metadata.h rename python/mlc_chat/compiler/compiler_pass/{fuse_decode_matmul_ewise.py => fuse_dequantize_matmul_ewise.py} (84%) rename python/mlc_chat/compiler/compiler_pass/{fuse_decode_take.py => fuse_dequantize_take.py} (69%) rename python/mlc_chat/compiler/compiler_pass/{fuse_decode_transpose.py => fuse_dequantize_transpose.py} (67%) diff --git a/cpp/json_parser.h b/cpp/json_parser.h new file mode 100644 index 0000000000..b181c300a9 --- /dev/null +++ b/cpp/json_parser.h @@ -0,0 +1,63 @@ +#ifndef MLC_LLM_CPP_JSON_PARSER_H_ +#define MLC_LLM_CPP_JSON_PARSER_H_ + +#define PICOJSON_USE_INT64 +#ifndef __STDC_FORMAT_MACROS +#define __STDC_FORMAT_MACROS +#endif + +#include +#include +#include +#include + +namespace mlc { +namespace llm { +namespace json { + +template +inline ValueType Lookup(const picojson::object& json, const std::string& key) { + auto it = json.find(key); + CHECK(it != json.end()) << "ValueError: key `" << key << "` not found in the JSON object"; + CHECK(it->second.is()) << "ValueError: key `" << key << "` has unexpected type"; + return it->second.get(); +} + +template <> +inline tvm::runtime::DataType Lookup(const picojson::object& json, const std::string& key) { + return tvm::runtime::DataType(tvm::runtime::String2DLDataType(Lookup(json, key))); +} + +template <> +inline tvm::runtime::ShapeTuple Lookup(const picojson::object& json, const std::string& key) { + picojson::array shape = Lookup(json, key); + std::vector result; + result.reserve(shape.size()); + for (const picojson::value& dim : shape) { + CHECK(dim.is()) << "ValueError: key `" << key << "` has unexpected type"; + result.push_back(dim.get()); + } + return tvm::runtime::ShapeTuple(std::move(result)); +} + +inline picojson::object ParseObject(const std::string& json_str) { + picojson::value result; + std::string err = picojson::parse(result, json_str); + if (!err.empty()) { + LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str; + } + CHECK(result.is()) + << "ValueError: The given string is not a JSON object: " << json_str; + return result.get(); +} + +inline picojson::object AsJSONObject(const picojson::value& json) { + CHECK(json.is()) << "ValueError: The given value is not a JSON object"; + return json.get(); +} + +} // namespace json +} // namespace llm +} // namespace mlc + +#endif // MLC_LLM_CPP_JSON_PARSER_H_ diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index b2d486be43..24776dc5b6 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -32,6 +32,7 @@ #include #include "conversation.h" +#include "model_metadata.h" #include "random.h" #include "support.h" #include "tokenizers.h" @@ -161,13 +162,18 @@ struct FunctionTable { static_cast(relax_vm::AllocatorType::kPooled), static_cast(kDLCPU), 0, static_cast(relax_vm::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { - return this->local_vm->GetFunction(name, false); + PackedFunc func = this->local_vm->GetFunction(name, false); + if (func == nullptr) { + LOG(WARNING) << "Cannot find function in VM: " << name; + } + return func; }; this->get_global_func = [](const std::string& name) -> PackedFunc { const auto* f = tvm::runtime::Registry::Get(name); CHECK(f != nullptr) << "ValueError: Cannot find function " << name; return *f; }; + this->model_metadata_ = ModelMetadata::FromModule(this->local_vm); this->_InitFunctions(); } } @@ -188,10 +194,23 @@ struct FunctionTable { const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load"); ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load"; (*fload_cache)(model_path, static_cast(device.device_type), device.device_id); - const PackedFunc* fload_params = - tvm::runtime::Registry::Get("vm.builtin.param_array_from_cache"); - ICHECK(fload_params) << "Cannot find env function vm.builtin.param_array_from_cache"; - Array params = (*fload_params)("param", -1); + Array params; + if (this->model_metadata_.params.empty()) { + constexpr const char* name_loader = "vm.builtin.param_array_from_cache"; + const PackedFunc* fload_params = tvm::runtime::Registry::Get(name_loader); + ICHECK(fload_params) << "Cannot find env function: " << name_loader; + params = (*fload_params)("param", -1); + } else { + constexpr const char* name_loader = "vm.builtin.param_array_from_cache_by_name"; + const PackedFunc* fload_params = tvm::runtime::Registry::Get(name_loader); + ICHECK(fload_params) << "Cannot find env function: " << name_loader; + Array param_names; + param_names.reserve(this->model_metadata_.params.size()); + for (const auto& param : this->model_metadata_.params) { + param_names.push_back(param.name); + } + params = (*fload_params)(param_names); + } // after we get params, it is safe to simply clear the cached version // as these params are referenced by params_ const PackedFunc* fclear_ndarray_cache = @@ -210,6 +229,9 @@ struct FunctionTable { this->softmax_func_ = mod_get_func("softmax_with_temperature"); this->encoding_without_cache_func_ = mod_get_func("encoding_without_cache"); this->create_kv_cache_func_ = mod_get_func("create_kv_cache"); + if (this->create_kv_cache_func_ == nullptr) { + this->create_kv_cache_func_ = mod_get_func("_initialize_effect"); + } this->reset_kv_cache_func_ = mod_get_func("reset_kv_cache"); if (this->reset_kv_cache_func_ == nullptr) { this->reset_kv_cache_func_ = get_global_func("vm.builtin.attention_kv_cache_array_clear"); @@ -260,6 +282,7 @@ struct FunctionTable { PackedFunc reset_kv_cache_func_; bool support_backtracking_kv_; PackedFunc fkvcache_array_popn_; + ModelMetadata model_metadata_; }; } // namespace @@ -437,6 +460,7 @@ class LLMChat { * \note This function overrides existing configurations. */ void LoadJSONOverride(const std::string& config_str, bool partial_update = false) { + LOG(INFO) << "config_str = " << config_str; picojson::value config_json; std::string err = picojson::parse(config_json, config_str); if (!err.empty()) { diff --git a/cpp/model_metadata.cc b/cpp/model_metadata.cc new file mode 100644 index 0000000000..135e0dcb7c --- /dev/null +++ b/cpp/model_metadata.cc @@ -0,0 +1,51 @@ +#include "./model_metadata.h" + +#include + +#include "./json_parser.h" + +namespace mlc { +namespace llm { + +using namespace tvm::runtime; + +ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& param) { + Param result; + result.name = json::Lookup(param, "name"); + result.shape = json::Lookup(param, "shape"); + result.dtype = json::Lookup(param, "dtype"); + return result; +} + +ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata) { + ModelMetadata result; + result.model_type = json::Lookup(metadata, "model_type"); + result.quantization = json::Lookup(metadata, "quantization"); + picojson::array params = json::Lookup(metadata, "params"); + result.params.reserve(params.size()); + for (const picojson::value& json_param : params) { + result.params.emplace_back(ModelMetadata::Param::FromJSON(json::AsJSONObject(json_param))); + } + return result; +} + +ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module) { + std::string json_str = ""; + try { + TypedPackedFunc pf = module.GetFunction("_metadata"); + ICHECK(pf != nullptr); + json_str = pf(); + } catch (...) { + return ModelMetadata(); // TODO: add a warning message about legacy usecases + } + picojson::object json = json::ParseObject(json_str); + try { + return ModelMetadata::FromJSON(json); + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to parse metadata:\n" << json_str; + throw e; + } +} + +} // namespace llm +} // namespace mlc diff --git a/cpp/model_metadata.h b/cpp/model_metadata.h new file mode 100644 index 0000000000..b408c72b08 --- /dev/null +++ b/cpp/model_metadata.h @@ -0,0 +1,37 @@ +/*! + * \file model_metadata.h + * \brief Metadata stored in model lib + */ +#include +#include +#include +#include + +#include + +namespace picojson { +class value; +using object = std::unordered_map; +} // namespace picojson + +namespace mlc { +namespace llm { + +struct ModelMetadata { + struct Param { + tvm::runtime::String name; + tvm::runtime::ShapeTuple shape; + tvm::runtime::DataType dtype; + + static Param FromJSON(const picojson::object& param_obj); + }; + std::string model_type; + std::string quantization; + std::vector params; + + static ModelMetadata FromJSON(const picojson::object& json_str); + static ModelMetadata FromModule(tvm::runtime::Module module); +}; + +} // namespace llm +} // namespace mlc diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index c4d7a9881f..678e924a78 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -1,11 +1,13 @@ """Python entrypoint of compilation.""" import dataclasses +import json import logging from io import StringIO from pathlib import Path -from typing import Callable, Optional +from typing import Callable, List, Optional, Tuple from tvm import IRModule, relax +from tvm.relax.frontend import nn from tvm.target import Target from ..support.style import bold @@ -46,21 +48,61 @@ def display(self) -> None: print(out.getvalue().rstrip()) +def _attach_auxiliary_methods( + mod: IRModule, + named_params: List[Tuple[str, nn.Parameter]], + args: CompileArgs, + model_config, +) -> None: + def _metadata(): + metadata = { + "quantization": args.quantization.name, + "model_type": args.model.name, + "params": [ + { + "name": name, + "shape": list(param.shape), + "dtype": param.dtype, + } + for name, param in named_params + ], + } + bb = relax.BlockBuilder() # pylint: disable=invalid-name + with bb.function("main", params=[]): + bb.emit_func_output(relax.StringImm(json.dumps(metadata))) + return bb.get()["main"] + + def _attach_variable_bounds(): + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr( + "tir_var_upper_bound", + { + "seq_len": model_config.max_sequence_length, + "total_seq_len": model_config.max_sequence_length, + }, + ) + + mod["_metadata"] = _metadata() + _attach_variable_bounds() + + def _compile(args: CompileArgs): logger.info("Creating model from: %s", args.config) model_config = args.model.config.from_file(args.config) args.overrides.apply(model_config) model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization) logger.info("Exporting the model to TVM Unity compiler") - mod, _named_params = model.export_tvm( + mod, named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore ) + _attach_auxiliary_methods(mod, named_params, args, model_config) logger.info("Running optimizations using TVM Unity") with args.target: mod = relax.get_pipeline("mlc_llm")(mod) logger.info("Generating code using TVM Unity") args.build_func(mod, args) - logger.info("Code dumped to: %s", bold(str(args.output))) + logger.info("Generated: %s", bold(str(args.output))) def compile( # pylint: disable=too-many-arguments,redefined-builtin diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py b/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_matmul_ewise.py similarity index 84% rename from python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py rename to python/mlc_chat/compiler/compiler_pass/fuse_dequantize_matmul_ewise.py index ddc71818ff..f8a64c8cda 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_matmul_ewise.py @@ -1,12 +1,12 @@ -"""A compiler pass that fuses decode + matmul + elementwise.""" +"""A compiler pass that fuses dequantize + matmul + elementwise.""" import tvm from tvm import IRModule, relax from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard -@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise") -class FuseDecodeMatmulEwise: # pylint: disable=too-few-public-methods - """A compiler pass that fuses decode + matmul + elementwise.""" +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeMatmulEwise") +class FuseDequantizeMatmulEwise: # pylint: disable=too-few-public-methods + """A compiler pass that fuses dequantize + matmul + elementwise.""" def transform_module( self, @@ -23,7 +23,7 @@ def transform_module( relax.transform.FuseOpsByPattern( [ ( - "decode_matmul", + "dequantize_matmul", *_pattern(match_ewise, n_aux_tensor), ) ] @@ -62,7 +62,9 @@ def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: g_var = call.args[0] if not isinstance(g_var, relax.GlobalVar): return False - return g_var.name_hint.startswith("decode") or g_var.name_hint.startswith("fused_decode") + return g_var.name_hint.startswith("dequantize") or g_var.name_hint.startswith( + "fused_dequantize" + ) def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool: call = ctx.annotated_expr["matmul"] diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py b/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_take.py similarity index 69% rename from python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py rename to python/mlc_chat/compiler/compiler_pass/fuse_dequantize_take.py index f2022c1161..80792159ba 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_take.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_take.py @@ -1,4 +1,4 @@ -"""A compiler pass that fuses decode + take.""" +"""A compiler pass that fuses dequantize + take.""" import tvm from tvm import IRModule, relax, tir from tvm.relax.dpl.pattern import ( @@ -10,9 +10,9 @@ ) -@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTake") -class FuseDecodeTake: # pylint: disable=too-few-public-methods - """A compiler pass that fuses decode + take.""" +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeTake") +class FuseDequantizeTake: # pylint: disable=too-few-public-methods + """A compiler pass that fuses dequantize + take.""" def transform_module( self, @@ -27,7 +27,7 @@ def transform_module( relax.transform.FuseOpsByPattern( [ ( - "decode_take", + "dequantize_take", *_pattern(n_aux_tensor, match_tir_vars), ) ] @@ -37,17 +37,19 @@ def transform_module( mod = tvm.transform.Sequential(seq)(mod) for g_var, func in mod.functions_items(): name = g_var.name_hint - if isinstance(func, tir.PrimFunc) and (("fused_decode" in name) and ("take" in name)): + if isinstance(func, tir.PrimFunc) and ( + ("fused_dequantize" in name) and ("take" in name) + ): sch_mod = tvm.IRModule({"main": func}) sch_mod = tir.transform.ForceNarrowIndexToInt32()(sch_mod) sch = tir.Schedule(sch_mod) - sch.compute_inline("decode") + sch.compute_inline("dequantize") mod[g_var] = sch.mod["main"] return mod def _pattern(n_aux_tensor: int, match_tir_vars: bool): - decode = is_op("relax.call_tir")( + dequantize = is_op("relax.call_tir")( GlobalVarPattern(), TuplePattern([wildcard() for _ in range(n_aux_tensor)]), add_constraint=False, @@ -56,13 +58,13 @@ def _pattern(n_aux_tensor: int, match_tir_vars: bool): if match_tir_vars: call_tir_args_take = [ GlobalVarPattern(), - TuplePattern([decode, indices]), + TuplePattern([dequantize, indices]), wildcard(), ] else: call_tir_args_take = [ GlobalVarPattern(), - TuplePattern([decode, indices]), + TuplePattern([dequantize, indices]), ] take = is_op("relax.call_tir")( *call_tir_args_take, @@ -70,19 +72,19 @@ def _pattern(n_aux_tensor: int, match_tir_vars: bool): ) annotations = { "take": take, - "decode": decode, + "dequantize": dequantize, "indices": indices, } def _check(ctx: relax.transform.PatternCheckContext) -> bool: take = ctx.annotated_expr["take"] - decode = ctx.annotated_expr["decode"] - if not isinstance(decode, relax.expr.Call): + dequantize = ctx.annotated_expr["dequantize"] + if not isinstance(dequantize, relax.expr.Call): return False if not isinstance(take.args[0], relax.GlobalVar) or not isinstance( - decode.args[0], relax.GlobalVar + dequantize.args[0], relax.GlobalVar ): return False - return "take" in take.args[0].name_hint and "decode" in decode.args[0].name_hint + return "take" in take.args[0].name_hint and "dequantize" in dequantize.args[0].name_hint return take, annotations, _check diff --git a/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py b/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_transpose.py similarity index 67% rename from python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py rename to python/mlc_chat/compiler/compiler_pass/fuse_dequantize_transpose.py index e2a826a1fb..816da39956 100644 --- a/python/mlc_chat/compiler/compiler_pass/fuse_decode_transpose.py +++ b/python/mlc_chat/compiler/compiler_pass/fuse_dequantize_transpose.py @@ -6,8 +6,8 @@ from tvm.relax.expr_functor import PyExprMutator, mutator -@tvm.transform.module_pass(opt_level=0, name="FuseDecodeTranspose") -class FuseDecodeTranspose: # pylint: disable=too-few-public-methods +@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeTranspose") +class FuseDequantizeTranspose: # pylint: disable=too-few-public-methods """A compiler pass that fuses transpose + dequantize.""" def __init__(self, skip_gemm: bool) -> None: @@ -15,11 +15,11 @@ def __init__(self, skip_gemm: bool) -> None: def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """IRModule-level transformation""" - return _DecodeTransposeFuser(mod, skip_gemm=self.skip_gemm).transform() + return _DequantizeTransposeFuser(mod, skip_gemm=self.skip_gemm).transform() @mutator -class _DecodeTransposeFuser(PyExprMutator): # pylint: disable=abstract-method +class _DequantizeTransposeFuser(PyExprMutator): # pylint: disable=abstract-method def __init__( self, mod: IRModule, @@ -45,7 +45,7 @@ def visit_call_( # pylint: disable=arguments-renamed call = self.visit_expr_post_order(call) if call.op != tvm.ir.Op.get("relax.matmul"): return call - # Do not fuse decode-transpose for GeMM + # Do not fuse dequantize-transpose for GeMM if self.skip_gemm and ( call.args[0].struct_info.ndim < 2 or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm) @@ -66,25 +66,27 @@ def visit_call_( # pylint: disable=arguments-renamed if ( not isinstance(transpose_input, relax.Call) or transpose_input.op != tvm.ir.Op.get("relax.call_tir") - or not transpose_input.args[0].name_hint.startswith("decode") + or not transpose_input.args[0].name_hint.startswith("dequantize") or not isinstance(transpose_input.struct_info, relax.TensorStructInfo) ): return call - decode_tir_func = self.mod[transpose_input.args[0]] - assert isinstance(decode_tir_func, tir.PrimFunc) + dequantize_tir_func = self.mod[transpose_input.args[0]] + assert isinstance(dequantize_tir_func, tir.PrimFunc) if ( # pylint: disable=too-many-boolean-expressions - len(decode_tir_func.body.block.alloc_buffers) != 1 - or not isinstance(decode_tir_func.body.block.body, tir.SeqStmt) - or len(decode_tir_func.body.block.body) != 2 - or not isinstance(decode_tir_func.body.block.body[1], tir.For) - or not isinstance(decode_tir_func.body.block.body[1].body.body, tir.BlockRealize) - or decode_tir_func.body.block.body[1].body.body.block.name_hint != "T_transpose" + len(dequantize_tir_func.body.block.alloc_buffers) != 1 + or not isinstance(dequantize_tir_func.body.block.body, tir.SeqStmt) + or len(dequantize_tir_func.body.block.body) != 2 + or not isinstance(dequantize_tir_func.body.block.body[1], tir.For) + or not isinstance(dequantize_tir_func.body.block.body[1].body.body, tir.BlockRealize) + or dequantize_tir_func.body.block.body[1].body.body.block.name_hint != "T_transpose" ): return call - new_func_buffers = [decode_tir_func.buffer_map[var] for var in decode_tir_func.params] - new_func_buffers[-1] = decode_tir_func.body.block.alloc_buffers[0] + new_func_buffers = [ + dequantize_tir_func.buffer_map[var] for var in dequantize_tir_func.params + ] + new_func_buffers[-1] = dequantize_tir_func.body.block.alloc_buffers[0] new_func = tir.PrimFunc( params=new_func_buffers, body=tir.BlockRealize( @@ -95,15 +97,15 @@ def visit_call_( # pylint: disable=arguments-renamed reads=[], writes=[], name_hint="root", - body=decode_tir_func.body.block.body[0], + body=dequantize_tir_func.body.block.body[0], ), ), ) # Call `renew_defs` for deep-copy to avoid IR node duplication in # different PrimFuncs of an IRModule. new_func = tir.stmt_functor.renew_defs(new_func) - g_var = self.builder_.add_func(new_func, func_name="decode") - decoded_matmul_rhs = self.builder_.emit( + g_var = self.builder_.add_func(new_func, func_name="dequantize") + dequantize_matmul_rhs = self.builder_.emit( relax.call_tir(g_var, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info) ) - return relax.op.matmul(call.args[0], decoded_matmul_rhs, out_dtype=call.attrs.out_dtype) + return relax.op.matmul(call.args[0], dequantize_matmul_rhs, out_dtype=call.attrs.out_dtype) diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler/compiler_pass/pipeline.py index 43fc8f131c..f9bfdd0c59 100644 --- a/python/mlc_chat/compiler/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler/compiler_pass/pipeline.py @@ -7,9 +7,9 @@ from tvm.relax import register_pipeline # pylint: disable=no-name-in-module from .clean_up_tir_attrs import CleanUpTIRAttrs -from .fuse_decode_matmul_ewise import FuseDecodeMatmulEwise -from .fuse_decode_take import FuseDecodeTake -from .fuse_decode_transpose import FuseDecodeTranspose +from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise +from .fuse_dequantize_take import FuseDequantizeTake +from .fuse_dequantize_transpose import FuseDequantizeTranspose from .fuse_transpose_matmul import FuseTransposeMatmul from .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc @@ -37,7 +37,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I [ # Phase 1. Passes on high-level operator graph _LogProgress("Running TVM Relax graph-level optimizations"), - FuseDecodeTranspose(skip_gemm=False), + FuseDequantizeTranspose(skip_gemm=False), FuseTransposeMatmul(), # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), @@ -48,8 +48,8 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.FuseTIR(), # Phase 3. Passes on TIR _LogProgress("Running TVM TIR-level optimizations"), - FuseDecodeMatmulEwise(), - FuseDecodeTake(), + FuseDequantizeMatmulEwise(), + FuseDequantizeTake(), tvm.relax.transform.DeadCodeElimination(), CleanUpTIRAttrs(["op_pattern"]), # Phase 4. Low-level Optimizations diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index ea27410bea..1e43e5430c 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -149,7 +149,7 @@ def _dequantize( ), scale[i, j // self.group_size], ), - name="decode", + name="dequantize", ) def quantize_weight(self, weight: NDArray) -> List[NDArray]: @@ -336,7 +336,7 @@ def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name scale, [tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)], ), - name_hint="decode", + name_hint="dequantize", args=[self.q_weight, self.q_scale], ) w = nn.op.permute_dims(w) # pylint: disable=invalid-name @@ -430,7 +430,7 @@ def forward(self, x: nn.Tensor) -> Sequence[nn.Tensor]: # pylint: disable=inval tir.IntImm("int64", self.in_features), ], ), - name_hint="decode", + name_hint="dequantize", args=[self.q_weight, self.q_scale], ) # x: [*B, in_features] @@ -501,7 +501,7 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name scale, [tir.IntImm("int64", self.num), tir.IntImm("int64", self.dim)], ), - name_hint="decode", + name_hint="dequantize", args=[self.q_weight, self.q_scale], ) if x.ndim == 1: From 65478c88ab989a09b6c05b21aad7d24a5e547a80 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 5 Nov 2023 19:52:53 -0800 Subject: [PATCH 097/116] [Fix] Remove Redundant Warnings (#1204) PR #1203 introduces some unnecessary and redundant logging messages. This PR gets them removed. --- cpp/llm_chat.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 24776dc5b6..980f7fee18 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -163,9 +163,6 @@ struct FunctionTable { static_cast(relax_vm::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { PackedFunc func = this->local_vm->GetFunction(name, false); - if (func == nullptr) { - LOG(WARNING) << "Cannot find function in VM: " << name; - } return func; }; this->get_global_func = [](const std::string& name) -> PackedFunc { @@ -460,7 +457,6 @@ class LLMChat { * \note This function overrides existing configurations. */ void LoadJSONOverride(const std::string& config_str, bool partial_update = false) { - LOG(INFO) << "config_str = " << config_str; picojson::value config_json; std::string err = picojson::parse(config_json, config_str); if (!err.empty()) { From 01d4339a48c83f742bb804dfc7c5efcb49679080 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Nov 2023 09:08:58 -0800 Subject: [PATCH 098/116] Try fix macOS build with picojson (#1206) The error message below ``` /usr/share/miniconda/envs/mlc-llm-build/conda-bld/mlc-chat-cli-nightly-package_1699286394016/work/3rdparty/tvm/3rdparty/picojson/picojson.h: In member function 'std::string picojson::value::to_str() const': /usr/share/miniconda/envs/mlc-llm-build/conda-bld/mlc-chat-cli-nightly-package_1699286394016/work/3rdparty/tvm/3rdparty/picojson/picojson.h:494:37: error: expected ')' before 'PRId64' 494 | SNPRINTF(buf, sizeof(buf), "%" PRId64, u_.int64_); | ~ ^~~~~~~ | ) /usr/share/miniconda/envs/mlc-llm-build/conda-bld/mlc-chat-cli-nightly-package_1699286394016/work/3rdparty/tvm/3rdparty/picojson/picojson.h:81:1: note: 'PRId64' is defined in header ''; did you forget to '#include '? 80 | #include +++ |+#include 81 | #include ``` indicates that the `__STDC_FORMAT_MACROS` flag is not turned on for some reason. --- cpp/json_parser.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/json_parser.h b/cpp/json_parser.h index b181c300a9..3505329660 100644 --- a/cpp/json_parser.h +++ b/cpp/json_parser.h @@ -2,9 +2,7 @@ #define MLC_LLM_CPP_JSON_PARSER_H_ #define PICOJSON_USE_INT64 -#ifndef __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS -#endif #include #include From 51d6f9cbf427b85f8d4f88b2fa0a3e451a9f4788 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Nov 2023 09:28:57 -0800 Subject: [PATCH 099/116] Try fix macOS build with picojson again (#1207) Try fix macOS build with picojson --- cpp/model_metadata.cc | 1 + cpp/model_metadata.h | 1 + 2 files changed, 2 insertions(+) diff --git a/cpp/model_metadata.cc b/cpp/model_metadata.cc index 135e0dcb7c..30c2cc0df1 100644 --- a/cpp/model_metadata.cc +++ b/cpp/model_metadata.cc @@ -1,3 +1,4 @@ +#define __STDC_FORMAT_MACROS #include "./model_metadata.h" #include diff --git a/cpp/model_metadata.h b/cpp/model_metadata.h index b408c72b08..7bd0172c5c 100644 --- a/cpp/model_metadata.h +++ b/cpp/model_metadata.h @@ -2,6 +2,7 @@ * \file model_metadata.h * \brief Metadata stored in model lib */ +#define __STDC_FORMAT_MACROS #include #include #include From a7f11835766a3d00968a04e75863de0bbb6b7acc Mon Sep 17 00:00:00 2001 From: Git bot Date: Mon, 6 Nov 2023 18:53:07 +0000 Subject: [PATCH 100/116] Auto updated submodule references --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3001b20b0d..36aa05178e 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3001b20b0dd114cad23fccb25cbb055ce80a224e +Subproject commit 36aa05178e08793ae937071b3b99c69ed3e13686 From e2c99a8cfe1d938d32512207530e887975bb7ebc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Nov 2023 12:01:51 -0800 Subject: [PATCH 101/116] [Fix] Keep update-to-date with upstream API change (#1209) --- cpp/image_embed.cc | 6 +++--- cpp/llm_chat.cc | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/image_embed.cc b/cpp/image_embed.cc index 716954b272..684d6c2a85 100644 --- a/cpp/image_embed.cc +++ b/cpp/image_embed.cc @@ -9,10 +9,10 @@ #include "image_embed.h" #include +#include #include #include #include -#include #include #include @@ -59,9 +59,9 @@ class LLMImage { ICHECK(fload_exec.defined()) << "TVM runtime cannot find vm_load_executable"; vm_ = fload_exec(); vm_->GetFunction("vm_initialization")(static_cast(device_.device_type), device_.device_id, - static_cast(relax_vm::AllocatorType::kPooled), + static_cast(memory::AllocatorType::kPooled), static_cast(kDLCPU), 0, - static_cast(relax_vm::AllocatorType::kPooled)); + static_cast(memory::AllocatorType::kPooled)); embed_func_ = vm_->GetFunction("embed"); diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 980f7fee18..ff9ca5f5d9 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -12,11 +12,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include @@ -159,8 +159,8 @@ struct FunctionTable { this->local_vm = fload_exec(); this->local_vm->GetFunction("vm_initialization")( static_cast(device.device_type), device.device_id, - static_cast(relax_vm::AllocatorType::kPooled), static_cast(kDLCPU), 0, - static_cast(relax_vm::AllocatorType::kPooled)); + static_cast(memory::AllocatorType::kPooled), static_cast(kDLCPU), 0, + static_cast(memory::AllocatorType::kPooled)); this->mod_get_func = [this](const std::string& name) -> PackedFunc { PackedFunc func = this->local_vm->GetFunction(name, false); return func; From e00220ce4c115837514c013f718853b7bdcb7c6e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Nov 2023 13:04:36 -0800 Subject: [PATCH 102/116] Detect `mtriple` via LLVM (#1211) --- python/mlc_chat/support/auto_target.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/mlc_chat/support/auto_target.py b/python/mlc_chat/support/auto_target.py index f84d3cd6b1..4c25380e8f 100644 --- a/python/mlc_chat/support/auto_target.py +++ b/python/mlc_chat/support/auto_target.py @@ -5,7 +5,7 @@ import tvm from tvm import IRModule, relax -from tvm._ffi import register_func +from tvm._ffi import get_global_func, register_func from tvm.contrib import tar, xcode from tvm.runtime import Device from tvm.target import Target @@ -113,14 +113,10 @@ def _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]: def _detect_target_host(hint: str) -> Target: """Detect the host CPU architecture.""" - # cpu = codegen.llvm_get_system_cpu() - # triple = codegen.llvm_get_system_triple() - # vendor = codegen.llvm_get_system_x86_vendor() if hint == "auto": - hint = "x86-64" - if hint == "x86-64": - hint = "x86_64" - return Target({"kind": "llvm", "mtriple": f"{hint}-unknown-unknown"}) + target_triple = get_global_func("tvm.codegen.llvm.GetDefaultTargetTriple")() + logger.info("%s host CPU architecture: %s", FOUND, bold(target_triple)) + return Target({"kind": "llvm", "mtriple": target_triple}) def _is_device(device: str): @@ -156,7 +152,7 @@ def _detect_target_from_device(device: str) -> Optional[Target]: logger.info( '%s configuration of target device "%s": %s', FOUND, - device, + bold(device), target.export(), ) return target From 9869ca6f986ef6d4e07ac4b910df363b01dd760d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 6 Nov 2023 16:03:12 -0600 Subject: [PATCH 103/116] Fix Python3.8 compatibility breakage (#1210) The breakage was resulting from newer syntax being used for type annotations, as part of https://github.com/mlc-ai/mlc-llm/pull/592. So long as `mlc_chat.interface.openai_api` wasn't imported, the breaking changes were not encountered. In https://github.com/mlc-ai/mlc-llm/pull/1107, the addition of `from .interface.openai_api import ChatMessage` caused this module to be imported, breaking compatibility of `mlc_chat.ChatModule` with Python3.8. This commit updates the type annotations to the supported syntax. --- python/mlc_chat/interface/openai_api.py | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/python/mlc_chat/interface/openai_api.py b/python/mlc_chat/interface/openai_api.py index 7e7cfd67bb..46157ccbbe 100644 --- a/python/mlc_chat/interface/openai_api.py +++ b/python/mlc_chat/interface/openai_api.py @@ -14,13 +14,13 @@ class ChatMessage(BaseModel): role: str content: str - name: str | None = None + name: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str - messages: list[ChatMessage] - stream: bool | None = False + messages: List[ChatMessage] + stream: Optional[bool] = False temperature: float = None top_p: float = None # TODO: replace by presence_penalty and frequency_penalty @@ -43,47 +43,47 @@ class ChatCompletionRequest(BaseModel): class UsageInfo(BaseModel): prompt_tokens: int = 0 - completion_tokens: int | None = 0 + completion_tokens: Optional[int] = 0 total_tokens: int = 0 class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - finish_reason: Literal["stop", "length"] | None = None + finish_reason: Optional[Literal["stop", "length"]] = None class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) - choices: list[ChatCompletionResponseChoice] + choices: List[ChatCompletionResponseChoice] # TODO: Implement support for the following fields - usage: UsageInfo | None = None + usage: Optional[UsageInfo] = None class DeltaMessage(BaseModel): - role: str | None = None - content: str | None = None + role: Optional[str] = None + content: Optional[str] = None class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - finish_reason: Literal["stop", "length"] | None = None + finish_reason: Optional[Literal["stop", "length"]] = None class ChatCompletionStreamResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) - choices: list[ChatCompletionResponseStreamChoice] + choices: List[ChatCompletionResponseStreamChoice] class CompletionRequest(BaseModel): model: str - prompt: str | list[str] - stream: bool | None = False + prompt: Union[str, List[str]] + stream: Optional[bool] = False temperature: float = None repetition_penalty: float = None top_p: float = None @@ -107,16 +107,16 @@ class CompletionRequest(BaseModel): class CompletionResponseChoice(BaseModel): index: int text: str - finish_reason: Literal["stop", "length"] | None = None + finish_reason: Optional[Literal["stop", "length"]] = None # TODO: logprobs support - logprobs: int | None = None + logprobs: Optional[int] = None class CompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") object: str = "text.completion" created: int = Field(default_factory=lambda: int(time.time())) - choices: list[CompletionResponseChoice] + choices: List[CompletionResponseChoice] usage: UsageInfo From 4042626997c44c486f550b611d8bb904a637e7e3 Mon Sep 17 00:00:00 2001 From: Lesheng Jin <34279105+LeshengJin@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:43:21 -0800 Subject: [PATCH 104/116] [Slim-LM] Enable loading from AWQ pre-quantized weight. (#1114) * [SLM] Enable loading from AWQ pre-quantized weight. * remove awq_loader.py * Update to the latest commit * Delete llama_parameter.py * update unittest * fix lint * upd * add Llama-2-7B-AWQ --- .../compiler/loader/huggingface_loader.py | 2 +- .../mlc_chat/compiler/model/llama_loader.py | 74 ++++ .../compiler/model/llama_quantization.py | 19 +- python/mlc_chat/compiler/model/model.py | 1 + .../compiler/quantization/__init__.py | 1 + .../compiler/quantization/awq_quantization.py | 367 ++++++++++++++++++ .../quantization/group_quantization.py | 22 +- .../compiler/quantization/quantization.py | 9 + .../mlc_chat/compiler/quantization/utils.py | 34 ++ tests/python/loader/test_awq.py | 45 +++ .../quantization/test_awq_quantization.py | 90 +++++ 11 files changed, 650 insertions(+), 14 deletions(-) create mode 100644 python/mlc_chat/compiler/quantization/awq_quantization.py create mode 100644 python/mlc_chat/compiler/quantization/utils.py create mode 100644 tests/python/loader/test_awq.py create mode 100644 tests/python/quantization/test_awq_quantization.py diff --git a/python/mlc_chat/compiler/loader/huggingface_loader.py b/python/mlc_chat/compiler/loader/huggingface_loader.py index a58220ab65..651c43b21f 100644 --- a/python/mlc_chat/compiler/loader/huggingface_loader.py +++ b/python/mlc_chat/compiler/loader/huggingface_loader.py @@ -83,7 +83,7 @@ def __init__( self.cached_files = {} self.torch_to_path = {} self.quantize_param_map = quantize_param_map - if path.suffix in (".bin", ".safetensors"): + if path.suffix in (".bin", ".safetensors", ".pt"): self._load_file(path) for name in self.cached_files[path].keys(): self.torch_to_path[name] = path diff --git a/python/mlc_chat/compiler/model/llama_loader.py b/python/mlc_chat/compiler/model/llama_loader.py index 12d957952e..d68a90c4bf 100644 --- a/python/mlc_chat/compiler/model/llama_loader.py +++ b/python/mlc_chat/compiler/model/llama_loader.py @@ -9,6 +9,7 @@ from ..loader import ExternMapping from ..quantization import Quantization from .llama_model import LlamaConfig, LlamaForCasualLM +from .llama_quantization import awq_quant def huggingface(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: @@ -82,3 +83,76 @@ def huggingface(model_config: LlamaConfig, quantization: Quantization) -> Extern ), ) return mapping + + +def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: + """Returns a parameter mapping that maps from the names of MLC LLM parameters to + the names of AWQ parameters. + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + + quantization : Quantization + The quantization configuration. + + Returns + ------- + param_map : ExternMapping + The parameter mapping from MLC to AWQ. + """ + model, _ = awq_quant(model_config, quantization) + _, _named_params = model.export_tvm(spec=model.get_default_spec()) + named_parameters = dict(_named_params) + + mapping = ExternMapping() + + for i in range(model_config.num_hidden_layers): + # Add QKV in self attention + attn = f"model.layers.{i}.self_attn" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{attn}.qkv_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{attn}.q_proj.{quantize_suffix}", + f"{attn}.k_proj.{quantize_suffix}", + f"{attn}.v_proj.{quantize_suffix}", + ], + functools.partial( + lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # Concat gate and up in MLP + mlp = f"model.layers.{i}.mlp" + for quantize_suffix in ["qweight", "qzeros", "scales"]: + mlc_name = f"{mlp}.gate_up_proj.{quantize_suffix}" + assert mlc_name in named_parameters + mlc_param = named_parameters[mlc_name] + mapping.add_mapping( + mlc_name, + [ + f"{mlp}.gate_proj.{quantize_suffix}", + f"{mlp}.up_proj.{quantize_suffix}", + ], + functools.partial( + lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + dtype=mlc_param.dtype, + ), + ) + + # inv_freq is not used in the model + mapping.add_unused(f"{attn}.rotary_emb.inv_freq") + + for mlc_name, mlc_param in named_parameters.items(): + if mlc_name not in mapping.param_map: + mapping.add_mapping( + mlc_name, + [mlc_name], + functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype), + ) + return mapping diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py index cec9bd86e5..68fc2dfff5 100644 --- a/python/mlc_chat/compiler/model/llama_quantization.py +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -5,7 +5,7 @@ from tvm.relax.frontend import nn from ..loader import QuantizeMapping -from ..quantization import GroupQuantize +from ..quantization import AWQQuantize, GroupQuantize from .llama_model import LlamaConfig, LlamaForCasualLM @@ -15,6 +15,23 @@ def group_quant( ) -> Tuple[nn.Module, QuantizeMapping]: """Quantize a Llama2 model using group quantization.""" model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) + quant_map = QuantizeMapping({}, {}) + model = quantization.quantize_model( + model, + quant_map, + "", + ) + return model, quant_map + + +def awq_quant( + model_config: LlamaConfig, + quantization: AWQQuantize, +) -> Tuple[nn.Module, QuantizeMapping]: + """Quantize a Llama2 model using Activation-aware Weight Quantization(AWQ).""" + model: nn.Module = LlamaForCasualLM(model_config) + model.to(quantization.model_dtype) quant_map = QuantizeMapping({}, {}) model = quantization.quantize_model( model, diff --git a/python/mlc_chat/compiler/model/model.py b/python/mlc_chat/compiler/model/model.py index b18742201f..a5e40818de 100644 --- a/python/mlc_chat/compiler/model/model.py +++ b/python/mlc_chat/compiler/model/model.py @@ -58,6 +58,7 @@ class Model: source={ "huggingface-torch": llama_loader.huggingface, "huggingface-safetensor": llama_loader.huggingface, + "awq": llama_loader.awq, }, quantize={ "group-quant": llama_quantization.group_quant, diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py index 3df96ce18a..74950df832 100644 --- a/python/mlc_chat/compiler/quantization/__init__.py +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -1,3 +1,4 @@ """A subpackage for quantization and dequantization algorithms""" +from .awq_quantization import AWQQuantize from .group_quantization import GroupQuantize from .quantization import QUANTIZATION, Quantization diff --git a/python/mlc_chat/compiler/quantization/awq_quantization.py b/python/mlc_chat/compiler/quantization/awq_quantization.py new file mode 100644 index 0000000000..944ded0ba0 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/awq_quantization.py @@ -0,0 +1,367 @@ +"""AWQ Quantization""" + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence + +import numpy as np +from tvm import DataType, DataTypeCode, te, tir +from tvm.relax.frontend import nn +from tvm.runtime import NDArray + +from ..loader import QuantizeMapping +from .utils import convert_uint_to_float + + +def _make_divisible(c, divisor): # pylint: disable=invalid-name + return (c + divisor - 1) // divisor + + +def _calculate_zeros_width(in_features, group_size=128, pack_num=8): + if group_size >= 128: + size_multiplier = 1 + elif group_size == 64: + size_multiplier = 2 + elif group_size == 32: + size_multiplier = 4 + else: + raise NotImplementedError + + base_width = _make_divisible(in_features // group_size, pack_num) + base_width = _make_divisible(base_width, size_multiplier) * size_multiplier + return base_width + + +@dataclass +class AWQQuantize: # pylint: disable=too-many-instance-attributes + """Configuration for AWQ quantization""" + + name: str + kind: str + group_size: int + quantize_dtype: str # "int3", "int4", "int8" + storage_dtype: str # "uint32" + model_dtype: str # "float16", "float32" + + num_elem_per_storage: int = 0 + num_storage_per_group: int = 0 + max_int_value: int = 0 + + prebuilt_quantize_func: Dict[str, Callable[[NDArray], NDArray]] = field( + default_factory=lambda: {} + ) + + def __post_init__(self): + assert self.kind == "awq" + quantize_dtype = DataType(self.quantize_dtype) + storage_dtype = DataType(self.storage_dtype) + model_dtype = DataType(self.model_dtype) + assert quantize_dtype.type_code == DataTypeCode.INT + assert storage_dtype.type_code == DataTypeCode.UINT + assert model_dtype.type_code == DataTypeCode.FLOAT + if storage_dtype.bits < quantize_dtype.bits: + raise ValueError("Storage unit should be greater or equal to quantized element") + + self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits + if self.group_size % self.num_elem_per_storage != 0: + raise ValueError("Group size should be divisible by numbers of elements per storage") + self.num_storage_per_group = self.group_size // self.num_elem_per_storage + self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1 + + def quantize_model( + self, + model: nn.Module, + quant_map: QuantizeMapping, + name_prefix: str, + ) -> nn.Module: + """ + Quantize model with awq quantization. + + Parameters + ---------- + model : nn.Module + The non-quantized nn.Module. + + quant_map : QuantizeMapping + The quantize mapping with name mapping and func mapping. + + name_prefix : str + The name prefix for visited weight. + + Returns + ------- + ret : nn.Module + The quantized nn.Module. + """ + + class _Mutator(nn.Mutator): + def __init__(self, config: AWQQuantize, quant_map: QuantizeMapping) -> None: + super().__init__() + self.config = config + self.quant_map = quant_map + + def visit_module(self, name: str, node: nn.Module) -> Any: + """ + The visiting method for awq quantization of nn.Module nodes. + + Parameters + ---------- + name : str + The name of the current node + + node : nn.Module + The current node of nn.Module to mutate. + + Returns + ------- + ret_node : Any + The new node to replace current node. + """ + + if isinstance(node, nn.Linear) and name != "lm_head": + return AWQQuantizeLinear.from_linear(node, self.config) + if isinstance(node, nn.MultiLinear): + return AWQQuantizeMultiLinear.from_multilinear(node, self.config) + return self.visit(name, node) + + model.to(dtype=self.model_dtype) + mutator = _Mutator(self, quant_map) + model = mutator.visit(name_prefix, model) + return model + + def _dequantize( + self, + weight: te.Tensor, + zeros: te.Tensor, + scale: te.Tensor, + out_shape: Optional[List[tir.PrimExpr]] = None, + ): + float_weight = convert_uint_to_float( + weight, + DataType(self.quantize_dtype).bits, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + out_shape, + ) + float_zeros = convert_uint_to_float( + zeros, + DataType(self.quantize_dtype).bits, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + out_shape, + ) + return te.compute( + shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage] + if out_shape is None + else out_shape, + fcompute=lambda i, j: tir.multiply( + tir.subtract(float_weight[i, j], float_zeros[i, j // self.group_size]), + scale[i, j // self.group_size], + ), + name="decode", + ) + + +class AWQQuantizeLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.Linear module with AWQ quantization""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: int, + config: AWQQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + self.qweight = nn.Parameter( + (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), + config.storage_dtype, + ) + self.qzeros = nn.Parameter( + ( + out_features, + _calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage), + ), + dtype=config.storage_dtype, + ) + self.scales = nn.Parameter( + ( + out_features, + _calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage) + * config.num_elem_per_storage, + ), + config.model_dtype, + ) + if bias: + self.bias = nn.Parameter((out_features,), config.model_dtype) + else: + self.bias = None + + @staticmethod + def from_linear(linear: nn.Linear, config: AWQQuantize) -> "AWQQuantizeLinear": + """ + Converts a non-quantized nn.Linear to a group quantized AWQQuantizeLinear + + Parameters + ---------- + linear : nn.Linear + The non-quantized nn.Linear. + + config : AWQQuantize + The awq quantization config. + + Returns + ------- + ret : GroupQuantizeLinear + The awq quantized AWQQuantizeLinear layer. + """ + return AWQQuantizeLinear( + in_features=linear.in_features, + out_features=linear.out_features, + config=config, + bias=getattr(linear, "bias", None) is not None, + out_dtype=linear.out_dtype, + ) + + def forward(self, x: nn.Tensor) -> nn.Tensor: # pylint: disable=invalid-name + """ + Forward method for awq quantized linear layer + + Parameters + ---------- + x : nn.Tensor + The input tensor. + + Returns + ------- + ret : nn.Tensor + The output tensor for the group quantized linear layer. + """ + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, zeros, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + zeros, + scale, + [tir.IntImm("int64", self.out_features), tir.IntImm("int64", self.in_features)], + ), + name_hint="decode", + args=[self.qweight, self.qzeros, self.scales], + ) + w = nn.op.permute_dims(w) # pylint: disable=invalid-name + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + return x + + +class AWQQuantizeMultiLinear(nn.Module): # pylint: disable=too-many-instance-attributes + """An nn.MultiLinear module with awq quantization.""" + + def __init__( # pylint: disable=too-many-arguments + self, + in_features: int, + out_features: nn.Sequence[int], + config: AWQQuantize, + bias: bool = True, + out_dtype: Optional[str] = None, + ): + assert len(out_features) > 0 + self.total_out_features = sum(out_features) + + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.out_dtype = out_dtype + self.config = config + self.qweight = nn.Parameter( + (self.total_out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), + config.storage_dtype, + ) + self.qzeros = nn.Parameter( + ( + self.total_out_features, + _calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage), + ), + dtype=config.storage_dtype, + ) + self.scales = nn.Parameter( + ( + self.total_out_features, + _calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage) + * config.num_elem_per_storage, + ), + config.model_dtype, + ) + if bias: + self.bias = nn.Parameter((self.total_out_features,), config.model_dtype) + else: + self.bias = None + + @staticmethod + def from_multilinear( + multi_linear: nn.MultiLinear, config: AWQQuantize + ) -> "AWQQuantizeMultiLinear": + """ + Converts a non-quantized nn.MultiLinear to a awq quantized AWQQuantizeLinear. + + Parameters + ---------- + linear : nn.MultiLinear + The non-quantized nn.MultiLinear + + config : AWQQuantize + The awq quantization config. + + Returns + ------- + ret : AWQQuantizeMultiLinear + The awq quantized AWQQuantizeMultiLinear layer. + """ + return AWQQuantizeMultiLinear( + in_features=multi_linear.in_features, + out_features=multi_linear.out_features, + config=config, + bias=getattr(multi_linear, "bias", None) is not None, + out_dtype=multi_linear.out_dtype, + ) + + def forward(self, x: nn.Tensor) -> Sequence[nn.Tensor]: # pylint: disable=invalid-name + """ + Forward method for multi linear layer. + + Parameters + ---------- + x : Tensor + The input tensor. + + Returns + ------- + ret : Tensor + The output tensor for the multi linear layer. + """ + sections = list(np.cumsum(self.out_features)[:-1]) + w = nn.op.tensor_expr_op( # pylint: disable=invalid-name + lambda weight, zeros, scale: self.config._dequantize( # pylint: disable=protected-access + weight, + zeros, + scale, + [ + tir.IntImm("int64", self.total_out_features), + tir.IntImm("int64", self.in_features), + ], + ), + name_hint="decode", + args=[self.qweight, self.qzeros, self.scales], + ) + w = nn.op.permute_dims(w) # pylint: disable=invalid-name + x = nn.op.matmul(x, w, out_dtype=self.out_dtype) + if self.bias is not None: + x = x + self.bias + results = nn.op.split(x, sections, axis=-1) + return results diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 1e43e5430c..935621173b 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -12,6 +12,7 @@ from tvm.target import Target from ..loader import QuantizeMapping +from .utils import convert_uint_to_float logger = logging.getLogger(__name__) @@ -126,25 +127,22 @@ def _dequantize( scale: te.Tensor, out_shape: Optional[List[tir.PrimExpr]] = None, ): - tir_bin_mask = tir.const((1 << DataType(self.quantize_dtype).bits) - 1, self.storage_dtype) tir_max_int = tir.const(self.max_int_value, self.model_dtype) + float_weight = convert_uint_to_float( + weight, + DataType(self.quantize_dtype).bits, + self.num_elem_per_storage, + self.storage_dtype, + self.model_dtype, + out_shape, + ) return te.compute( shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage] if out_shape is None else out_shape, fcompute=lambda i, j: tir.multiply( tir.subtract( - tir.bitwise_and( - tir.shift_right( - weight[i, j // self.num_elem_per_storage], - tir.Cast( - self.storage_dtype, - (j % self.num_elem_per_storage) - * DataType(self.quantize_dtype).bits, - ), - ), - tir_bin_mask, - ), + float_weight[i, j], tir_max_int, ), scale[i, j // self.group_size], diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py index 2efad4beb4..f84881c966 100644 --- a/python/mlc_chat/compiler/quantization/quantization.py +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -1,6 +1,7 @@ """A centralized registry of all existing quantization methods and their configurations.""" from typing import Any, Dict +from .awq_quantization import AWQQuantize from .group_quantization import GroupQuantize Quantization = Any @@ -31,4 +32,12 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr storage_dtype="uint32", model_dtype="float16", ), + "q4f16_awq": AWQQuantize( + name="q4f16_awq", + kind="awq", + group_size=128, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float16", + ), } diff --git a/python/mlc_chat/compiler/quantization/utils.py b/python/mlc_chat/compiler/quantization/utils.py new file mode 100644 index 0000000000..3470e42493 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/utils.py @@ -0,0 +1,34 @@ +"""Common utilities for quantization""" +from typing import List, Optional + +from tvm import te, tir + + +def convert_uint_to_float( # pylint: disable=too-many-arguments + weight: te.Tensor, + bits: int, + num_elem_per_storage: int, + storage_dtype: str, + model_dtype: str, + out_shape: Optional[List[tir.PrimExpr]] = None, +) -> te.Tensor: + """Convert a quantized uint weight to an unquantized float weight.""" + tir_bin_mask = tir.const((1 << bits) - 1, storage_dtype) + return te.compute( + shape=[weight.shape[0], weight.shape[1] * num_elem_per_storage] + if out_shape is None + else out_shape, + fcompute=lambda i, j: tir.Cast( + model_dtype, + tir.bitwise_and( + tir.shift_right( + weight[i, j // num_elem_per_storage], + tir.Cast( + storage_dtype, + (j % num_elem_per_storage) * bits, + ), + ), + tir_bin_mask, + ), + ), + ) diff --git a/tests/python/loader/test_awq.py b/tests/python/loader/test_awq.py new file mode 100644 index 0000000000..0003145399 --- /dev/null +++ b/tests/python/loader/test_awq.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring +import logging +from pathlib import Path +from typing import Union + +import pytest +import tvm + +from mlc_chat.compiler import MODEL_PRESETS, MODELS, QUANTIZATION +from mlc_chat.compiler.loader import HuggingFaceLoader +from mlc_chat.support import tqdm + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +@pytest.mark.parametrize( + "param_path", + [ + "./dist/models/llama-2-7b-w4-g128-awq.pt", + "./dist/models/Llama-2-7B-AWQ/model.safetensors", + ], +) +def test_load_llama(param_path: Union[str, Path]): + path_params = Path(param_path) + + model = MODELS["llama"] + quantization = QUANTIZATION["q4f16_awq"] + config = model.config.from_dict(MODEL_PRESETS["llama2_7b"]) + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["awq"](config, quantization), + ) + with tqdm.redirect(): + for _name, _param in loader.load(tvm.device("cpu")): + ... + + +if __name__ == "__main__": + test_load_llama(param_path="./dist/models/llama-2-7b-w4-g128-awq.pt") + test_load_llama(param_path="./dist/models/Llama-2-7B-AWQ/model.safetensors") diff --git a/tests/python/quantization/test_awq_quantization.py b/tests/python/quantization/test_awq_quantization.py new file mode 100644 index 0000000000..fbdb680cb0 --- /dev/null +++ b/tests/python/quantization/test_awq_quantization.py @@ -0,0 +1,90 @@ +# pylint: disable=invalid-name,missing-docstring +from typing import List + +import numpy as np +import pytest +import torch +import tvm +import tvm.testing +from tvm import DataType +from tvm.relax.frontend import nn + +from mlc_chat.compiler import QUANTIZATION +from mlc_chat.compiler.loader import QuantizeMapping +from mlc_chat.compiler.quantization import AWQQuantize + + +def dequantize_np( + config: AWQQuantize, + weight: np.ndarray, + zeros: np.ndarray, + scale: np.ndarray, +) -> np.ndarray: + def decode_int_arr(int_arr: np.ndarray, num_elem_per_storage: int, bits: int): + bin_mask = (1 << bits) - 1 + int_arr_repeated = np.repeat(int_arr, num_elem_per_storage, axis=-1) + indice_j = np.indices(int_arr_repeated.shape)[1] + arr_bin = np.bitwise_and( + np.right_shift( + int_arr_repeated, + (indice_j % num_elem_per_storage) * bits, + ), + bin_mask, + ) + return arr_bin + + weight_bin = decode_int_arr( + weight, config.num_elem_per_storage, DataType(config.quantize_dtype).bits + ) + zero_bin = decode_int_arr( + zeros, config.num_elem_per_storage, DataType(config.quantize_dtype).bits + ) + scale_repeated = np.repeat(scale, config.group_size, axis=-1) + zero_bin_repeated = np.repeat(zero_bin, config.group_size, axis=-1) + return (weight_bin - zero_bin_repeated) * scale_repeated + + +@pytest.mark.parametrize( + "quant_name, shape, dtype", + [ + ("q4f16_awq", [2, 4096], "float16"), + ], +) +def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): + class Test(nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype) + + def forward(self, x: nn.Tensor): + return self.linear(x) + + config = QUANTIZATION[quant_name] + assert isinstance(config, AWQQuantize) + weight_np = np.random.randint( + np.iinfo(config.storage_dtype).min, + np.iinfo(config.storage_dtype).max, + (shape[0], shape[1] // config.num_elem_per_storage), + ).astype(config.storage_dtype) + zeros_np = np.random.randint( + np.iinfo(config.storage_dtype).min, + np.iinfo(config.storage_dtype).max, + (shape[0], shape[1] // config.num_elem_per_storage // config.group_size), + ).astype(config.storage_dtype) + scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype( + config.model_dtype + ) + mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") + mod.linear.qweight.data = weight_np + mod.linear.qzeros.data = zeros_np + mod.linear.scales.data = scale_np + model = mod.jit(spec={"forward": {"x": nn.spec.Tensor((shape[1], shape[1]), dtype)}}) + out = model["forward"]( + torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member + ) + ref = dequantize_np(config, weight_np, zeros_np, scale_np).T + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_dequantize_weight("q4f16_awq", [2, 4096], "float16") From be1c18bd6271be1705e56da7ed76f1aeebcbb98e Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Tue, 7 Nov 2023 13:05:49 -0500 Subject: [PATCH 105/116] [Bugfix] Fix Cannot import name '_LIB' from 'mlc_chat.base' (#1214) Fix Python API doc --- python/mlc_chat/chat_module.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 9f306c14b6..1e47729ac9 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -13,7 +13,7 @@ import tvm from tvm.runtime import disco # pylint: disable=unused-import -from .base import _LIB # pylint: disable=unused-import +from . import base # pylint: disable=unused-import from .interface.openai_api import ChatMessage # pylint: disable=line-too-long @@ -794,14 +794,18 @@ def generate( Parameters ---------- - prompt : Union[str, List[ChatMessage]] + prompt: Union[str, List[ChatMessage]] The user input prompt, i.e. a question to ask the chat module. It can also be the whole conversation history (list of messages with role and content) - eg: ```[ - ChatMessage(role="user", content="Hello, how are you?"), - ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), - ChatMessage(role="user", content="I'm good too."), - ]``` + eg: + + .. code:: + + [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ] generation_config: Optional[GenerationConfig] The generation config object to override the ChatConfig generation settings. progress_callback: object @@ -1011,11 +1015,15 @@ def _prefill( input : Union[str, List[ChatMessage]] The user input prompt, i.e. a question to ask the chat module. It can also be the whole conversation history (list of messages with role and content) - eg: ```[ - ChatMessage(role="user", content="Hello, how are you?"), - ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), - ChatMessage(role="user", content="I'm good too."), - ]``` + eg: + + .. code:: + + [ + ChatMessage(role="user", content="Hello, how are you?"), + ChatMessage(role="assistant", content="I'm fine, thank you. How about you?"), + ChatMessage(role="user", content="I'm good too."), + ] decode_next_token : bool Whether to decode the next token after prefilling. place_in_prompt: PlaceInPrompt From 1015aaecbb7b1af1209c294ed9cceb31d1f1982e Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 7 Nov 2023 17:01:42 -0800 Subject: [PATCH 106/116] [SLM] Support `q3f16_1` and `q4f32_1` (#1215) This PR supports the int3 and float32 group quantization, and fixes some minor issue in quantization impl and tests. --- .../quantization/group_quantization.py | 15 ++++++--- .../compiler/quantization/quantization.py | 16 ++++++++++ .../quantization/test_group_quantization.py | 32 +++++++++++++++---- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/python/mlc_chat/compiler/quantization/group_quantization.py b/python/mlc_chat/compiler/quantization/group_quantization.py index 935621173b..ecce18d3c0 100644 --- a/python/mlc_chat/compiler/quantization/group_quantization.py +++ b/python/mlc_chat/compiler/quantization/group_quantization.py @@ -222,8 +222,11 @@ def _quantize( # pylint: disable=too-many-locals max_abs = te.compute( shape=scale_shape, fcompute=lambda i, j: te.max( - te.abs(weight[i, j * self.group_size + r]), - where=j * self.group_size + r < k, + tir.if_then_else( + j * self.group_size + r < k, + te.abs(weight[i, j * self.group_size + r]), + te.min_value(self.model_dtype), + ), axis=r, ), name="max_abs_value", @@ -251,9 +254,13 @@ def _quantize( # pylint: disable=too-many-locals quantized_weight = te.compute( shape=quantized_weight_shape, fcompute=lambda i, j: tir.sum( - scaled_weight[i, j * self.num_elem_per_storage + r] << (r * quantize_dtype.bits), + tir.if_then_else( + j * self.num_elem_per_storage + r < k, + scaled_weight[i, j * self.num_elem_per_storage + r] + << (r * quantize_dtype.bits), + 0, + ), axis=r, - where=j * self.num_elem_per_storage + r < k, ), name="weight", ) diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py index f84881c966..bae8d07aec 100644 --- a/python/mlc_chat/compiler/quantization/quantization.py +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -24,6 +24,14 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr """ QUANTIZATION: Dict[str, Quantization] = { + "q3f16_1": GroupQuantize( + name="q3f16_1", + kind="group-quant", + group_size=40, + quantize_dtype="int3", + storage_dtype="uint32", + model_dtype="float16", + ), "q4f16_1": GroupQuantize( name="q4f16_1", kind="group-quant", @@ -32,6 +40,14 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr storage_dtype="uint32", model_dtype="float16", ), + "q4f32_1": GroupQuantize( + name="q4f32_1", + kind="group-quant", + group_size=32, + quantize_dtype="int4", + storage_dtype="uint32", + model_dtype="float32", + ), "q4f16_awq": AWQQuantize( name="q4f16_awq", kind="awq", diff --git a/tests/python/quantization/test_group_quantization.py b/tests/python/quantization/test_group_quantization.py index 106d0f5fb5..04b23e91d3 100644 --- a/tests/python/quantization/test_group_quantization.py +++ b/tests/python/quantization/test_group_quantization.py @@ -35,8 +35,10 @@ def quantize_np(config: GroupQuantize, weight: np.ndarray): 0, config.max_int_value * 2, ).astype(config.storage_dtype) + weight_filtered = np.reshape(weight_scaled_reshaped, (n, k)) + weight_filtered[..., weight.shape[1] :] = 0 weight_scaled = np.reshape( - weight_scaled_reshaped, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) + weight_filtered, (n, k // config.num_elem_per_storage, config.num_elem_per_storage) ) indice_k = np.indices(weight_scaled.shape, dtype=config.storage_dtype)[-1] quantized_weight = np.sum( @@ -53,6 +55,7 @@ def dequantize_np( scale: np.ndarray, out_shape: List[int] = None, ): + assert weight.shape[0] == scale.shape[0] bin_mask = (1 << DataType(config.quantize_dtype).bits) - 1 max_int = config.max_int_value out_shape = ( @@ -70,13 +73,21 @@ def dequantize_np( ), bin_mask, ) - return ((weight_bin - max_int) * scale_repeated)[: out_shape[0]][: out_shape[1]] + assert weight_bin.shape[1] <= scale_repeated.shape[1] + return ((weight_bin - max_int) * scale_repeated[..., : weight_bin.shape[1]])[ + : out_shape[0], : out_shape[1] + ] @pytest.mark.parametrize( "quant_name, shape, dtype, device", [ + ("q3f16_1", [2, 13], "float16", "cpu"), + ("q3f16_1", [16, 120], "float16", "cpu"), + ("q4f16_1", [2, 13], "float16", "cpu"), ("q4f16_1", [16, 128], "float16", "cpu"), + ("q4f32_1", [2, 13], "float32", "cpu"), + ("q4f32_1", [16, 128], "float32", "cpu"), ], ) def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str): @@ -90,15 +101,20 @@ def test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: tvm.testing.assert_allclose( dequantize_np(config, quantized_weight, scale, shape), dequantize_np(config, quantized_weight_ref, scale_ref, shape), - rtol=1e-3, - atol=0.2, + rtol=1e-2 if quant_name.startswith("q3") else 1e-3, + atol=0.4 if quant_name.startswith("q3") else 0.2, ) @pytest.mark.parametrize( "quant_name, shape, dtype", [ + ("q3f16_1", [2, 13], "float16"), + ("q3f16_1", [16, 120], "float16"), + ("q4f16_1", [2, 13], "float16"), ("q4f16_1", [16, 128], "float16"), + ("q4f32_1", [2, 13], "float32"), + ("q4f32_1", [16, 128], "float32"), ], ) def test_dequantize_weight(quant_name: str, shape: List[int], dtype: str): @@ -115,9 +131,9 @@ def forward(self, x: nn.Tensor): weight_np = np.random.randint( np.iinfo(config.storage_dtype).min, np.iinfo(config.storage_dtype).max, - (shape[0], shape[1] // config.num_elem_per_storage), + (shape[0], -(shape[1] // -config.num_elem_per_storage)), ).astype(config.storage_dtype) - scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype( + scale_np = np.random.random((shape[0], -(shape[1] // -config.group_size))).astype( config.model_dtype ) mod = config.quantize_model(Test(), QuantizeMapping({}, {}), "") @@ -127,14 +143,16 @@ def forward(self, x: nn.Tensor): out = model["forward"]( torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype))) # pylint: disable=no-member ) - ref = dequantize_np(config, weight_np, scale_np).T + ref = dequantize_np(config, weight_np, scale_np, shape).T tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize( "quant_name, shape, dtype", [ + ("q3f16_1", [16, 128], "float16"), ("q4f16_1", [16, 128], "float16"), + ("q4f32_1", [16, 128], "float32"), ], ) def test_quantize_model(quant_name: str, shape: List[int], dtype: str): From 1a6faddd0c13c23e5560ec124d78797be9140c15 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Tue, 7 Nov 2023 23:32:43 -0800 Subject: [PATCH 107/116] Make the Compilation Working E2E (#1218) --- python/mlc_chat/compiler/model/llama_model.py | 94 +++++++++-- .../mlc_chat/compiler/quantization/utils.py | 18 +-- tests/legacy-python/module_intercept.py | 147 ++++++++++++++++++ 3 files changed, 233 insertions(+), 26 deletions(-) create mode 100644 tests/legacy-python/module_intercept.py diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 023db05e82..27d7db0825 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -27,7 +27,7 @@ class LlamaConfig(ConfigBase): # pylint: disable=too-many-instance-attributes num_hidden_layers: int rms_norm_eps: float vocab_size: int - position_embedding_base: int = 10000 + position_embedding_base: int = 0 max_sequence_length: int = 0 num_key_value_heads: int = 0 head_dim: int = 0 @@ -49,6 +49,11 @@ def __post_init__(self): "`max_sequence_length` nor `max_position_embeddings` is provided " "in `config.json`." ) + if self.position_embedding_base == 0: + if "rope_theta" in self.kwargs: + self.position_embedding_base = self.kwargs.pop("rope_theta") + else: + self.position_embedding_base = 10000 if self.num_key_value_heads == 0: self.num_key_value_heads = self.num_attention_heads if self.head_dim == 0: @@ -60,6 +65,69 @@ def __post_init__(self): # pylint: disable=invalid-name,missing-docstring +class RMSNorm(nn.Module): + """ + Module for rms norm layer. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + hidden_size: int, + axes, # pylint: disable=unused-argument + epsilon: float = 1e-5, + bias: bool = True, + dtype: Optional[str] = None, + ): + super().__init__() + self.epsilon = epsilon + self.weight = nn.Parameter((hidden_size,), dtype=dtype) + if bias: + self.bias = nn.Parameter((hidden_size,), dtype=dtype) + else: + self.bias = None + + def forward(self, x: Tensor): + """ + Forward method for rms norm layer. + + Parameters + ---------- + x : Tensor + The input tensor. + + Returns + ------- + ret : Tensor + The output tensor for the rms norm layer. + """ + + def f_square(x): + x = x.astype("float32") + return x * x + + def f_div_mult(x, square_sum, weight, *indices): + *i, k = indices + s = tir.sqrt(square_sum[*i] / x.shape[-1] + self.epsilon) + s = x[*i, k].astype("float32") / s + s = (weight[k] * s).astype(x.dtype) + return s + + def te_op(x: te.Tensor, weight: te.Tensor): + k = te.reduce_axis((0, x.shape[-1]), name="k") + square_sum = te.compute( + x.shape[:-1], + lambda *i: te.sum(f_square(x[*i, k]), axis=k), + name=x.op.name + "red_temp", + ) + return te.compute( + x.shape, + lambda *i: f_div_mult(x, square_sum, weight, *i), + name="rms_norm", + ) + + return op.tensor_expr_op(te_op, "rms_norm", args=[x, self.weight]) + + class RotaryEmbedding(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() @@ -80,9 +148,9 @@ def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var): freq = (offset + s) / freq cos = tir.cos(freq).astype(dtype) * x[b, s, h, d] sin = tir.sin(freq).astype(dtype) * tir.if_then_else( - d < self.head_dim // 2, - -x[b, s, h, d + self.head_dim // 2], - x[b, s, h, d - self.head_dim // 2], + d < head_dim // 2, + -x[b, s, h, d + head_dim // 2], + x[b, s, h, d - head_dim // 2], ) return cos + sin @@ -146,8 +214,8 @@ def forward( # pylint: disable=too-many-locals self.k_cache.append(op.squeeze(k, axis=0)) self.v_cache.append(op.squeeze(v, axis=0)) - k = op.reshape(self.k_cache.view(total_seq_len), (b, t, h_kv, d)) - v = op.reshape(self.v_cache.view(total_seq_len), (b, t, h_kv, d)) + k = op.reshape(self.k_cache.view(t), (b, t, h_kv, d)) + v = op.reshape(self.v_cache.view(t), (b, t, h_kv, d)) if h_kv != h_q: k = k.repeat(h_q // h_kv, axis=2) v = v.repeat(h_q // h_kv, axis=2) @@ -163,11 +231,9 @@ def forward( # pylint: disable=too-many-locals attn_weights = op.softmax(attn_weights, axis=-1) else: attn_weights = op.softmax(attn_weights.astype("float32"), axis=-1).astype(dtype) - return self.o_proj( - op.matmul(attn_weights, v) # [b, h, s, t] x [b, h, t, d] = [b, h, s, d] - .permute_dims([0, 2, 1, 3]) # [b, s, h, d] - .reshape((b, s, h_q * d)) - ) + # [b, h, s, t] x [b, h, t, d] => [b, h, s, d] => [b, s, h, d] + output = op.matmul(attn_weights, v) + return self.o_proj(output.permute_dims([0, 2, 1, 3]).reshape((b, s, h_q * d))) class LlamaDecoderLayer(nn.Module): @@ -175,8 +241,8 @@ def __init__(self, config: LlamaConfig, rotary_embedding: RotaryEmbedding): rms_norm_eps = config.rms_norm_eps self.self_attn = LlamaAttention(config, rotary_embedding) self.mlp = LlamaFFN(config) - self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) - self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.input_layernorm = RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) + self.post_attention_layernorm = RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False) def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var): hidden_states = ( @@ -195,7 +261,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.ModuleList( [LlamaDecoderLayer(config, rotary_embedding) for _ in range(config.num_hidden_layers)] ) - self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) + self.norm = RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False) def forward(self, inputs: Tensor, total_seq_len: tir.Var, attention_mask: Tensor): hidden_states = self.embed_tokens(inputs) diff --git a/python/mlc_chat/compiler/quantization/utils.py b/python/mlc_chat/compiler/quantization/utils.py index 3470e42493..9a879d2e96 100644 --- a/python/mlc_chat/compiler/quantization/utils.py +++ b/python/mlc_chat/compiler/quantization/utils.py @@ -18,17 +18,11 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments shape=[weight.shape[0], weight.shape[1] * num_elem_per_storage] if out_shape is None else out_shape, - fcompute=lambda i, j: tir.Cast( - model_dtype, - tir.bitwise_and( - tir.shift_right( - weight[i, j // num_elem_per_storage], - tir.Cast( - storage_dtype, - (j % num_elem_per_storage) * bits, - ), - ), - tir_bin_mask, + fcompute=lambda i, j: tir.bitwise_and( + tir.shift_right( + weight[i, j // num_elem_per_storage], + ((j % num_elem_per_storage) * bits).astype(storage_dtype), ), - ), + tir_bin_mask, + ).astype(model_dtype), ) diff --git a/tests/legacy-python/module_intercept.py b/tests/legacy-python/module_intercept.py new file mode 100644 index 0000000000..e63bb21de6 --- /dev/null +++ b/tests/legacy-python/module_intercept.py @@ -0,0 +1,147 @@ +"""This script is an example of running and comparing the outputs of two different TVM Relax VMs. +""" +# pylint: disable=missing-docstring,invalid-name +import json + +import numpy as np +import torch +import tvm +from transformers import LlamaTokenizer +from tvm import relax +from tvm.contrib import tvmjs + +KVCACHE_FUNCS = [ + "vm.builtin.attention_kv_cache_append", + "vm.builtin.attention_kv_cache_view", +] +DEVICE = "cuda:0" +PROMPT = "What is the meaning of life?" +TOKENIZER = "./dist/debug-llama/" + +COMBO = { + "CURRENT": { + "model_lib": "./dist/debug-llama/llama.so", + "params": "./dist/debug-llama", + "target_func": "fused_fused_dequantize1_NT_matmul6", + }, + "LEGACY": { + "model_lib": "./dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so", + "params": "./dist/Llama-2-7b-chat-hf-q4f16_1/params", + "target_func": "fused_fused_decode2_NT_matmul", + }, +} + + +class Instrument: # pylint: disable=too-few-public-methods + def __init__( + self, + target_func: str, + ): + self.first_time = True + self.target_func = target_func + self.saved_args = [] # type: ignore + + def __call__( + self, + func, + func_symbol: str, + before_run: bool, + ret_value, + *args, + ): + if before_run: + return + if func_symbol.startswith("vm.builtin."): + if func_symbol not in KVCACHE_FUNCS: + return + if func_symbol == self.target_func and self.first_time: + self.first_time = False + for arg in args: + print(arg.shape, arg.dtype) + self.saved_args.append(arg.numpy()) + + +class TestState: + def __init__(self, device, model_lib, target_func): + self.mod = relax.VirtualMachine( + tvm.runtime.load_module(model_lib), + device, + ) + self.inst = Instrument(target_func=target_func) + self.mod.set_instrument(self.inst) + + +def _tokenize(sentence: str): + tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER, trust_remote_code=True) + tokens = tokenizer(PROMPT, return_tensors="pt").input_ids.to(torch.int32).numpy() + print(f"Tokenizing: {sentence}") + print(f"Tokens: {tokens}") + return tokens + + +def _load_params(params, device, metadata): + param_dict, _ = tvmjs.load_ndarray_cache(params, device) + param_list = [] + for name in [x["name"] for x in metadata["params"]]: + param_list.append(param_dict[name]) + return param_list + + +def _load_params_legacy(params, device): + param_dict, metadata = tvmjs.load_ndarray_cache(params, device) + param_list = [] + for i in range(metadata["ParamSize"]): + param_list.append(param_dict[f"param_{i}"]) + return param_list + + +def _as_input_tuple(scalar): + return tvm.runtime.ShapeTuple([scalar]) + + +@tvm.register_func("debug_save") +def _debug_save(x, _): + return tvm.nd.array(x.numpy(), x.device) + + +def main() -> None: + device = tvm.device(DEVICE) + prompt = _tokenize(PROMPT) + + def _run_legacy(model_lib, params, target_func): + state = TestState(device, model_lib, target_func) + kv_cache = state.mod["create_kv_cache"]() + param_list = _load_params_legacy(params, device) + state.mod["prefill"]( + tvm.nd.array(prompt, device), + _as_input_tuple(len(prompt[0])), + kv_cache, + param_list, + ) + return state.inst.saved_args + + def _run_current(model_lib, params, target_func): + state = TestState(device, model_lib, target_func) + metadata = json.loads(state.mod["_metadata"]()) + kv_cache = state.mod["_initialize_effect"]() + param_list = _load_params(params, device, metadata) + state.mod["prefill"]( + tvm.nd.array(prompt, device), + _as_input_tuple(len(prompt[0])), + kv_cache, + param_list, + ) + return state.inst.saved_args + + print("============== Running old flow =================") + new_args = _run_current(**COMBO["CURRENT"]) + print("============== Running new flow =================") + old_args = _run_legacy(**COMBO["LEGACY"]) + + for i, (new_arg, old_arg) in enumerate(zip(new_args, old_args)): + print(f"Checking arg {i}") + np.testing.assert_allclose(new_arg, old_arg, rtol=1e-12, atol=1e-12) + + +if __name__ == "__main__": + main() From 616ca42ccf905bafe4afff1756d7ec2150654af4 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Wed, 8 Nov 2023 11:15:20 -0500 Subject: [PATCH 108/116] [Mistral][SWA] Add sliding window to metadata (#1217) Add sliding window to metadata, make smalle changes to invariants in runtime --- cpp/llm_chat.cc | 15 ++++++++------- mlc_llm/relax_model/commons.py | 4 ++++ mlc_llm/relax_model/mistral.py | 5 +++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index ff9ca5f5d9..4237febd9c 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -397,10 +397,16 @@ class LLMChat { CHECK(!config.count("max_window_size")) << "Cannot specify both sliding_window and max_window_size."; this->sliding_window_ = config["sliding_window"].get(); + CHECK(this->sliding_window_ > 0) << "Sliding window size needs to be positive"; + CHECK(config.count("sliding_window_chunk_size")) + << "Need to specify chunk size if using sliding window attention."; } if (config.count("sliding_window_chunk_size")) { CHECK(config["sliding_window_chunk_size"].is()); this->sliding_window_chunk_size_ = config["sliding_window_chunk_size"].get(); + CHECK(this->sliding_window_chunk_size_ > 0) + << "Sliding window chunk size needs to be positive"; + CHECK(config.count("sliding_window")) << "Need to specify sliding window size."; } if (config.count("model_name")) { CHECK(config["model_name"].is()); @@ -816,13 +822,8 @@ class LLMChat { NDArray logits_on_device; if (this->sliding_window_ != -1) { // Use chunking if we use sliding window attention (see Mistral paper figure 3). - int64_t sliding_window_chunk_size = this->sliding_window_chunk_size_; - if (this->sliding_window_chunk_size_ == -1) { - // One chunk if chunk size not specified - sliding_window_chunk_size = token_len; - } - for (int64_t begin = 0; begin < token_len; begin += sliding_window_chunk_size) { - int64_t end = std::min(token_len, begin + sliding_window_chunk_size); + for (int64_t begin = 0; begin < token_len; begin += this->sliding_window_chunk_size_) { + int64_t end = std::min(token_len, begin + this->sliding_window_chunk_size_); std::vector chunk = std::vector(prompt_tokens.begin() + begin, prompt_tokens.begin() + end); new_seq_len += static_cast(chunk.size()); diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 82dd5c367b..4924c2f015 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -10,6 +10,8 @@ def create_metadata_func( max_window_size: int, stop_tokens: List[int], add_prefix_space: bool, + sliding_window: int = -1, + sliding_window_chunk_size: int = -1, ): metadata = json.dumps( { @@ -17,6 +19,8 @@ def create_metadata_func( "max_window_size": max_window_size, "stop_tokens": stop_tokens, "add_prefix_space": add_prefix_space, + "sliding_window": sliding_window, + "sliding_window_chunk_size": sliding_window_chunk_size, } ) with bb.function("get_metadata", params=[]): diff --git a/mlc_llm/relax_model/mistral.py b/mlc_llm/relax_model/mistral.py index 1ef00ff577..31ed39fdb5 100644 --- a/mlc_llm/relax_model/mistral.py +++ b/mlc_llm/relax_model/mistral.py @@ -949,6 +949,9 @@ def get_model(args, hf_config): sliding_window_chunk_size=args.sliding_window_chunk_size, ) + assert config.sliding_window != -1 + assert config.sliding_window_chunk_size != -1 + param_manager = ParamManager() bb = relax.BlockBuilder() @@ -962,6 +965,8 @@ def get_model(args, hf_config): max_window_size=config.max_sequence_length, stop_tokens=[2], add_prefix_space=False, + sliding_window=config.sliding_window, + sliding_window_chunk_size=config.sliding_window_chunk_size, ) mod = bb.get() From e52f449930c218093b7bc82e559bed54b59cb82d Mon Sep 17 00:00:00 2001 From: Antonio Calatrava Date: Wed, 8 Nov 2023 17:36:27 +0100 Subject: [PATCH 109/116] Support for `chatml` format conversation (for TinyLlama-1.1B-Chat-v0.2) (#956) * added support for chatml format conversation * added template to factory --- cpp/conv_templates.cc | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/cpp/conv_templates.cc b/cpp/conv_templates.cc index dd90a67fb5..b0e47b27a9 100644 --- a/cpp/conv_templates.cc +++ b/cpp/conv_templates.cc @@ -7,6 +7,27 @@ namespace mlc { namespace llm { namespace { +Conversation ChatML() { + Conversation conv; + conv.name = "chatml"; + conv.roles = {"<|im_start|>user", "<|im_start|>assistant"}; + conv.system = + ("<|im_start|>system A conversation between a user and an LLM-based AI assistant. The " + "assistant gives helpful and honest answers.<|im_end|> "); + conv.messages = {}; + conv.offset = 0; + conv.separator_style = SeparatorStyle::kSepRoleMsg; + conv.seps = {"<|im_end|>", "<|im_end|>"}; + conv.role_msg_sep = "\n"; + conv.role_empty_sep = "\n"; + // TODO(mlc-team): add eos to mlc-chat-config + // and remove eos from stop token setting. + conv.stop_tokens = {2}; + conv.stop_str = "<|im_end|>"; + conv.add_bos = true; + return conv; +} + Conversation LlamaDefault() { Conversation conv; conv.name = "llama_default"; @@ -583,6 +604,7 @@ using ConvFactory = Conversation (*)(); Conversation Conversation::FromTemplate(const std::string& name) { static std::unordered_map factory = { + {"chatml", ChatML}, {"llama_default", LlamaDefault}, {"llama-2", Llama2}, {"mistral_default", MistralDefault}, From fbe75e33f13ea7bb18b7ee7fd961ea310c492aaa Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 8 Nov 2023 09:21:25 -0800 Subject: [PATCH 110/116] Add Rust Support for MLC-LLM (#1213) This PR introduces Rust language support for the MLC-LLM project, specifically targeting supporting the `ChatModule` interface. It utilizes the existing C++ implementation of MLC-LLM and leverages both TVM's C API and its Rust bindings. The `rust/examples/mlc_chat.rs` gives an example of how to create a `chat_module` and serve user prompts in Rust. The primary goal of this PR is to enrich the MLC-LLM ecosystem by offering a Rust interface that aligns with the current Python API. This enhancement will empower Rust developers to integrate MLC-LLM into their codebase and applications. **Followup PRs**: - Extend the feature set to achieve parity with the C++/Python interface. - Refine the Rust API, ensuring robustness. - Set up Rust CI if needed. --- rust/.gitignore | 20 ++ rust/Cargo.toml | 13 ++ rust/README.md | 25 +++ rust/build.rs | 4 + rust/examples/mlc_chat.rs | 10 + rust/src/chat_module.rs | 445 ++++++++++++++++++++++++++++++++++++++ rust/src/config.rs | 276 +++++++++++++++++++++++ rust/src/lib.rs | 5 + 8 files changed, 798 insertions(+) create mode 100644 rust/.gitignore create mode 100644 rust/Cargo.toml create mode 100644 rust/README.md create mode 100644 rust/build.rs create mode 100644 rust/examples/mlc_chat.rs create mode 100644 rust/src/chat_module.rs create mode 100644 rust/src/config.rs create mode 100644 rust/src/lib.rs diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 0000000000..c5e4e0d10a --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,20 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# Generated by Rust +**/*.rs.bk +/examples/pkg + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb + +# IDE files +.idea/ +*.iml +.vscode/ diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 0000000000..58cc03f40b --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "mlc-llm" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tvm-rt = { path = "../3rdparty/tvm/rust/tvm-rt", features = ["dynamic-linking"] } +tracing = "0.1.32" +derive_builder = "0.12.0" +serde = { version = "1.0.160", features = ["derive"] } +serde_json = "1.0.107" diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 0000000000..8c92525772 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,25 @@ +# MLC-LLM Rust Package + +This folder contains the source code of MLC-LLM Rust package. + +# Installations +To set up the MLC-LLM Rust package, please follow these steps: + +**Step 1:** Begin by following the detailed installation [instructions](https://llm.mlc.ai/docs/deploy/rest.html#optional-build-from-source) for TVM Unity and MLC-LLM. + +**Step 2:** Define the environment variables for TVM and MLC-LLM by running the following commands in your terminal: +```bash +export TVM_HOME=/path/to/tvm +export MLC_HOME=/path/to/mlc-llm +``` + +**Step 3:** Update your `LD_LIBRARY_PATH` to include the `libtvm_runtime` and `libmlc_llm_module` libraries. These can typically be found within the build directories of your TVM and MLC-LLM installations. + +# How to run it? +To start using the package, you can refer to the example code provided in the examples directory. This code demonstrates how to create a chat_module and serve prompts effectively. + +Execute the example with Cargo using the following command: +```bash +cargo run --example mlc_chat +``` + diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 0000000000..d8e01a77e4 --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,4 @@ +fn main() { + println!("cargo:rustc-link-lib=dylib=mlc_llm_module"); + println!("cargo:rustc-link-search=native={}/build", env!("MLC_HOME")); +} diff --git a/rust/examples/mlc_chat.rs b/rust/examples/mlc_chat.rs new file mode 100644 index 0000000000..2e87d56946 --- /dev/null +++ b/rust/examples/mlc_chat.rs @@ -0,0 +1,10 @@ +extern crate mlc_llm; + +use mlc_llm::chat_module::ChatModule; + +fn main() { + let cm = ChatModule::new("/path/to/Llama2-13B-q8f16_1", "rocm", None).unwrap(); + let output = cm.generate("what is the meaning of life?", None).unwrap(); + println!("resp: {:?}", output); + println!("stats: {:?}", cm.stats(false)); +} diff --git a/rust/src/chat_module.rs b/rust/src/chat_module.rs new file mode 100644 index 0000000000..831905eee8 --- /dev/null +++ b/rust/src/chat_module.rs @@ -0,0 +1,445 @@ +use std::fs; +use std::path::{Path, PathBuf}; +use std::result; +use tracing::info; +use tvm_rt::{function::Function, Module}; + +use super::config::*; + +#[derive(Debug)] +pub enum ChatModuleError { + /// Global function in a TVM Module is not found + GlobalFuncNotFound, + /// TVM Runtime error + TvmRuntime(tvm_rt::Error), +} + +impl From for ChatModuleError { + fn from(e: tvm_rt::Error) -> Self { + Self::TvmRuntime(e) + } +} + +pub type Result = result::Result; + +/// The ChatModule for MLC LLM. +/// +/// # Examples +/// +/// ``` +/// use mlc_llm::chat_module::ChatModule; +/// +/// // Create a ChatModule instance +/// let cm = ChatModule::new("Llama-2-7b-chat-hf-q4f16_1", "cuda", None, None).unwrap(); +/// +/// // Generate a response for a given prompt +/// let output = cm.generate("What is the meaning of life?", None).unwrap(); +/// +/// // Print prefill and decode performance statistics +/// println!("Statistics: {:?}\n", cm.stats(false).unwrap()); +/// +/// let output = cm.generate("What is Rust?", None).unwrap(); +/// ``` +pub struct ChatModule { + chat_module: Module, + chat_config: ChatConfig, +} + +#[derive(Debug, Copy, Clone)] +pub enum PlaceInPrompt { + All = 0, + Begin = 1, + Middle = 2, + End = 3, +} + +impl PlaceInPrompt { + pub fn to_value(&self) -> i32 { + *self as i32 + } +} + +/// Parse the input device identifier into device name and id. +/// +/// # Parameters +/// * `device` - The device identifier to parse. It can be in the format "device_name" (e.g., "cuda") +/// or "device_name:device_id" (e.g., "cuda:1"). +/// +/// # Returns +/// * `device_name` - The name of the device. +/// * `device_id` - The id of the device, or 0 if not specified in the input. +fn parse_device_str(device: &str) -> (&str, i32) { + let device_err_msg = format!( + "Invalid device name: {}. Please enter the device in the form \ + 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ + one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", + device + ); + let device_args: Vec<&str> = device.split(':').collect(); + match device_args.len() { + 1 => (device_args[0], 0), + 2 => (device_args[0], device_args[1].parse::().unwrap()), + _ => panic!("{}", device_err_msg), + } +} + +/// Use user-provided argument `model` to search for a valid model path. +/// We define "valid" as having an `mlc-chat-config.json` right under the folder. +/// +/// # Parameters +/// * `model`: User's input; may be a compiled model's name, or a full path. +/// +/// # Returns +/// * `model_path`: A "valid" path to model folder with `mlc-chat-config.json` existing under it. +/// * `chat_file`: The path to the `mlc-chat-config.json` file. +/// +/// # Panics +/// * If a valid model_path cannot be found. +pub fn get_model_path(model: &str) -> (PathBuf, PathBuf) { + // Note that the order of this list corresponds to our search priority + let candidate_paths = vec![ + PathBuf::from(model), // full path, or just the name + PathBuf::from(format!("{}/params", model)), // Default directory after mlc_llm.build_model() + PathBuf::from(format!("dist/prebuilt/{}", model)), // Using prebuilt workflow + PathBuf::from(format!("dist/{}/params", model)), // Default directory after mlc_llm.build_model() in the current path + PathBuf::from(format!("dist/prebuilt/mlc-chat-{}", model)), // Also prebuilt workflow, but missed prefix + ]; + + // Look for the first folder that has `mlc-chat-config.json` under it + for candidate in &candidate_paths { + let chat_file = candidate.join("mlc-chat-config.json"); + if chat_file.is_file() { + info!( + "Using model folder: {:?}", + candidate.canonicalize().unwrap() + ); + info!( + "Using mlc chat config: {:?}", + chat_file.canonicalize().unwrap() + ); + return (candidate.clone(), chat_file); + } + } + + let mut found_folder = false; + let mut valid_dir_str = String::new(); + for candidate in &candidate_paths { + if candidate.is_dir() { + valid_dir_str += &format!("- {:?}\n", candidate.canonicalize().unwrap()); + found_folder = true; + } + } + + if found_folder { + // Error 1: there is a folder, but not an mlc-llm model folder (E1) + let err_msg = format!( + "The model folder provided does not seem to refer to a valid mlc-llm model folder.\n\ + Specifically, we cannot find `mlc-chat-config.json`, a required file. You should \ + provide a path that contains the file.\n\ + According to your input `model`, we looked at folder(s):\n\ + {}\n\ + MLC-Chat consumes models that are processed by the MLC-LLM build process.\n\ + ", + valid_dir_str, + ); + panic!("{}", err_msg); + } else { + // Error 2: cannot find a folder (E0) + let all_paths_str = candidate_paths + .iter() + .map(|path| format!("- {}\n", path.display())) + .collect::(); + let err_msg = format!( + "Cannot find the model folder. We searched over the following possible paths:\n\ + {}\n\ + You can try to pass in `model=/path/to/your-model-path`, and confirm \ + that it contains `mlc-chat-config.json`, among other essential files.\n\ + ", + all_paths_str, + ); + panic!("{}", err_msg); + } +} + +/// Read in the config file in model path, then potentially override with user input. +/// +/// # Parameters: +/// * `config_file_path`: &Path +/// `chat_file` returned by a function like `get_model_path()`. +fn get_chat_config( + config_file_path: &Path, +) -> result::Result> { + // Read the base configuration from the file + let file_contents = fs::read_to_string(config_file_path)?; + let final_chat_config = ChatConfig::from_json(&file_contents)?; + Ok(final_chat_config) +} + +fn get_lib_module_path( + model: &str, + model_path: &Path, + chat_config: &ChatConfig, + model_lib_path: Option<&str>, + device_name: &str, + config_file_path: &Path, +) -> PathBuf { + // 1. Use user's model_lib_path if provided + if let Some(lib_path) = model_lib_path { + let path = Path::new(lib_path); + if path.is_file() { + info!("Using library model: {:?}", path); + return path.to_path_buf(); + } else { + panic!( + "The `model_lib_path` you passed in is not a file: {:?}.", + lib_path + ); + } + } + + // 2. Generate all possible file names according to OS + let mut candidate_paths = Vec::new(); + if let Some(model_lib) = &chat_config.model_lib { + let candidate_lib_names: Vec = if cfg!(target_os = "linux") { + vec![format!("{}-{}.so", model_lib, device_name)] + } else if cfg!(target_os = "macos") { + vec![ + format!("{}-{}.dylib", model_lib, device_name), + format!("{}-{}.so", model_lib, device_name), + ] + } else if cfg!(target_os = "windows") { + vec![format!("{}-{}.dll", model_lib, device_name)] + } else { + vec![ + format!("{}-{}.dylib", model_lib, device_name), + format!("{}-{}.so", model_lib, device_name), + format!("{}-{}.dll", model_lib, device_name), + ] + }; + + // 3. Generate possible model library paths + let pardir_model_path = model_path.parent().unwrap(); + for lib_name in &candidate_lib_names { + let paths: Vec = vec![ + lib_name.clone(), + format!("dist/prebuilt/lib/{}", lib_name), + format!("dist/{}/{}", model, lib_name), + model_path.join(lib_name).to_string_lossy().into_owned(), + pardir_model_path + .join(lib_name) + .to_string_lossy() + .into_owned(), + ]; + + candidate_paths.extend(paths); + } + + // 4. Search for model library + for candidate in &candidate_paths { + let candidate_path = Path::new(candidate); + if candidate_path.is_file() { + info!("Using library model: {:?}", candidate_path); + return candidate_path.to_path_buf(); + } + } + + // 5. Error + let mut err_msg = format!( + "Cannot find the model library that corresponds to `{:?}`.\n\ + `{:?}` is either provided in the `chat_config` \ + you passed in, or specified in {:?}.\n\ + We searched over the following possible paths: \n", + model_lib, model_lib, config_file_path + ); + for candidate in &candidate_paths { + err_msg += &format!("- {}\n", candidate); + } + err_msg += &format!( + "If you would like to directly specify the model library path, you may \ + consider passing in the `ChatModule.model_lib_path` parameter." + ); + + panic!("{}", err_msg); + } else { + panic!("Cannot find the model library, you need to either pass it in, or specify in the chat_config file."); + } +} + +impl ChatModule { + pub fn new(model: &str, device: &str, model_lib_path: Option<&str>) -> Result { + let device_err_msg = format!( + "Invalid device name: {}. Please enter the device in the form \ + 'device_name:device_id' or 'device_name', where 'device_name' needs to be \ + one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'.", + device + ); + + let (device_name, device_id) = parse_device_str(device); + + // 1. Get device name and id + let device_type = match device_name { + "cude" => 2, + "opencl" => 4, + "vulkan" => 7, + "metal" => 8, + "rocm" => 10, + _ => panic!("{}", device_err_msg), + }; + + static GLOBAL_FUNC_NAME: &str = "mlc.llm_chat_create"; + let f = Function::get(GLOBAL_FUNC_NAME).ok_or(ChatModuleError::GlobalFuncNotFound)?; + let m: Module = f + .invoke(vec![device_type.into(), device_id.into()]) + .unwrap() + .try_into() + .expect("call should succeed"); + + // 2. Look up the model path + let (model_path, config_file_path) = get_model_path(model); + + // 3. Instantiate chat_config + let chat_config = get_chat_config(&config_file_path).unwrap(); + + // 4. Look up the model library + let model_lib_path = get_lib_module_path( + model, + &model_path, + &chat_config, + model_lib_path, + device_name, + &config_file_path, + ); + + let chat_mod = Self { + chat_module: m, + chat_config: chat_config, + }; + let model_lib_str = model_lib_path.as_path().display().to_string(); + let model_path_str = model_path.as_path().display().to_string(); + chat_mod + .reload(&model_lib_str, &model_path_str, "") + .unwrap(); + Ok(chat_mod) + } + + /// Reload the chat module from the given library and model path. + fn reload(&self, lib: &str, model_path: &str, app_config_json: &str) -> Result<()> { + let f = self.chat_module.get_function("reload", false)?; + f.invoke(vec![lib.into(), model_path.into(), app_config_json.into()])?; + Ok(()) + } + + /// Reset the chat session, clear all chat history, and potentially + /// override the original `mlc-chat-config.json`. + pub fn reset_chat(&self) -> Result<()> { + // TODO: add optional user-specified ChatConfig + let f = self.chat_module.get_function("reset_chat", false)?; + f.invoke(vec![])?; + Ok(()) + } + + /// Get the runtime stats of the encoding step, decoding step (and embedding step if exists) + /// of the chat module in text form. + pub fn stats(&self, verbose: bool) -> Result { + if verbose { + let f = self + .chat_module + .get_function("verbose_runtime_stats_text", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + return Ok(res); + } + let f = self.chat_module.get_function("runtime_stats_text", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + return Ok(res); + } + + /// Check if the stop condition is met for the current round. + fn stopped(&self) -> Result { + let f = self.chat_module.get_function("stopped", false)?; + let res: bool = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + + /// Get the output message in the current round. + fn get_message(&self) -> Result { + let f = self.chat_module.get_function("get_message", false)?; + let res: String = f.invoke(vec![])?.try_into().expect("call should succeed"); + Ok(res) + } + + /// Decode the next token, the decoding result is stored in a buffer and + /// can be retrieved by [get_message]. + fn decode(&self, generation_config: Option<&GenerationConfig>) -> Result<()> { + let generation_config_str = match generation_config { + Some(config) => serde_json::to_string(config).unwrap(), + None => { + let config = GenerationConfig::from_chat_config(&self.chat_config); + serde_json::to_string(&config).unwrap() + } + }; + let f = self.chat_module.get_function("decode", false)?; + f.invoke(vec![generation_config_str.into()])?; + Ok(()) + } + + /// A high-level method that returns the full response from the chat module given a user + /// prompt. User can optionally specify which callback method to use upon receiving the + /// response. + pub fn generate( + &self, + prompt: &str, + generation_config: Option<&GenerationConfig>, + ) -> Result> { + // TODO: add progress_callback + let mut new_msgs: Vec = vec![]; + let mut num_return_sequences: usize = 1; + + if let Some(gc) = generation_config { + if let Some(n) = gc.n { + num_return_sequences = n; + } + } + + for _ in 0..num_return_sequences { + self.reset_chat().unwrap(); + self.prefill(prompt, true, PlaceInPrompt::All, generation_config) + .unwrap(); + + while !self.stopped().unwrap() { + self.decode(generation_config)?; + } + let new_msg = self.get_message().unwrap(); + new_msgs.push(new_msg); + } + + Ok(new_msgs) + } + + /// Run prefill stage for a given input and optionally decode the first output token. + /// User can decide where to place the input in the prompt. + fn prefill( + &self, + input: &str, + decode_next_token: bool, + place_in_promt: PlaceInPrompt, + generation_config: Option<&GenerationConfig>, + ) -> Result<()> { + let generation_config_str = match generation_config { + Some(config) => serde_json::to_string(config).unwrap(), + None => { + let config = GenerationConfig::from_chat_config(&self.chat_config); + serde_json::to_string(&config).unwrap() + } + }; + + let f = self.chat_module.get_function("prefill", false)?; + f.invoke(vec![ + input.into(), + (&decode_next_token).into(), + place_in_promt.to_value().into(), + generation_config_str.into(), + ])?; + Ok(()) + } +} + diff --git a/rust/src/config.rs b/rust/src/config.rs new file mode 100644 index 0000000000..61371d197a --- /dev/null +++ b/rust/src/config.rs @@ -0,0 +1,276 @@ +use serde::{Deserialize, Serialize}; + +/// A struct that represents user-defined partial configuration for conversation template. +/// +/// This can be passed in to the instantiation of a [ChatModule](crate::chat_module::ChatModule) +/// instance to override the default setting in `mlc-chat-config.json` under the +/// model folder. Note that we will first load the predefined template +/// with the name specified in `conv_template`. +/// +/// Since the configuration is partial, everything will be optional. +#[derive(Clone, Default, Builder, Debug, Serialize, Deserialize)] +#[builder(default)] +pub struct ConvConfig { + /// Name of the conversation. + name: Option, + + /// The prompt encoded before starting the chat. + system: Option, + + /// An array that describes the role names of the user and the model. + roles: Option>, + + /// The chat history represented as an array of string pairs. + messages: Option>>, + + /// The offset used to begin the chat from the chat history. + offset: Option, + + /// Specifies whether we are in chat-bot mode (`0`) or pure LM prompt mode (`1`). + separator_style: Option, + + /// An array of strings indicating the separators to be used after a user message and a model message respectively. + seps: Option>, + + /// A string indicating the separator between a role and a message. + role_msg_sep: Option, + + /// A string indicating the separator to append to a role when there is no message yet. + role_empty_sep: Option, + + /// When the `stop_str` is encountered, the model will stop generating output. + stop_str: Option, + + /// A list of token IDs that act as stop tokens. + stop_tokens: Option>, + + /// Determines whether a beginning-of-string (bos) token should be added before the input tokens. + add_bos: Option, +} + +impl ConvConfig { + pub fn post_init(&mut self) { + if let Some(messages) = &self.messages { + if self.offset.is_none() { + self.offset = Some(messages.len()); + } + } + } +} + +/// A struct that represents user-defined partial configuration for the chat config file. +/// +/// An instance of [ChatConfig] can be passed in to override the default setting. +/// Since the configuration is partial, everything will be optional. +/// +/// Note: This struct is used to represent the chat config during intermediate processing. +#[derive(Builder, Debug, Default, Serialize, Deserialize)] +#[builder(default)] +pub struct ChatConfig { + /// The necessary model library to launch this model architecture. + /// Recommended to reuse model library when possible. + pub model_lib: Option, + + /// Uniquely identifying the model in application. Also used by + /// CLI to specify which model to run. + pub local_id: Option, + + /// The name of the conversation template that this chat uses. + pub conv_template: Option, + + /// Temperature applied to logits before sampling. Encourages diverse outputs if higher. + pub temperature: Option, + + /// Controls the likelihood of the model generating repeated texts. + /// See the CTRL paper for more details: + repetition_penalty: Option, + + /// Determines the set of tokens from which we sample during decoding. + /// More info on top-p sampling: + top_p: Option, + + /// Approximated average number of generated tokens in each round. + mean_gen_len: Option, + + /// Maximum number of tokens to be generated in each round. + max_gen_len: Option, + + /// Fraction of maximum window size to shift when it is exceeded. + shift_fill_factor: Option, + + /// List of tokenizer files of the model. + tokenizer_files: Option>, + + /// Partial overriding configuration for conversation template. + conv_config: Option, + + /// The category of the model's architecture (e.g. `llama`, `gpt_neox`, `rwkv`). + model_category: Option, + + /// Name of the model (e.g. `Llama-2-7b-chat-hf`). + model_name: Option, + + /// Tensor parallel degree. + num_shards: Option, + + /// Maximum kv cache window size. + max_window_size: Option, +} + +impl ChatConfig { + pub fn from_json(json_str: &str) -> Result { + serde_json::from_str(json_str) + } +} + +/// A struct that represents user-defined generation configuration. +/// +/// An instance of [GenerationConfig] can be passed into the +/// [ChatModule::generate](crate::chat_module::ChatModule::generate) function +/// to override the default generation settings specified in `mlc-chat-config.json` +/// and `ChatConfig` under the model folder. +/// +/// Once the generation ends, `GenerationConfig` is discarded, as the values +/// are only intended to override the `ChatConfig` generation settings during a +/// single generation, unless it is recurrently passed to the `generate` function. +/// This allows for changing generation settings over time, without permanently +/// overriding the `ChatConfig`. +/// +/// Since the configuration is partial, all fields are optional. +#[derive(Builder, Debug, Default, Serialize, Deserialize)] +#[builder(default)] +pub struct GenerationConfig { + /// The temperature applied to logits before sampling. The default value is + /// `0.7`. A higher temperature encourages more diverse outputs, while a + /// lower temperature produces more deterministic outputs. + temperature: Option, + + /// The repetition penalty controls the likelihood of the model generating + /// repeated texts. The default value is set to `1.0`, indicating that no + /// repetition penalty is applied. Increasing the value reduces the + /// likelihood of repeat text generation. However, setting a high + /// `repetition_penalty` may result in the model generating meaningless + /// texts. The ideal choice of repetition penalty may vary among models. Only + /// Active when presence_penalty and frequency_penalty are both `0.0`. + + /// For more details on how repetition penalty controls text generation, please + /// check out the CTRL paper . + repetition_penalty: Option, + + /// This parameter determines the set of tokens from which we sample during + /// decoding. The default value is set to `0.95`. At each step, we select + /// tokens from the minimal set that has a cumulative probability exceeding + /// the ``top_p` parameter. + + /// For additional information on top-p sampling, please refer to this blog + /// post: . + top_p: Option, + + /// The approximated average number of generated tokens in each round. Used + /// to determine whether the maximum window size would be exceeded. + mean_gen_len: Option, + + /// This parameter determines the maximum length of the generated text. If it is + /// not set, the model will generate text until it encounters a stop token. + max_gen_len: Option, + + /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on + /// whether they appear in the text so far, increasing the model's likelihood + /// to talk about new topics. Negative values can increase the likelihood of + /// repetition. + presence_penalty: Option, + + /// Number between `-2.0` and `2.0`. Positive values penalize new tokens based on their + /// existing frequency in the text so far, decreasing the model's likelihood to + /// repeat the same line verbatim. Negative values can increase the likelihood of + /// repetition. + frequency_penalty: Option, + + /// This parameter determines the number of text samples to generate. The default + /// value is `1`. Note that this parameter is only used when `stream` is set to + /// `false`. + pub n: Option, + + /// When `stop` is encountered, the model will stop generating output. + /// It can be a string or a list of strings. If it is a list of strings, the model + /// will stop generating output when any of the strings in the list is encountered. + /// Note that this parameter does not override the default stop string of the model. + stop: Option>, +} + +impl GenerationConfig { + pub fn from_chat_config(chat_config: &ChatConfig) -> Self { + Self { + temperature: chat_config.temperature, + repetition_penalty: chat_config.repetition_penalty, + top_p: chat_config.top_p, + mean_gen_len: chat_config.mean_gen_len, + max_gen_len: chat_config.max_gen_len, + presence_penalty: Some(0.0), + frequency_penalty: Some(0.0), + n: Some(0), + stop: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_conv_config() { + let mut config = ConvConfig { + messages: Some(vec![vec![ + "User: Hi".to_string(), + "Assistant: Hello".to_string(), + ]]), + offset: None, + ..Default::default() + }; + config.post_init(); + assert_eq!(config.offset, Some(1)); + } + + #[test] + fn test_chat_config() { + let json_data = r#" + { + "model_lib": "some_lib", + "local_id": "id123", + "temperature": 0.7 + } + "#; + + let config = ChatConfig::from_json(json_data).unwrap(); + + assert_eq!(config.model_lib, Some("some_lib".to_string())); + assert_eq!(config.local_id, Some("id123".to_string())); + assert_eq!(config.temperature, Some(0.7)); + let _pretty_json = serde_json::to_string_pretty(&config).unwrap(); + } + + #[test] + fn test_generation_config() { + let chat_config = ChatConfigBuilder::default() + .temperature(Some(0.7)) + .top_p(Some(0.8)) + .mean_gen_len(Some(50)) + .max_gen_len(Some(75)) + .build() + .unwrap(); + + let gen_config = GenerationConfig::from_chat_config(&chat_config); + + assert_eq!(gen_config.temperature, chat_config.temperature); + assert_eq!( + gen_config.repetition_penalty, + chat_config.repetition_penalty + ); + assert_eq!(gen_config.top_p, chat_config.top_p); + assert_eq!(gen_config.mean_gen_len, chat_config.mean_gen_len); + assert_eq!(gen_config.max_gen_len, chat_config.max_gen_len); + assert_eq!(gen_config.presence_penalty, Some(0.0)); + assert_eq!(gen_config.frequency_penalty, Some(0.0)); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 0000000000..e83534ceeb --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,5 @@ +#[macro_use] +extern crate derive_builder; + +pub mod chat_module; +pub mod config; From beca2ab36004ff81a96bdc49ca0280f9f7a4f567 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Wed, 8 Nov 2023 16:34:27 -0500 Subject: [PATCH 111/116] [Bugfix] Remove dependency on openai_api in chat module (#1222) Remove dependency on openai_api --- python/mlc_chat/chat_module.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 1e47729ac9..11cf9832a8 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -8,13 +8,15 @@ import warnings from dataclasses import asdict, dataclass, fields from enum import Enum -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import tvm from tvm.runtime import disco # pylint: disable=unused-import from . import base # pylint: disable=unused-import -from .interface.openai_api import ChatMessage + +if TYPE_CHECKING: + from .interface.openai_api import ChatMessage # pylint: disable=line-too-long _PYTHON_GET_STARTED_TUTORIAL_URL = "https://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_chat_module_getting_started.ipynb" @@ -784,7 +786,7 @@ def __init__( def generate( self, - prompt: Union[str, List[ChatMessage]], + prompt: Union[str, List["ChatMessage"]], generation_config: Optional[GenerationConfig] = None, progress_callback=None, ) -> Union[str, List[str]]: @@ -1002,7 +1004,7 @@ def _unload(self): def _prefill( self, - input: Union[str, List[ChatMessage]], # pylint: disable=redefined-builtin + input: Union[str, List["ChatMessage"]], # pylint: disable=redefined-builtin decode_next_token: bool = True, place_in_prompt: PlaceInPrompt = PlaceInPrompt.All, generation_config: Optional[GenerationConfig] = None, From 9ee570562dca1ef3b777e979f3df25135c7a719d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 8 Nov 2023 14:19:28 -0800 Subject: [PATCH 112/116] Bake in RAM Usage in the Generated DSO (#1224) With this PR, the metadata in a DSO file using `vm["_metadata"]()` now have information about the upper bound RAM estimate on each function. As an example, the JSON string now is: ```json { "quantization": "q4f16_1", "model_type": "llama", "memory_usage": { "_initialize_effect": 0, "prefill": 136192, "softmax_with_temperature": 0, "decode": 218624 }, "params": [ {"name": "model.embed_tokens.q_weight", "shape": [32000, 512], "dtype": "uint32"}, {"name": "model.embed_tokens.q_scale", "shape": [32000, 128], "dtype": "float16"}, ... ] } ``` This helps the MLC runtime to better determine if a method is going to OOM and plan ahead, e.g. plan pre-allocated KVCache, accordingly. The idea originates from Ruihang's ancient PR that prints memory usage estimate as debugging information for demo purposes, and this PR further enhances it to IRModule-level attribute that can be used by the runtime. --- python/mlc_chat/compiler/compile.py | 4 +- .../compiler_pass/estimate_memory_usage.py | 77 +++++++++++++++++++ .../compiler/compiler_pass/pipeline.py | 2 + 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 678e924a78..9bfa9787f9 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -58,6 +58,7 @@ def _metadata(): metadata = { "quantization": args.quantization.name, "model_type": args.model.name, + "memory_usage": {str(k): int(v) for k, v in mod.attrs["mlc_llm.memory_usage"].items()}, "params": [ { "name": name, @@ -67,6 +68,7 @@ def _metadata(): for name, param in named_params ], } + print(json.dumps(metadata, indent=2)) bb = relax.BlockBuilder() # pylint: disable=invalid-name with bb.function("main", params=[]): bb.emit_func_output(relax.StringImm(json.dumps(metadata))) @@ -96,10 +98,10 @@ def _compile(args: CompileArgs): mod, named_params = model.export_tvm( spec=model.get_default_spec(), # type: ignore ) - _attach_auxiliary_methods(mod, named_params, args, model_config) logger.info("Running optimizations using TVM Unity") with args.target: mod = relax.get_pipeline("mlc_llm")(mod) + _attach_auxiliary_methods(mod, named_params, args, model_config) logger.info("Generating code using TVM Unity") args.build_func(mod, args) logger.info("Generated: %s", bold(str(args.output))) diff --git a/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py b/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py new file mode 100644 index 0000000000..d6f959accf --- /dev/null +++ b/python/mlc_chat/compiler/compiler_pass/estimate_memory_usage.py @@ -0,0 +1,77 @@ +"""Memory usage estimation analysis function for Relax functions.""" +from typing import Dict + +import tvm +from tvm import relax +from tvm.ir import IRModule, Op +from tvm.relax.expr_functor import PyExprVisitor, visitor + + +@tvm.transform.module_pass(opt_level=0, name="EstimateMemoryUsage") +class EstimateMemoryUsage: # pylint: disable=too-few-public-methods + """A pass that attaches the memory usage information as an IRModule attribute. + + This pass relies on static analysis on each TVM Relax function in the specific IRModule. + It simply accumulates all memory allocation calls in a function, and does not consider + more dynamic runtime features like control flo "if" or function calls. + """ + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """Entry point of the pass.""" + lowered_mod = tvm.transform.Sequential( + [ + relax.transform.RewriteDataflowReshape(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + relax.transform.StaticPlanBlockMemory(), + ], + name="relax.lower", + )(mod) + usage = _MemoryEstimator().run(lowered_mod) + return mod.with_attr("mlc_llm.memory_usage", usage) + + +@visitor +class _MemoryEstimator(PyExprVisitor): + """The IR visitor which estimates the memory usage of each Relax function.""" + + def __init__(self) -> None: + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + self._op_alloc_tensor = Op.get("relax.builtin.alloc_tensor") + self._op_alloc_storage = Op.get("relax.memory.alloc_storage") + + def run(self, mod: IRModule) -> Dict[str, int]: + """Entry point of the visitor.""" + result: Dict[str, int] = {} + for global_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + self.planned_alloc_mem = 0 + self.planned_mem_num = 0 + self.visit_expr(func) + result[global_var.name_hint] = self.planned_alloc_mem + return result + + def visit_call_(self, call: relax.Call) -> None: # pylint: disable=arguments-renamed + if call.op == self._op_alloc_tensor: + self._builtin_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value) + elif call.op == self._op_alloc_storage: + self._storage_alloc(size=call.args[0]) + super().visit_call_(call) + + def _builtin_tensor_alloc(self, shape: relax.Expr, dtype_str: str) -> None: + assert isinstance(shape, relax.ShapeExpr) + size = 1 + for dim_len in shape.values: + if not isinstance(dim_len, tvm.tir.IntImm): + return + size *= dim_len.value + dtype = tvm.DataType(dtype_str) + self.planned_mem_num += 1 + self.planned_alloc_mem += size * ((dtype.bits + 7) // 8) * dtype.lanes + + def _storage_alloc(self, size: relax.Expr) -> None: + assert isinstance(size, relax.ShapeExpr) + self.planned_mem_num += 1 + self.planned_alloc_mem += size.values[0].value diff --git a/python/mlc_chat/compiler/compiler_pass/pipeline.py b/python/mlc_chat/compiler/compiler_pass/pipeline.py index f9bfdd0c59..1f8baab3b6 100644 --- a/python/mlc_chat/compiler/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler/compiler_pass/pipeline.py @@ -7,6 +7,7 @@ from tvm.relax import register_pipeline # pylint: disable=no-name-in-module from .clean_up_tir_attrs import CleanUpTIRAttrs +from .estimate_memory_usage import EstimateMemoryUsage from .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise from .fuse_dequantize_take import FuseDequantizeTake from .fuse_dequantize_transpose import FuseDequantizeTranspose @@ -64,6 +65,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I _LogProgress("Running memory optimizations"), LiftTIRGlobalBufferAlloc(), tvm.tir.transform.ForceNarrowIndexToInt32(), + EstimateMemoryUsage(), ] ) mod = seq(mod._move()) # pylint: disable=protected-access From 069181c084b46aa6f3f994c192ac1518b8855dae Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Wed, 8 Nov 2023 14:37:15 -0800 Subject: [PATCH 113/116] [Fix] ChatModule python messages and offset types (#1220) small fixes --- python/mlc_chat/chat_module.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 11cf9832a8..6b5bd41c8d 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -43,10 +43,10 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes roles : Optional[List[str]] An array that describes the role names of the user and the model. These names are specific to the model being used. - messages : Optional[List[str]] + messages : Optional[List[List[str]]] The chat history represented as an array of string pairs in the following format: ``[[role_0, msg_0], [role_1, msg_1], ...]``. - offset : Optional[str] + offset : Optional[int] The offset used to begin the chat from the chat history. When offset is not ``0``, ``messages[0:offset-1]`` will be encoded. separator_style : Optional[int] @@ -71,7 +71,7 @@ class ConvConfig: # pylint: disable=too-many-instance-attributes system: Optional[str] = None roles: Optional[List[str]] = None messages: Optional[List[List[str]]] = None - offset: Optional[str] = None + offset: Optional[int] = None separator_style: Optional[int] = None seps: Optional[List[str]] = None role_msg_sep: Optional[str] = None @@ -844,8 +844,6 @@ def generate( if (generation_config is not None) and (generation_config.n is not None): num_return_sequences = generation_config.n return_str = False - else: - num_return_sequences = 1 for _ in range(num_return_sequences): self.reset_chat() From f1bc951d7deebaa6ba0ae5a94c978843a35af3b1 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 8 Nov 2023 14:47:55 -0800 Subject: [PATCH 114/116] [Fix] Variable Upperbound Should be Injected before Build Pipeline (#1225) Now it shows a more reasonable upper bound for sequence length = 4096. ```json { "_initialize_effect": 0, "prefill": 3479311360, "softmax_with_temperature": 0, "decode": 34531840 } ``` Thanks Ruihang for helping with the fix! --- python/mlc_chat/compiler/compile.py | 64 ++++++++++++++++------------- 1 file changed, 35 insertions(+), 29 deletions(-) diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 9bfa9787f9..ade62309c5 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -52,41 +52,46 @@ def _attach_auxiliary_methods( mod: IRModule, named_params: List[Tuple[str, nn.Parameter]], args: CompileArgs, - model_config, ) -> None: - def _metadata(): - metadata = { - "quantization": args.quantization.name, - "model_type": args.model.name, - "memory_usage": {str(k): int(v) for k, v in mod.attrs["mlc_llm.memory_usage"].items()}, - "params": [ - { - "name": name, - "shape": list(param.shape), - "dtype": param.dtype, - } - for name, param in named_params - ], - } - print(json.dumps(metadata, indent=2)) + def _get_memory_usage(): + return {str(k): int(v) for k, v in mod.attrs["mlc_llm.memory_usage"].items()} + + def _get_param_info(): + return [ + { + "name": name, + "shape": list(param.shape), + "dtype": param.dtype, + } + for name, param in named_params + ] + + def _emit_metadata(metadata): bb = relax.BlockBuilder() # pylint: disable=invalid-name with bb.function("main", params=[]): bb.emit_func_output(relax.StringImm(json.dumps(metadata))) return bb.get()["main"] - def _attach_variable_bounds(): - for g_var, func in mod.functions_items(): - if isinstance(func, relax.Function): - mod[g_var] = func.with_attr( - "tir_var_upper_bound", - { - "seq_len": model_config.max_sequence_length, - "total_seq_len": model_config.max_sequence_length, - }, - ) + mod["_metadata"] = _emit_metadata( + metadata={ + "quantization": args.quantization.name, + "model_type": args.model.name, + "memory_usage": _get_memory_usage(), + "params": _get_param_info(), + } + ) + - mod["_metadata"] = _metadata() - _attach_variable_bounds() +def _attach_variable_bounds(mod, model_config): + for g_var, func in mod.functions_items(): + if isinstance(func, relax.Function): + mod[g_var] = func.with_attr( + "tir_var_upper_bound", + { + "seq_len": model_config.max_sequence_length, + "total_seq_len": model_config.max_sequence_length, + }, + ) def _compile(args: CompileArgs): @@ -99,9 +104,10 @@ def _compile(args: CompileArgs): spec=model.get_default_spec(), # type: ignore ) logger.info("Running optimizations using TVM Unity") + _attach_variable_bounds(mod, model_config) with args.target: mod = relax.get_pipeline("mlc_llm")(mod) - _attach_auxiliary_methods(mod, named_params, args, model_config) + _attach_auxiliary_methods(mod, named_params, args) logger.info("Generating code using TVM Unity") args.build_func(mod, args) logger.info("Generated: %s", bold(str(args.output))) From 834811f0df84e97561d1934b1215c2f8d509b48b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 9 Nov 2023 12:45:48 -0600 Subject: [PATCH 115/116] [MultiGPU] Support pre-sharded model weights (#1096) * [Bugfix] Correct input shape for shard info function Prior to this commit, the sharding functions sharded axis converted from `orig_size * num_shards` to `orig_size // num_shards`. This commit updates the sharding functions to instead convert from `orig_size` to `orig_size // num_shards`. * [Bugfix] Include LegalizeOps in utils.convert_weights Prior to this commit, `utils.convert_weights` assumes that the parameter transformation module is already legalized, and uses no relax operations that require legalization. This commit adds a call to `relax.transform.LegalizeOps` to remove this assumption. * [MultiGPU] Cleanup create_shard_info_func - De-duplicate the `if param.shard_strategy == foo` if/else chain - Return a `tvm.IRModule` instead of modifying an existing module * Extract a ParamManager.optimize_transform_param_order method * Extract ParamManager.create_parameter_transformation call from convert_weights * Support writing of pre-sharded weights * Support execution using pre-sharded weights * Updating for review comments * fix typo --- cpp/llm_chat.cc | 20 ++- mlc_llm/core.py | 69 +++++++++- mlc_llm/relax_model/commons.py | 195 +++++++++++++++++++++------ mlc_llm/relax_model/param_manager.py | 150 ++++++++++++++++++++- mlc_llm/utils.py | 24 ++-- python/mlc_chat/chat_module.py | 3 + 6 files changed, 396 insertions(+), 65 deletions(-) diff --git a/cpp/llm_chat.cc b/cpp/llm_chat.cc index 4237febd9c..1255c18bcc 100644 --- a/cpp/llm_chat.cc +++ b/cpp/llm_chat.cc @@ -175,19 +175,25 @@ struct FunctionTable { } } - ObjectRef LoadParams(const std::string& model_path, Device device) { + ObjectRef LoadParams(const std::string& model_path, Device device, bool use_presharded_weights) { if (this->use_disco) { std::filesystem::path fs_model_path = model_path; std::string metadata_path = (fs_model_path / "ndarray-cache.json").string(); std::string ndarray_cache_metadata = LoadBytesFromFile(metadata_path); PackedFunc loader_create = this->get_global_func("runtime.disco.ShardLoader"); - PackedFunc loader_load_all = this->get_global_func("runtime.disco.ShardLoaderLoadAll"); + + auto load_all_func_name = use_presharded_weights + ? "runtime.disco.ShardLoaderLoadAllPresharded" + : "runtime.disco.ShardLoaderLoadAll"; + PackedFunc loader_load_all = this->get_global_func(load_all_func_name); CHECK(loader_create != nullptr); CHECK(loader_load_all != nullptr); DRef loader = loader_create(metadata_path, ndarray_cache_metadata, "", this->disco_mod); DRef params = loader_load_all(loader); return params; } else { + CHECK(!use_presharded_weights) << "Use of pre-sharded weights requires more than one GPU"; + const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load"); ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load"; (*fload_cache)(model_path, static_cast(device.device_type), device.device_id); @@ -387,6 +393,12 @@ class LLMChat { } else { this->num_shards_ = 1; } + if (config.count("use_presharded_weights")) { + CHECK(config["use_presharded_weights"].is()); + this->use_presharded_weights_ = config["use_presharded_weights"].get(); + } else { + this->use_presharded_weights_ = false; + } if (config.count("max_window_size")) { CHECK(config["max_window_size"].is()); this->max_window_size_ = @@ -518,7 +530,7 @@ class LLMChat { << "Cannot find env function vm.builtin.sample_top_p_from_logits"; fsample_topp_from_logits_ = *fsample_topp_from_logits_ptr; // Step 5. Load params in nd-array cache. - this->params_ = ft_.LoadParams(model_path, device_); + this->params_ = ft_.LoadParams(model_path, device_, use_presharded_weights_); // Step 6. KV cache creation. this->kv_cache_ = ft_.create_kv_cache_func_(); // Step 7. Pre-allocate fixed size ndarray @@ -1358,6 +1370,8 @@ class LLMChat { int64_t vocab_size_; // number of shards in distributed inference int64_t num_shards_; + // Load weights that were saved in sharded form + bool use_presharded_weights_; // shift window fill factor double shift_fill_factor_{0.3}; // temperature diff --git a/mlc_llm/core.py b/mlc_llm/core.py index a0490ecf10..d9f2a4b8b4 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -1,5 +1,6 @@ # pylint: disable=missing-docstring, redefined-outer-name, not-callable import argparse +import functools import json import os import pickle @@ -29,7 +30,8 @@ rwkv, stablelm_3b, ) -from mlc_llm.relax_model.commons import create_shard_info_func +from mlc_llm.relax_model.commons import create_shard_info_func, create_shard_transformation_func +from mlc_llm.relax_model.param_manager import transform_params_for_each_rank, chain_parameter_transforms from mlc_llm.transform import fuse_split_rotary_embedding, rewrite_attention @@ -279,6 +281,13 @@ class BuildArgs: ), }, ) + use_presharded_weights: bool = field( + default=False, + metadata={ + "action": "store_true", + "help": "Produce separate weight sets for each shard.", + }, + ) use_flash_attn_mqa: bool = field( default=False, metadata={ @@ -366,9 +375,14 @@ def _parse_args(parsed) -> argparse.Namespace: "tvm.contrib.vllm.single_query_cached_kv_attention", True ), "TVM needs to be built with -DUSE_VLLM=ON." - parsed.artifact_path = os.path.join( - parsed.artifact_path, f"{parsed.model}-{parsed.quantization.name}" - ) + model_name = [ + parsed.model, + parsed.quantization.name, + ] + if parsed.use_presharded_weights: + model_name.append(f"presharded-{parsed.num_shards}gpu") + + parsed.artifact_path = os.path.join(parsed.artifact_path, "-".join(model_name)) return parsed @@ -602,6 +616,7 @@ def dump_mlc_chat_config( config["mean_gen_len"] = mean_gen_len config["max_gen_len"] = max_gen_len config["num_shards"] = args.num_shards + config["use_presharded_weights"] = args.use_presharded_weights config["shift_fill_factor"] = shift_fill_factor if rwkv_world: config["tokenizer_files"] = ["tokenizer_model"] @@ -741,12 +756,46 @@ def build_model_from_args(args: argparse.Namespace): qspec_updater.visit_module(mod) if not args.build_model_only: + parameter_transforms = [] + # Run pre-quantization if provided. args.model_path = param_manager.run_pre_quantize(args.model_path) param_manager.init_torch_pname_to_bin_name(args.use_safetensors) + parameter_transforms.append(param_manager.create_parameter_transformation()) + + # Run pre-sharding if required + if args.num_shards > 1 and args.use_presharded_weights: + mod_shard = create_shard_transformation_func(param_manager, args, model_config) + mod_shard = transform_params_for_each_rank(mod_shard, num_shards=args.num_shards) + parameter_transforms.append(mod_shard) + + # Chain all parameter transforms together. This allows + # ReorderTransformFunc to be applied to the single + # resulting parameter transformation function. + mod_transform = functools.reduce(chain_parameter_transforms, parameter_transforms) + + seq = tvm.ir.transform.Sequential( + [ + relax.transform.CanonicalizeBindings(), + relax.transform.EliminateCommonSubexpr(), + relax.transform.DeadCodeElimination(), + # TODO(Lunderberg): Implement + # relax.transform.Simplify() that applies + # canonicalization, CSE, and DCE until + # convergence. + relax.transform.CanonicalizeBindings(), + relax.transform.EliminateCommonSubexpr(), + relax.transform.DeadCodeElimination(), + param_manager.optimize_transform_param_order(), + ], + name="SimplifyModTransform", + ) + + mod_transform = seq(mod_transform) + + params = utils.convert_weights(mod_transform, param_manager, params, args) + utils.save_params(params, args.artifact_path, args.num_shards if args.use_presharded_weights else 1) - new_params = utils.convert_weights(param_manager, params, args) - utils.save_params(new_params, args.artifact_path) if args.model_category != "minigpt": utils.copy_tokenizer(args) if args.model_category == "rwkv" or args.model_category == "rwkv_world": @@ -772,7 +821,13 @@ def build_model_from_args(args: argparse.Namespace): mod = mod_transform_before_build(mod, param_manager, args, model_config) if args.num_shards > 1: - create_shard_info_func(mod, param_manager, args, model_config) + # We require a "create_sharding_info" function for all + # multi-GPU models, even if they are using pre-sharded + # weights. When using pre-sharded weights, the list of + # initialization-time transforms to apply is empty. + sharding_module = create_shard_info_func(param_manager, args, model_config) + mod.update(sharding_module) + with open(cache_path, "wb") as outfile: pickle.dump(mod, outfile) print(f"Save a cached module to {cache_path}.") diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 4924c2f015..e314ef0e39 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -1,7 +1,10 @@ import json -from typing import List +from typing import List, Optional, Dict -from tvm import relax, te, topi +import tvm +from tvm import relax, tir, te, topi + +import mlc_llm def create_metadata_func( @@ -27,8 +30,9 @@ def create_metadata_func( bb.emit_func_output(relax.StringImm(metadata)) -def create_shard_info_func(mod, param_manager, args, model_config): - num_shards = args.num_shards +def _get_shard_strategies( + model_config, num_shards: int, param_shape_is_already_sharded: bool +) -> Dict[str, tvm.tir.PrimFunc]: head_dim = model_config.hidden_size // model_config.num_attention_heads q_heads = model_config.num_attention_heads kv_heads = model_config.get_num_key_value_heads() @@ -36,7 +40,9 @@ def create_shard_info_func(mod, param_manager, args, model_config): # pylint: disable=invalid-name def shard_qkv_weight_scale(weight: relax.TensorStructInfo): (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial) * num_shards, int(red) + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + spatial *= num_shards a = te.placeholder((spatial, red), dtype=dtype) w = topi.reshape(a, (spatial // head_dim, head_dim, red)) q = te.compute((q_heads, head_dim, red), lambda i, j, k: w[i, j, k]) @@ -52,7 +58,9 @@ def shard_qkv_weight_scale(weight: relax.TensorStructInfo): def shard_k_weight_scale(weight: relax.TensorStructInfo): (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial), int(red) * num_shards + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + red *= num_shards a = te.placeholder((spatial, red), dtype=dtype) w = topi.reshape(a, (spatial, num_shards, red // num_shards)) w = topi.transpose(w, (1, 0, 2)) @@ -61,7 +69,9 @@ def shard_k_weight_scale(weight: relax.TensorStructInfo): def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): (spatial, red), dtype = weight.shape, weight.dtype - spatial, red = int(spatial) * num_shards, int(red) + spatial, red = int(spatial), int(red) + if param_shape_is_already_sharded: + spatial *= num_shards a = te.placeholder((spatial, red), dtype=dtype) g = te.compute((spatial // 2, red), lambda i, j: a[i, j]) u = te.compute((spatial // 2, red), lambda i, j: a[spatial // 2 + i, j]) @@ -74,50 +84,157 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): # pylint: enable=invalid-name + return { + "shard_qkv": shard_qkv_weight_scale, + "shard_mlp_k": shard_k_weight_scale, + "shard_o_proj_k": shard_k_weight_scale, + "shard_gate_up": shard_gate_up_weight_scale, + } + + +def create_shard_info_func(param_manager, args, model_config) -> tvm.IRModule: + shard_strategy_to_func = _get_shard_strategies( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) + shard_info_dict = {} shard_funcs = {} - def add_to_shard_info(param_name: str, func_name: str): - func = shard_funcs[func_name] - buffer = func.buffer_map[func.params[-1]] - shape = [int(i) for i in buffer.shape] - dtype = str(buffer.dtype) - shard_info_dict[param_name] = [(func_name, [shape, dtype])] + def add_to_shard_info(param_name: str, func_name: Optional[str]): + shard_info = [] + if func_name is not None: + func = shard_funcs[func_name] + buffer = func.buffer_map[func.params[-1]] + shape = [int(i) for i in buffer.shape] + dtype = str(buffer.dtype) + shard_info.append((func_name, [shape, dtype])) + + shard_info_dict[param_name] = shard_info q_params = param_manager.get_quantized_param_info("prefill").fields for _, param in param_manager.params.items(): if param.shard_strategy is None: pass - elif param.shard_strategy == "shard_qkv": - for i, weight in enumerate(param_manager.param2qrange[param]): - name = f"shard_qkv_{i}" - if name not in shard_funcs: - shard_funcs[name] = shard_qkv_weight_scale(q_params[weight]) - add_to_shard_info(f"param_{weight}", name) - elif param.shard_strategy == "shard_mlp_k": - for i, weight in enumerate(param_manager.param2qrange[param]): - name = f"shard_mlp_k_{i}" - if name not in shard_funcs: - shard_funcs[name] = shard_k_weight_scale(q_params[weight]) - add_to_shard_info(f"param_{weight}", name) - elif param.shard_strategy == "shard_o_proj_k": - for i, weight in enumerate(param_manager.param2qrange[param]): - name = f"shard_o_proj_k_{i}" - if name not in shard_funcs: - shard_funcs[name] = shard_k_weight_scale(q_params[weight]) - add_to_shard_info(f"param_{weight}", name) - elif param.shard_strategy == "shard_gate_up": + elif param.shard_strategy in shard_strategy_to_func: for i, weight in enumerate(param_manager.param2qrange[param]): - name = f"shard_gate_up_{i}" - if name not in shard_funcs: - shard_funcs[name] = shard_gate_up_weight_scale(q_params[weight]) - add_to_shard_info(f"param_{weight}", name) + if args.use_presharded_weights: + sharding_func_name = None + else: + sharding_func_name = f"{param.shard_strategy}_{i}" + if sharding_func_name not in shard_funcs: + shard_funcs[sharding_func_name] = shard_strategy_to_func[ + param.shard_strategy + ](q_params[weight]) + add_to_shard_info(f"param_{weight}", sharding_func_name) else: raise NotImplementedError(f"Shard strategy not implemented: {param.shard_strategy}") + + bb = relax.BlockBuilder() # pylint: disable=invalid-name + for name, func in shard_funcs.items(): func = func.with_attr({"global_symbol": name}) - mod[name] = func - bb = relax.BlockBuilder() # pylint: disable=invalid-name + bb.add_func(func, name) + with bb.function("get_shard_info", params=[]): bb.emit_func_output(relax.StringImm(json.dumps(shard_info_dict))) - mod["get_shard_info"] = bb.get()["get_shard_info"] + + return bb.get() + + +def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule: + shard_strategy_to_func = _get_shard_strategies( + model_config, + num_shards=args.num_shards, + param_shape_is_already_sharded=args.build_model_only, + ) + + q_params = param_manager.get_quantized_param_info("prefill").fields + + # The order of the quantized parameters must be preserved. + # Therefore, we need to loop over q_params and look up information + # as needed, rather than looping over original parameters and + # looking up the quantized parameters as needed. + orig_param_lookup = {} + for param in param_manager.params_in_func["prefill"]: + qrange = param_manager.param2qrange[param] + for i_orig_part, i_qparam in enumerate(qrange): + orig_param_lookup[i_qparam] = ( + param, + i_orig_part, + len(qrange), + ) + + bb = relax.BlockBuilder() # pylint: disable=invalid-name + with bb.function("transform_params"): + rank = tir.Var("rank", "int64") + # TODO(Lunderberg): Support primitive inputs to relax + # functions. Currently, using a PrimStructInfo as the + # argument results in an error thrown during + # `vm_shape_lower.cc`, due to BindParams failing to replace + # the symbolic variable "rank" when defined in a R.PrimValue. + # + # rank_arg = relax.Var("rank", relax.PrimStructInfo(value=rank)) + rank_arg = relax.Var("rank_arg", relax.ShapeStructInfo([rank])) + + args = [rank_arg] + output = [] + + for i_qparam, qparam_sinfo in enumerate(q_params): + param, i_orig_part, num_orig_parts = orig_param_lookup[i_qparam] + + if isinstance(param.quant_spec, mlc_llm.quantization.NoQuantizationSpec): + arg_name = param.name + elif num_orig_parts == 1: + arg_name = f"{param.name}.quantized" + else: + arg_name = f"{param.name}.quantized_{i_orig_part}" + + arg = relax.Var(arg_name, qparam_sinfo) + + if param.shard_strategy is None: + sharded = arg + else: + strategy_func = shard_strategy_to_func[param.shard_strategy]( + qparam_sinfo + ).without_attr("global_symbol") + strategy_gvar = bb.add_func( + strategy_func, + func_name=f"{arg_name}.sharding_func", + ) + + # TODO(Lunderberg): Write the strategies as relax + # functions, so the sharded shapes can be inferred. + reordered_buffer = strategy_func.buffer_map[strategy_func.params[-1]] + reordered_sinfo = relax.TensorStructInfo( + reordered_buffer.shape, reordered_buffer.dtype + ) + reordered = relax.op.call_tir( + strategy_gvar, relax.Tuple([arg]), out_sinfo=reordered_sinfo + ) + + # TODO(Lunderberg): Allow relax.PrimValue as the index + # in a TupleGetItem. This would allow all of the + # splits to be generated at once in the merged + # function, and could be optimized to an in-place view. + # + # split = relax.op.split(reordered, indices_or_sections=num_shards, axis=0)[rank] + split = relax.op.strided_slice( + reordered, + axes=[0], + begin=[rank], + end=[rank + 1], + assume_inbound=True, + ) + + sharded = relax.op.squeeze(split, axis=0) + + args.append(arg) + output.append(sharded) + + with bb.dataflow(): + gv = bb.emit_output(output) + bb.emit_func_output(output=gv, params=args) + + return bb.get() diff --git a/mlc_llm/relax_model/param_manager.py b/mlc_llm/relax_model/param_manager.py index 7f0751b2a0..69a25ccb73 100644 --- a/mlc_llm/relax_model/param_manager.py +++ b/mlc_llm/relax_model/param_manager.py @@ -763,14 +763,23 @@ def create_parameter_transformation(self, optimize_parameter_order: bool = True) """ mod = _create_quantize_func(self) if optimize_parameter_order: - reorder_pass = ReorderTransformFunc( - self.pidx2pname, - self.torch_pname2binname, - self.f_convert_pname_fwd, - ) - mod = reorder_pass(mod) + mod = self.optimize_transform_param_order()(mod) return mod + def optimize_transform_param_order(self) -> tvm.transform.Pass: + """Produce an transformation that optimizes for minimal memory footprint + + Returns + ------- + tvm.transform.Pass + The transformation + """ + return ReorderTransformFunc( + self.pidx2pname, + self.torch_pname2binname, + self.f_convert_pname_fwd, + ) + @mutator class ParamReplacer(PyExprMutator): @@ -1006,3 +1015,132 @@ def _create_quantize_func(param_manager: ParamManager) -> tvm.IRModule: param_manager.param2qrange = param2qrange # Return the created IRModule. return bb.get() + + +def transform_params_for_each_rank( + mod: tvm.IRModule, num_shards: int, rank_argument_name: str = "rank_arg" +) -> tvm.IRModule: + """Update a parameter transform to apply across all ranks + + For use in generating a pre-sharded set of weights. Given a + parameter transformation that generates sharded model weights for + a single shard, produce a parameter transformation that generates + sharded model weights for each shard. + + Parameters + ---------- + mod: tvm.IRModule + + A module containing the parameter transformation function, + named "transform_params", along with any subroutines called by + the parameter transformation. + + num_shards: int + + The number of shards to generate. + + rank_argument_name: str + + The name of the argument that specifies the rank. Should be a + R.ShapeTuple with a single R.PrimStructInfo('int64'). + + Returns + ------- + tvm.IRModule + + The modified parameter transformation + """ + generic_transform = mod["transform_params"] + tensor_params = generic_transform.params[1:] + + bb = relax.BlockBuilder() + + with bb.function("transform_params", params=tensor_params): + output = [] + for rank in range(num_shards): + # TODO(Lunderberg): Implement this in terms of a + # generic utility that inlines local functions. + func = generic_transform + func = func.bind_params({rank_argument_name: relax.ShapeExpr([rank])}) + func = relax.utils.copy_with_new_vars(func) + func = func.bind_params( + {var: tensor_param for (var, tensor_param) in zip(func.params, tensor_params)} + ) + shard_tuple = func.body + output.extend([shard_tuple[i] for i in range(len(tensor_params))]) + + with bb.dataflow(): + gv = bb.emit_output(relax.Tuple(output)) + bb.emit_func_output(gv) + + mod["transform_params"] = bb.get()["transform_params"] + return mod + + +def chain_parameter_transforms(mod_a: tvm.IRModule, mod_b: tvm.IRModule) -> tvm.IRModule: + """Chain two sequential parameter transformations + + For use in manipulating sets of model weights. Given two + parameter transformations that could be applied sequentially, + produce a single parameter transformation whose output is the same + as applying the parameter transformations sequentially. + + + .. code-block:: python + + # Before + params_after_a = mod_a['transform_params'](orig_params) + params_after_b = mod_b['transform_params'](params_after_a) + + # After + mod_ab = chain_parameter_transforms(mod_a, mod_b) + params_after_b = mod_ab['transform_params'](orig_params) + + Parameters + ---------- + mod_a: tvm.IRModule + + The module containing the first parameter transformation. + + mod_b: tvm.IRModule + + The module containing the second parameter transformation. + + Returns + ------- + tvm.IRModule + + The module containing the output + + """ + func_a = mod_a["transform_params"] + func_b = mod_b["transform_params"] + + bb = relax.BlockBuilder() + + with bb.function("transform_params", params=func_a.params): + with bb.dataflow(): + # TODO(Lunderberg): Implement this in terms of a + # generic utility that inlines local functions. + func_a_output = bb.emit(func_a.body) + func_b_param_map = {param: expr for (param, expr) in zip(func_b.params, func_a_output)} + func_b_output = func_b.bind_params(func_b_param_map).body + gv = bb.emit_output(func_b_output) + bb.emit_func_output(gv) + + merged_transform_func = bb.get()["transform_params"] + + new_mod = { + **{ + gvar: func + for gvar, func in mod_a.functions.items() + if gvar.name_hint != "transform_params" + }, + **{ + gvar: func + for gvar, func in mod_b.functions.items() + if gvar.name_hint != "transform_params" + }, + "transform_params": merged_transform_func, + } + return tvm.IRModule(new_mod) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index 1bcf1e8816..b995de2956 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -210,16 +210,11 @@ def debug_dump_shader(ex: tvm.relax.Executable, name: str, args: argparse.Namesp def convert_weights( + mod_transform: tvm.IRModule, param_mgr: param_manager.ParamManager, model_params: List[Optional[tvm.nd.NDArray]], args: argparse.Namespace, ): - # Create the quantization function. - # We first create an initial one, then reorder it according to each - # weight's location in the binary files, in the purpose of reducing - # memory usage when loading torch weights as well as acceleration. - mod_transform = param_mgr.create_parameter_transformation() - # Save the number of parameters before we lower mod_transform, so # we can use them in the progress bar. transform_func = mod_transform["transform_params"] @@ -231,6 +226,7 @@ def convert_weights( mod_transform = relax.transform.ToNonDataflow()(mod_transform) mod_transform = relax.transform.LazyTransformParams()(mod_transform) mod_transform = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_transform) + mod_transform = relax.transform.LegalizeOps()(mod_transform) debug_dump_script(mod_transform, "mod_convert_weights.py", args) @@ -278,16 +274,24 @@ def convert_weights( return loaded_params -def save_params(params: List[tvm.nd.NDArray], artifact_path: str) -> None: +def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded: int = 1) -> None: from tvm.contrib import tvmjs # pylint: disable=import-outside-toplevel + assert len(params) % num_presharded == 0 + num_weights = len(params) // num_presharded + meta_data = {} param_dict = {} meta_data["ParamSize"] = len(params) - total_size = 0.0 for i, nd in enumerate(params): - assert nd is not None, f"Missing parameter at index {i}" - param_dict[f"param_{i}"] = nd + if num_presharded == 1: + param_name = f"param_{i}" + else: + expected_worker_id = i // num_weights + orig_param_id = i % num_weights + param_name = f"param_{orig_param_id}_shard-{expected_worker_id+1}-of-{num_presharded}" + + param_dict[param_name] = nd total_size_bytes = sum(math.prod(param.shape) * np.dtype(param.dtype).itemsize for param in params) total_size_gb = total_size_bytes / (1024 ** 3) diff --git a/python/mlc_chat/chat_module.py b/python/mlc_chat/chat_module.py index 6b5bd41c8d..bcadaa84ba 100644 --- a/python/mlc_chat/chat_module.py +++ b/python/mlc_chat/chat_module.py @@ -153,6 +153,8 @@ class ChatConfig: # pylint: disable=too-many-instance-attributes Name of the model (e.g. ``Llama-2-7b-chat-hf``). num_shards: Optional[str] Tensor parallel degree. + use_presharded_weights: Optional[bool] + If True, the weights were saved with sharding already applied. max_window_size: Optional[str] Maximum kv cache window size. """ @@ -171,6 +173,7 @@ class ChatConfig: # pylint: disable=too-many-instance-attributes model_category: Optional[str] = None model_name: Optional[str] = None num_shards: Optional[int] = None + use_presharded_weights: Optional[bool] = None max_window_size: Optional[int] = None @classmethod From b022dc223e169d9e6237192bf42a53080fde035b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 9 Nov 2023 19:36:43 +0000 Subject: [PATCH 116/116] fix --- mlc_llm/utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlc_llm/utils.py b/mlc_llm/utils.py index b995de2956..f6858922cd 100644 --- a/mlc_llm/utils.py +++ b/mlc_llm/utils.py @@ -119,8 +119,15 @@ def argparse_postproc_common(args: argparse.Namespace) -> None: if args.quantization not in quantization_schemes: raise ValueError(f'Quantization "{args.quantization}" is not supported.') + + use_ft_quant = args.quantization in ["q4f16_ft", "q8f16_ft"] args.quantization = quantization_schemes[args.quantization] + if use_ft_quant and args.num_shards > 1: + # Preprocess is done after sharding for this case. + args.quantization.linear_weight.do_preprocess = False + args.quantization.final_fc_weight.do_preprocess = False + def debug_dump_script(mod, name, args: argparse.Namespace, show_meta=True): """Debug dump mode""" @@ -283,6 +290,7 @@ def save_params(params: List[tvm.nd.NDArray], artifact_path: str, num_presharded meta_data = {} param_dict = {} meta_data["ParamSize"] = len(params) + for i, nd in enumerate(params): if num_presharded == 1: param_name = f"param_{i}"