Skip to content

Commit 7ccb51a

Browse files
authored
Integrating MLC runtime with the new compilation workflow (mlc-ai#1203)
1 parent 3413d17 commit 7ccb51a

File tree

10 files changed

+282
-59
lines changed

10 files changed

+282
-59
lines changed

cpp/json_parser.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#ifndef MLC_LLM_CPP_JSON_PARSER_H_
2+
#define MLC_LLM_CPP_JSON_PARSER_H_
3+
4+
#define PICOJSON_USE_INT64
5+
#ifndef __STDC_FORMAT_MACROS
6+
#define __STDC_FORMAT_MACROS
7+
#endif
8+
9+
#include <picojson.h>
10+
#include <tvm/runtime/container/shape_tuple.h>
11+
#include <tvm/runtime/data_type.h>
12+
#include <tvm/runtime/logging.h>
13+
14+
namespace mlc {
15+
namespace llm {
16+
namespace json {
17+
18+
template <typename ValueType>
19+
inline ValueType Lookup(const picojson::object& json, const std::string& key) {
20+
auto it = json.find(key);
21+
CHECK(it != json.end()) << "ValueError: key `" << key << "` not found in the JSON object";
22+
CHECK(it->second.is<ValueType>()) << "ValueError: key `" << key << "` has unexpected type";
23+
return it->second.get<ValueType>();
24+
}
25+
26+
template <>
27+
inline tvm::runtime::DataType Lookup(const picojson::object& json, const std::string& key) {
28+
return tvm::runtime::DataType(tvm::runtime::String2DLDataType(Lookup<std::string>(json, key)));
29+
}
30+
31+
template <>
32+
inline tvm::runtime::ShapeTuple Lookup(const picojson::object& json, const std::string& key) {
33+
picojson::array shape = Lookup<picojson::array>(json, key);
34+
std::vector<int64_t> result;
35+
result.reserve(shape.size());
36+
for (const picojson::value& dim : shape) {
37+
CHECK(dim.is<int64_t>()) << "ValueError: key `" << key << "` has unexpected type";
38+
result.push_back(dim.get<int64_t>());
39+
}
40+
return tvm::runtime::ShapeTuple(std::move(result));
41+
}
42+
43+
inline picojson::object ParseObject(const std::string& json_str) {
44+
picojson::value result;
45+
std::string err = picojson::parse(result, json_str);
46+
if (!err.empty()) {
47+
LOG(FATAL) << "Failed to parse JSON: err. The JSON string is:" << json_str;
48+
}
49+
CHECK(result.is<picojson::object>())
50+
<< "ValueError: The given string is not a JSON object: " << json_str;
51+
return result.get<picojson::object>();
52+
}
53+
54+
inline picojson::object AsJSONObject(const picojson::value& json) {
55+
CHECK(json.is<picojson::object>()) << "ValueError: The given value is not a JSON object";
56+
return json.get<picojson::object>();
57+
}
58+
59+
} // namespace json
60+
} // namespace llm
61+
} // namespace mlc
62+
63+
#endif // MLC_LLM_CPP_JSON_PARSER_H_

cpp/llm_chat.cc

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <vector>
3333

3434
#include "conversation.h"
35+
#include "model_metadata.h"
3536
#include "random.h"
3637
#include "support.h"
3738
#include "tokenizers.h"
@@ -161,13 +162,18 @@ struct FunctionTable {
161162
static_cast<int>(relax_vm::AllocatorType::kPooled), static_cast<int>(kDLCPU), 0,
162163
static_cast<int>(relax_vm::AllocatorType::kPooled));
163164
this->mod_get_func = [this](const std::string& name) -> PackedFunc {
164-
return this->local_vm->GetFunction(name, false);
165+
PackedFunc func = this->local_vm->GetFunction(name, false);
166+
if (func == nullptr) {
167+
LOG(WARNING) << "Cannot find function in VM: " << name;
168+
}
169+
return func;
165170
};
166171
this->get_global_func = [](const std::string& name) -> PackedFunc {
167172
const auto* f = tvm::runtime::Registry::Get(name);
168173
CHECK(f != nullptr) << "ValueError: Cannot find function " << name;
169174
return *f;
170175
};
176+
this->model_metadata_ = ModelMetadata::FromModule(this->local_vm);
171177
this->_InitFunctions();
172178
}
173179
}
@@ -188,10 +194,23 @@ struct FunctionTable {
188194
const PackedFunc* fload_cache = tvm::runtime::Registry::Get("vm.builtin.ndarray_cache.load");
189195
ICHECK(fload_cache) << "TVM runtime cannot find vm.builtin.ndarray_cache.load";
190196
(*fload_cache)(model_path, static_cast<int32_t>(device.device_type), device.device_id);
191-
const PackedFunc* fload_params =
192-
tvm::runtime::Registry::Get("vm.builtin.param_array_from_cache");
193-
ICHECK(fload_params) << "Cannot find env function vm.builtin.param_array_from_cache";
194-
Array<NDArray> params = (*fload_params)("param", -1);
197+
Array<NDArray> params;
198+
if (this->model_metadata_.params.empty()) {
199+
constexpr const char* name_loader = "vm.builtin.param_array_from_cache";
200+
const PackedFunc* fload_params = tvm::runtime::Registry::Get(name_loader);
201+
ICHECK(fload_params) << "Cannot find env function: " << name_loader;
202+
params = (*fload_params)("param", -1);
203+
} else {
204+
constexpr const char* name_loader = "vm.builtin.param_array_from_cache_by_name";
205+
const PackedFunc* fload_params = tvm::runtime::Registry::Get(name_loader);
206+
ICHECK(fload_params) << "Cannot find env function: " << name_loader;
207+
Array<String> param_names;
208+
param_names.reserve(this->model_metadata_.params.size());
209+
for (const auto& param : this->model_metadata_.params) {
210+
param_names.push_back(param.name);
211+
}
212+
params = (*fload_params)(param_names);
213+
}
195214
// after we get params, it is safe to simply clear the cached version
196215
// as these params are referenced by params_
197216
const PackedFunc* fclear_ndarray_cache =
@@ -210,6 +229,9 @@ struct FunctionTable {
210229
this->softmax_func_ = mod_get_func("softmax_with_temperature");
211230
this->encoding_without_cache_func_ = mod_get_func("encoding_without_cache");
212231
this->create_kv_cache_func_ = mod_get_func("create_kv_cache");
232+
if (this->create_kv_cache_func_ == nullptr) {
233+
this->create_kv_cache_func_ = mod_get_func("_initialize_effect");
234+
}
213235
this->reset_kv_cache_func_ = mod_get_func("reset_kv_cache");
214236
if (this->reset_kv_cache_func_ == nullptr) {
215237
this->reset_kv_cache_func_ = get_global_func("vm.builtin.attention_kv_cache_array_clear");
@@ -260,6 +282,7 @@ struct FunctionTable {
260282
PackedFunc reset_kv_cache_func_;
261283
bool support_backtracking_kv_;
262284
PackedFunc fkvcache_array_popn_;
285+
ModelMetadata model_metadata_;
263286
};
264287

265288
} // namespace
@@ -437,6 +460,7 @@ class LLMChat {
437460
* \note This function overrides existing configurations.
438461
*/
439462
void LoadJSONOverride(const std::string& config_str, bool partial_update = false) {
463+
LOG(INFO) << "config_str = " << config_str;
440464
picojson::value config_json;
441465
std::string err = picojson::parse(config_json, config_str);
442466
if (!err.empty()) {

cpp/model_metadata.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "./model_metadata.h"
2+
3+
#include <tvm/runtime/packed_func.h>
4+
5+
#include "./json_parser.h"
6+
7+
namespace mlc {
8+
namespace llm {
9+
10+
using namespace tvm::runtime;
11+
12+
ModelMetadata::Param ModelMetadata::Param::FromJSON(const picojson::object& param) {
13+
Param result;
14+
result.name = json::Lookup<std::string>(param, "name");
15+
result.shape = json::Lookup<ShapeTuple>(param, "shape");
16+
result.dtype = json::Lookup<DataType>(param, "dtype");
17+
return result;
18+
}
19+
20+
ModelMetadata ModelMetadata::FromJSON(const picojson::object& metadata) {
21+
ModelMetadata result;
22+
result.model_type = json::Lookup<std::string>(metadata, "model_type");
23+
result.quantization = json::Lookup<std::string>(metadata, "quantization");
24+
picojson::array params = json::Lookup<picojson::array>(metadata, "params");
25+
result.params.reserve(params.size());
26+
for (const picojson::value& json_param : params) {
27+
result.params.emplace_back(ModelMetadata::Param::FromJSON(json::AsJSONObject(json_param)));
28+
}
29+
return result;
30+
}
31+
32+
ModelMetadata ModelMetadata::FromModule(tvm::runtime::Module module) {
33+
std::string json_str = "";
34+
try {
35+
TypedPackedFunc<String()> pf = module.GetFunction("_metadata");
36+
ICHECK(pf != nullptr);
37+
json_str = pf();
38+
} catch (...) {
39+
return ModelMetadata(); // TODO: add a warning message about legacy usecases
40+
}
41+
picojson::object json = json::ParseObject(json_str);
42+
try {
43+
return ModelMetadata::FromJSON(json);
44+
} catch (const std::exception& e) {
45+
LOG(WARNING) << "Failed to parse metadata:\n" << json_str;
46+
throw e;
47+
}
48+
}
49+
50+
} // namespace llm
51+
} // namespace mlc

cpp/model_metadata.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*!
2+
* \file model_metadata.h
3+
* \brief Metadata stored in model lib
4+
*/
5+
#include <tvm/runtime/container/shape_tuple.h>
6+
#include <tvm/runtime/container/string.h>
7+
#include <tvm/runtime/data_type.h>
8+
#include <tvm/runtime/module.h>
9+
10+
#include <unordered_map>
11+
12+
namespace picojson {
13+
class value;
14+
using object = std::unordered_map<std::string, value>;
15+
} // namespace picojson
16+
17+
namespace mlc {
18+
namespace llm {
19+
20+
struct ModelMetadata {
21+
struct Param {
22+
tvm::runtime::String name;
23+
tvm::runtime::ShapeTuple shape;
24+
tvm::runtime::DataType dtype;
25+
26+
static Param FromJSON(const picojson::object& param_obj);
27+
};
28+
std::string model_type;
29+
std::string quantization;
30+
std::vector<Param> params;
31+
32+
static ModelMetadata FromJSON(const picojson::object& json_str);
33+
static ModelMetadata FromModule(tvm::runtime::Module module);
34+
};
35+
36+
} // namespace llm
37+
} // namespace mlc

python/mlc_chat/compiler/compile.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Python entrypoint of compilation."""
22
import dataclasses
3+
import json
34
import logging
45
from io import StringIO
56
from pathlib import Path
6-
from typing import Callable, Optional
7+
from typing import Callable, List, Optional, Tuple
78

89
from tvm import IRModule, relax
10+
from tvm.relax.frontend import nn
911
from tvm.target import Target
1012

1113
from ..support.style import bold
@@ -46,21 +48,61 @@ def display(self) -> None:
4648
print(out.getvalue().rstrip())
4749

4850

51+
def _attach_auxiliary_methods(
52+
mod: IRModule,
53+
named_params: List[Tuple[str, nn.Parameter]],
54+
args: CompileArgs,
55+
model_config,
56+
) -> None:
57+
def _metadata():
58+
metadata = {
59+
"quantization": args.quantization.name,
60+
"model_type": args.model.name,
61+
"params": [
62+
{
63+
"name": name,
64+
"shape": list(param.shape),
65+
"dtype": param.dtype,
66+
}
67+
for name, param in named_params
68+
],
69+
}
70+
bb = relax.BlockBuilder() # pylint: disable=invalid-name
71+
with bb.function("main", params=[]):
72+
bb.emit_func_output(relax.StringImm(json.dumps(metadata)))
73+
return bb.get()["main"]
74+
75+
def _attach_variable_bounds():
76+
for g_var, func in mod.functions_items():
77+
if isinstance(func, relax.Function):
78+
mod[g_var] = func.with_attr(
79+
"tir_var_upper_bound",
80+
{
81+
"seq_len": model_config.max_sequence_length,
82+
"total_seq_len": model_config.max_sequence_length,
83+
},
84+
)
85+
86+
mod["_metadata"] = _metadata()
87+
_attach_variable_bounds()
88+
89+
4990
def _compile(args: CompileArgs):
5091
logger.info("Creating model from: %s", args.config)
5192
model_config = args.model.config.from_file(args.config)
5293
args.overrides.apply(model_config)
5394
model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization)
5495
logger.info("Exporting the model to TVM Unity compiler")
55-
mod, _named_params = model.export_tvm(
96+
mod, named_params = model.export_tvm(
5697
spec=model.get_default_spec(), # type: ignore
5798
)
99+
_attach_auxiliary_methods(mod, named_params, args, model_config)
58100
logger.info("Running optimizations using TVM Unity")
59101
with args.target:
60102
mod = relax.get_pipeline("mlc_llm")(mod)
61103
logger.info("Generating code using TVM Unity")
62104
args.build_func(mod, args)
63-
logger.info("Code dumped to: %s", bold(str(args.output)))
105+
logger.info("Generated: %s", bold(str(args.output)))
64106

65107

66108
def compile( # pylint: disable=too-many-arguments,redefined-builtin

python/mlc_chat/compiler/compiler_pass/fuse_decode_matmul_ewise.py renamed to python/mlc_chat/compiler/compiler_pass/fuse_dequantize_matmul_ewise.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
"""A compiler pass that fuses decode + matmul + elementwise."""
1+
"""A compiler pass that fuses dequantize + matmul + elementwise."""
22
import tvm
33
from tvm import IRModule, relax
44
from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard
55

66

7-
@tvm.transform.module_pass(opt_level=0, name="FuseDecodeMatmulEwise")
8-
class FuseDecodeMatmulEwise: # pylint: disable=too-few-public-methods
9-
"""A compiler pass that fuses decode + matmul + elementwise."""
7+
@tvm.transform.module_pass(opt_level=0, name="FuseDequantizeMatmulEwise")
8+
class FuseDequantizeMatmulEwise: # pylint: disable=too-few-public-methods
9+
"""A compiler pass that fuses dequantize + matmul + elementwise."""
1010

1111
def transform_module(
1212
self,
@@ -23,7 +23,7 @@ def transform_module(
2323
relax.transform.FuseOpsByPattern(
2424
[
2525
(
26-
"decode_matmul",
26+
"dequantize_matmul",
2727
*_pattern(match_ewise, n_aux_tensor),
2828
)
2929
]
@@ -62,7 +62,9 @@ def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool:
6262
g_var = call.args[0]
6363
if not isinstance(g_var, relax.GlobalVar):
6464
return False
65-
return g_var.name_hint.startswith("decode") or g_var.name_hint.startswith("fused_decode")
65+
return g_var.name_hint.startswith("dequantize") or g_var.name_hint.startswith(
66+
"fused_dequantize"
67+
)
6668

6769
def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool:
6870
call = ctx.annotated_expr["matmul"]

0 commit comments

Comments
 (0)