32
32
#include < vector>
33
33
34
34
#include " conversation.h"
35
+ #include " model_metadata.h"
35
36
#include " random.h"
36
37
#include " support.h"
37
38
#include " tokenizers.h"
@@ -161,13 +162,18 @@ struct FunctionTable {
161
162
static_cast <int >(relax_vm::AllocatorType::kPooled ), static_cast <int >(kDLCPU ), 0 ,
162
163
static_cast <int >(relax_vm::AllocatorType::kPooled ));
163
164
this ->mod_get_func = [this ](const std::string& name) -> PackedFunc {
164
- return this ->local_vm ->GetFunction (name, false );
165
+ PackedFunc func = this ->local_vm ->GetFunction (name, false );
166
+ if (func == nullptr ) {
167
+ LOG (WARNING) << " Cannot find function in VM: " << name;
168
+ }
169
+ return func;
165
170
};
166
171
this ->get_global_func = [](const std::string& name) -> PackedFunc {
167
172
const auto * f = tvm::runtime::Registry::Get (name);
168
173
CHECK (f != nullptr ) << " ValueError: Cannot find function " << name;
169
174
return *f;
170
175
};
176
+ this ->model_metadata_ = ModelMetadata::FromModule (this ->local_vm );
171
177
this ->_InitFunctions ();
172
178
}
173
179
}
@@ -188,10 +194,23 @@ struct FunctionTable {
188
194
const PackedFunc* fload_cache = tvm::runtime::Registry::Get (" vm.builtin.ndarray_cache.load" );
189
195
ICHECK (fload_cache) << " TVM runtime cannot find vm.builtin.ndarray_cache.load" ;
190
196
(*fload_cache)(model_path, static_cast <int32_t >(device.device_type ), device.device_id );
191
- const PackedFunc* fload_params =
192
- tvm::runtime::Registry::Get (" vm.builtin.param_array_from_cache" );
193
- ICHECK (fload_params) << " Cannot find env function vm.builtin.param_array_from_cache" ;
194
- Array<NDArray> params = (*fload_params)(" param" , -1 );
197
+ Array<NDArray> params;
198
+ if (this ->model_metadata_ .params .empty ()) {
199
+ constexpr const char * name_loader = " vm.builtin.param_array_from_cache" ;
200
+ const PackedFunc* fload_params = tvm::runtime::Registry::Get (name_loader);
201
+ ICHECK (fload_params) << " Cannot find env function: " << name_loader;
202
+ params = (*fload_params)(" param" , -1 );
203
+ } else {
204
+ constexpr const char * name_loader = " vm.builtin.param_array_from_cache_by_name" ;
205
+ const PackedFunc* fload_params = tvm::runtime::Registry::Get (name_loader);
206
+ ICHECK (fload_params) << " Cannot find env function: " << name_loader;
207
+ Array<String> param_names;
208
+ param_names.reserve (this ->model_metadata_ .params .size ());
209
+ for (const auto & param : this ->model_metadata_ .params ) {
210
+ param_names.push_back (param.name );
211
+ }
212
+ params = (*fload_params)(param_names);
213
+ }
195
214
// after we get params, it is safe to simply clear the cached version
196
215
// as these params are referenced by params_
197
216
const PackedFunc* fclear_ndarray_cache =
@@ -210,6 +229,9 @@ struct FunctionTable {
210
229
this ->softmax_func_ = mod_get_func (" softmax_with_temperature" );
211
230
this ->encoding_without_cache_func_ = mod_get_func (" encoding_without_cache" );
212
231
this ->create_kv_cache_func_ = mod_get_func (" create_kv_cache" );
232
+ if (this ->create_kv_cache_func_ == nullptr ) {
233
+ this ->create_kv_cache_func_ = mod_get_func (" _initialize_effect" );
234
+ }
213
235
this ->reset_kv_cache_func_ = mod_get_func (" reset_kv_cache" );
214
236
if (this ->reset_kv_cache_func_ == nullptr ) {
215
237
this ->reset_kv_cache_func_ = get_global_func (" vm.builtin.attention_kv_cache_array_clear" );
@@ -260,6 +282,7 @@ struct FunctionTable {
260
282
PackedFunc reset_kv_cache_func_;
261
283
bool support_backtracking_kv_;
262
284
PackedFunc fkvcache_array_popn_;
285
+ ModelMetadata model_metadata_;
263
286
};
264
287
265
288
} // namespace
@@ -437,6 +460,7 @@ class LLMChat {
437
460
* \note This function overrides existing configurations.
438
461
*/
439
462
void LoadJSONOverride (const std::string& config_str, bool partial_update = false ) {
463
+ LOG (INFO) << " config_str = " << config_str;
440
464
picojson::value config_json;
441
465
std::string err = picojson::parse (config_json, config_str);
442
466
if (!err.empty ()) {
0 commit comments