@@ -163,7 +163,7 @@ struct ModelPaths {
163
163
*/
164
164
std::filesystem::path lib;
165
165
166
- static ModelPaths Find (const std::string& device_name, const std::string& local_id);
166
+ static ModelPaths Find (const std::string& device_name, const std::string& local_id, const std::string& user_lib_path );
167
167
};
168
168
169
169
/* !
@@ -337,7 +337,7 @@ std::string ReadStringFromJSONFile(const std::filesystem::path& config_path,
337
337
return config[key].get <std::string>();
338
338
}
339
339
340
- ModelPaths ModelPaths::Find (const std::string& device_name, const std::string& local_id) {
340
+ ModelPaths ModelPaths::Find (const std::string& device_name, const std::string& local_id, const std::string &user_lib_path ) {
341
341
// Step 1. Find config path
342
342
std::filesystem::path config_path;
343
343
if (auto path = TryInferMLCChatConfig (local_id)) {
@@ -368,26 +368,36 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l
368
368
}
369
369
std::cout << " Use model weights: " << params_json << std::endl;
370
370
// Step 3. Find model lib path
371
- std::string lib_local_id = ReadStringFromJSONFile (config_path, " model_lib" );
372
- std::string lib_name = lib_local_id + " -" + device_name;
373
371
std::filesystem::path lib_path;
374
- if (auto path = FindFile ({lib_local_id,
372
+ if (!user_lib_path.empty ()) {
373
+ lib_path = user_lib_path;
374
+ if (!std::filesystem::exists (lib_path) || !std::filesystem::is_regular_file (lib_path)) {
375
+ LOG (FATAL) << " The `lib_path` you passed in is not a file: " << user_lib_path << " \n " ;
376
+ exit (1 );
377
+ }
378
+ } else {
379
+ std::string lib_local_id = ReadStringFromJSONFile (config_path, " model_lib" );
380
+ std::string lib_name = lib_local_id + " -" + device_name;
381
+ if (auto path = FindFile ({lib_local_id,
375
382
" dist/prebuilt/lib" , // Using prebuilt workflow
376
383
" dist/" + local_id, " dist/prebuilt/" + lib_local_id},
377
384
{
378
385
lib_name + GetArchSuffix (),
379
386
lib_name,
380
387
},
381
388
GetLibSuffixes ())) {
382
- lib_path = path.value ();
383
- } else {
384
- LOG (FATAL) << " Cannot find the model library that corresponds to `" << lib_local_id << " `.\n "
385
- << " We searched over the following possible paths: \n "
386
- << " - " + lib_local_id << " \n "
387
- << " - dist/prebuilt/lib \n "
388
- << " - dist/" + local_id << " \n "
389
- << " - dist/prebuilt/" + lib_local_id;
390
- exit (1 );
389
+ lib_path = path.value ();
390
+ } else {
391
+ LOG (FATAL) << " Cannot find the model library that corresponds to `" << lib_local_id << " `.\n "
392
+ << " We searched over the following possible paths: \n "
393
+ << " - " + lib_local_id << " \n "
394
+ << " - dist/prebuilt/lib \n "
395
+ << " - dist/" + local_id << " \n "
396
+ << " - dist/prebuilt/" + lib_local_id << " \n "
397
+ << " If you would like to directly specify the full model library path, you may "
398
+ << " consider passing in the `--model-lib-path` argument.\n " ;
399
+ exit (1 );
400
+ }
391
401
}
392
402
std::cout << " Use model library: " << lib_path << std::endl;
393
403
return ModelPaths{config_path, params_json, lib_path};
@@ -427,8 +437,8 @@ void Converse(ChatModule* chat, const std::string& input, int stream_interval,
427
437
* \param stream_interval The interval that should be used for streaming the response.
428
438
*/
429
439
void Chat (ChatModule* chat, const std::string& device_name, std::string local_id,
430
- int stream_interval = 2 ) {
431
- ModelPaths model = ModelPaths::Find (device_name, local_id);
440
+ std::string lib_path, int stream_interval = 2 ) {
441
+ ModelPaths model = ModelPaths::Find (device_name, local_id, lib_path );
432
442
PrintSpecialCommands ();
433
443
chat->Reload (model);
434
444
chat->ProcessSystemPrompts ();
@@ -456,7 +466,7 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id
456
466
if (new_local_id.empty ()) {
457
467
new_local_id = local_id;
458
468
}
459
- model = ModelPaths::Find (device_name, new_local_id);
469
+ model = ModelPaths::Find (device_name, new_local_id, lib_path );
460
470
chat->Reload (model);
461
471
local_id = new_local_id;
462
472
} else if (input.substr (0 , 5 ) == " /help" ) {
@@ -470,7 +480,17 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id
470
480
int main (int argc, char * argv[]) {
471
481
argparse::ArgumentParser args (" mlc_chat" );
472
482
473
- args.add_argument (" --model" );
483
+ args.add_description (" MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n "
484
+ " Note: the --model argument is required. It can either be the model name with its "
485
+ " quantization scheme or a full path to the model folder. In the former case, the "
486
+ " provided name will be used to search for the model folder over possible paths. "
487
+ " --model-lib-path argument is optional. If unspecified, the --model argument will be used "
488
+ " to search for the library file over possible paths." );
489
+
490
+ args.add_argument (" --model" )
491
+ .help (" [required] the model to use" );
492
+ args.add_argument (" --model-lib-path" )
493
+ .help (" [optional] the full path to the model library file to use" );
474
494
args.add_argument (" --device" ).default_value (" auto" );
475
495
args.add_argument (" --evaluate" ).default_value (false ).implicit_value (true );
476
496
args.add_argument (" --eval-prompt-len" ).default_value (128 ).scan <' i' , int >();
@@ -485,6 +505,10 @@ int main(int argc, char* argv[]) {
485
505
}
486
506
487
507
std::string local_id = args.get <std::string>(" --model" );
508
+ std::string lib_path;
509
+ if (args.present (" --model-lib-path" )) {
510
+ lib_path = args.get <std::string>(" --model-lib-path" );
511
+ }
488
512
auto [device_name, device_id] = DetectDevice (args.get <std::string>(" --device" ));
489
513
490
514
try {
@@ -494,14 +518,14 @@ int main(int argc, char* argv[]) {
494
518
// that are not supposed to be used in chat app setting
495
519
int prompt_len = args.get <int >(" --eval-prompt-len" );
496
520
int gen_len = args.get <int >(" --eval-gen-len" );
497
- ModelPaths model = ModelPaths::Find (device_name, local_id);
521
+ ModelPaths model = ModelPaths::Find (device_name, local_id, lib_path );
498
522
tvm::runtime::Module chat_mod = mlc::llm::CreateChatModule (GetDevice (device_name, device_id));
499
523
std::string model_path = model.config .parent_path ().string ();
500
524
tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile (model.lib .string ());
501
525
chat_mod.GetFunction (" reload" )(lib, tvm::String (model_path));
502
526
chat_mod.GetFunction (" evaluate" )(prompt_len, gen_len);
503
527
} else {
504
- Chat (&chat, device_name, local_id);
528
+ Chat (&chat, device_name, local_id, lib_path );
505
529
}
506
530
} catch (const std::runtime_error& err) {
507
531
std::cerr << err.what () << std::endl;
0 commit comments