Skip to content

Commit b0373d1

Browse files
authored
Support lib_path override in C++. Improvements on docs and error messages (mlc-ai#1086)
* Support lib_path option in C++ CLI. Disable ChatConfig.model_lib override in Python API. Improvements on helper messages and error messages * Update docs * Rename lib_path -> model_lib_path
1 parent 56a8004 commit b0373d1

File tree

4 files changed

+89
-36
lines changed

4 files changed

+89
-36
lines changed

cpp/cli_main.cc

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct ModelPaths {
163163
*/
164164
std::filesystem::path lib;
165165

166-
static ModelPaths Find(const std::string& device_name, const std::string& local_id);
166+
static ModelPaths Find(const std::string& device_name, const std::string& local_id, const std::string& user_lib_path);
167167
};
168168

169169
/*!
@@ -337,7 +337,7 @@ std::string ReadStringFromJSONFile(const std::filesystem::path& config_path,
337337
return config[key].get<std::string>();
338338
}
339339

340-
ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id) {
340+
ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& local_id, const std::string &user_lib_path) {
341341
// Step 1. Find config path
342342
std::filesystem::path config_path;
343343
if (auto path = TryInferMLCChatConfig(local_id)) {
@@ -368,26 +368,36 @@ ModelPaths ModelPaths::Find(const std::string& device_name, const std::string& l
368368
}
369369
std::cout << "Use model weights: " << params_json << std::endl;
370370
// Step 3. Find model lib path
371-
std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib");
372-
std::string lib_name = lib_local_id + "-" + device_name;
373371
std::filesystem::path lib_path;
374-
if (auto path = FindFile({lib_local_id,
372+
if (!user_lib_path.empty()) {
373+
lib_path = user_lib_path;
374+
if (!std::filesystem::exists(lib_path) || !std::filesystem::is_regular_file(lib_path)) {
375+
LOG(FATAL) << "The `lib_path` you passed in is not a file: " << user_lib_path << "\n";
376+
exit(1);
377+
}
378+
} else {
379+
std::string lib_local_id = ReadStringFromJSONFile(config_path, "model_lib");
380+
std::string lib_name = lib_local_id + "-" + device_name;
381+
if (auto path = FindFile({lib_local_id,
375382
"dist/prebuilt/lib", // Using prebuilt workflow
376383
"dist/" + local_id, "dist/prebuilt/" + lib_local_id},
377384
{
378385
lib_name + GetArchSuffix(),
379386
lib_name,
380387
},
381388
GetLibSuffixes())) {
382-
lib_path = path.value();
383-
} else {
384-
LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n"
385-
<< "We searched over the following possible paths: \n"
386-
<< "- " + lib_local_id << "\n"
387-
<< "- dist/prebuilt/lib \n"
388-
<< "- dist/" + local_id << "\n"
389-
<< "- dist/prebuilt/" + lib_local_id;
390-
exit(1);
389+
lib_path = path.value();
390+
} else {
391+
LOG(FATAL) << "Cannot find the model library that corresponds to `" << lib_local_id << "`.\n"
392+
<< "We searched over the following possible paths: \n"
393+
<< "- " + lib_local_id << "\n"
394+
<< "- dist/prebuilt/lib \n"
395+
<< "- dist/" + local_id << "\n"
396+
<< "- dist/prebuilt/" + lib_local_id << "\n"
397+
<< "If you would like to directly specify the full model library path, you may "
398+
<< "consider passing in the `--model-lib-path` argument.\n";
399+
exit(1);
400+
}
391401
}
392402
std::cout << "Use model library: " << lib_path << std::endl;
393403
return ModelPaths{config_path, params_json, lib_path};
@@ -427,8 +437,8 @@ void Converse(ChatModule* chat, const std::string& input, int stream_interval,
427437
* \param stream_interval The interval that should be used for streaming the response.
428438
*/
429439
void Chat(ChatModule* chat, const std::string& device_name, std::string local_id,
430-
int stream_interval = 2) {
431-
ModelPaths model = ModelPaths::Find(device_name, local_id);
440+
std::string lib_path, int stream_interval = 2) {
441+
ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path);
432442
PrintSpecialCommands();
433443
chat->Reload(model);
434444
chat->ProcessSystemPrompts();
@@ -456,7 +466,7 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id
456466
if (new_local_id.empty()) {
457467
new_local_id = local_id;
458468
}
459-
model = ModelPaths::Find(device_name, new_local_id);
469+
model = ModelPaths::Find(device_name, new_local_id, lib_path);
460470
chat->Reload(model);
461471
local_id = new_local_id;
462472
} else if (input.substr(0, 5) == "/help") {
@@ -470,7 +480,17 @@ void Chat(ChatModule* chat, const std::string& device_name, std::string local_id
470480
int main(int argc, char* argv[]) {
471481
argparse::ArgumentParser args("mlc_chat");
472482

473-
args.add_argument("--model");
483+
args.add_description("MLCChat CLI is the command line tool to run MLC-compiled LLMs out of the box.\n"
484+
"Note: the --model argument is required. It can either be the model name with its "
485+
"quantization scheme or a full path to the model folder. In the former case, the "
486+
"provided name will be used to search for the model folder over possible paths. "
487+
"--model-lib-path argument is optional. If unspecified, the --model argument will be used "
488+
"to search for the library file over possible paths.");
489+
490+
args.add_argument("--model")
491+
.help("[required] the model to use");
492+
args.add_argument("--model-lib-path")
493+
.help("[optional] the full path to the model library file to use");
474494
args.add_argument("--device").default_value("auto");
475495
args.add_argument("--evaluate").default_value(false).implicit_value(true);
476496
args.add_argument("--eval-prompt-len").default_value(128).scan<'i', int>();
@@ -485,6 +505,10 @@ int main(int argc, char* argv[]) {
485505
}
486506

487507
std::string local_id = args.get<std::string>("--model");
508+
std::string lib_path;
509+
if (args.present("--model-lib-path")) {
510+
lib_path = args.get<std::string>("--model-lib-path");
511+
}
488512
auto [device_name, device_id] = DetectDevice(args.get<std::string>("--device"));
489513

490514
try {
@@ -494,14 +518,14 @@ int main(int argc, char* argv[]) {
494518
// that are not supposed to be used in chat app setting
495519
int prompt_len = args.get<int>("--eval-prompt-len");
496520
int gen_len = args.get<int>("--eval-gen-len");
497-
ModelPaths model = ModelPaths::Find(device_name, local_id);
521+
ModelPaths model = ModelPaths::Find(device_name, local_id, lib_path);
498522
tvm::runtime::Module chat_mod = mlc::llm::CreateChatModule(GetDevice(device_name, device_id));
499523
std::string model_path = model.config.parent_path().string();
500524
tvm::runtime::Module lib = tvm::runtime::Module::LoadFromFile(model.lib.string());
501525
chat_mod.GetFunction("reload")(lib, tvm::String(model_path));
502526
chat_mod.GetFunction("evaluate")(prompt_len, gen_len);
503527
} else {
504-
Chat(&chat, device_name, local_id);
528+
Chat(&chat, device_name, local_id, lib_path);
505529
}
506530
} catch (const std::runtime_error& err) {
507531
std::cerr << err.what() << std::endl;

docs/deploy/cli.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o
111111
- Model lib should be placed at ``./dist/prebuilt/lib/$(local_id)-$(arch).$(suffix)``.
112112
- Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(local_id)/``.
113113
114+
.. note::
115+
Please make sure that you have the same directory structure as above, because the CLI tool
116+
relies on it to automatically search for model lib and weights. If you would like to directly
117+
provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument
118+
to the CLI
119+
114120
.. collapse:: Example
115121
116122
.. code:: shell
@@ -134,6 +140,12 @@ Once ``mlc_chat_cli`` is installed, you are able to run any MLC-compiled model o
134140
- Model libraries should be placed at ``./dist/$(local_id)/$(local_id)-$(arch).$(suffix)``.
135141
- Model weights and chat config are located under ``./dist/$(local_id)/params/``.
136142
143+
.. note::
144+
Please make sure that you have the same directory structure as above, because the CLI tool
145+
relies on it to automatically search for model lib and weights. If you would like to directly
146+
provide a full model lib path to override the auto-search, you can pass in a ``--model-lib-path`` argument
147+
to the CLI
148+
137149
.. collapse:: Example
138150
139151
.. code:: shell

docs/deploy/python.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ If you do not have the MLC-compiled ``model`` ready:
5151
- Model lib should be placed at ``./dist/prebuilt/lib/$(model)-$(arch).$(suffix)``.
5252
- Model weights and chat config are located under ``./dist/prebuilt/mlc-chat-$(model)/``.
5353
54+
.. note::
55+
Please make sure that you have the same directory structure as above, because Python API
56+
relies on it to automatically search for model lib and weights. If you would like to directly
57+
provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path``
58+
5459
.. collapse:: Example
5560
5661
.. code:: shell
@@ -74,6 +79,11 @@ If you do not have the MLC-compiled ``model`` ready:
7479
- Model libraries should be placed at ``./dist/$(model)/$(model)-$(arch).$(suffix)``.
7580
- Model weights and chat config are located under ``./dist/$(model)/params/``.
7681
82+
.. note::
83+
Please make sure that you have the same directory structure as above, because Python API
84+
relies on it to automatically search for model lib and weights. If you would like to directly
85+
provide a full model lib path to override the auto-search, you can specify ``ChatModule.model_lib_path``
86+
7787
.. collapse:: Example
7888
7989
.. code:: shell
@@ -157,7 +167,7 @@ You can also checkout the :doc:`/prebuilt_models` page to run other models.
157167
|
158168
159169
.. note::
160-
You could also specify the address of ``model`` and ``lib_path`` explicitly. If
170+
You could also specify the address of ``model`` and ``model_lib_path`` explicitly. If
161171
you only specify ``model`` as ``model_name`` and ``quantize_mode``, we will
162172
do a search for you. See more in the documentation of :meth:`mlc_chat.ChatModule.__init__`.
163173

python/mlc_chat/chat_module.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import sys
8+
import warnings
89
from dataclasses import asdict, dataclass, fields
910
from enum import Enum
1011
from typing import List, Optional
@@ -351,6 +352,12 @@ def _get_chat_config(config_file_path: str, user_chat_config: Optional[ChatConfi
351352
# We override using user's chat config
352353
for field in fields(user_chat_config):
353354
field_name = field.name
355+
if field_name == 'model_lib':
356+
warn_msg = ('WARNING: Do not override "model_lib" in ChatConfig. '
357+
'This override will be ignored. '
358+
'Please use ChatModule.model_lib_path to override the full model library path instead.')
359+
warnings.warn(warn_msg)
360+
continue
354361
field_value = getattr(user_chat_config, field_name)
355362
if field_value is not None:
356363
setattr(final_chat_config, field_name, field_value)
@@ -389,7 +396,7 @@ def _get_lib_module_path(
389396
model: str,
390397
model_path: str,
391398
chat_config: ChatConfig,
392-
lib_path: Optional[str],
399+
model_lib_path: Optional[str],
393400
device_name: str,
394401
config_file_path: str,
395402
) -> str:
@@ -403,7 +410,7 @@ def _get_lib_module_path(
403410
Model path found by `_get_model_path`.
404411
chat_config : ChatConfig
405412
Chat config after potential overrides. Returned by ``_get_chat_config``.
406-
lib_path : Optional[str]
413+
model_lib_path : Optional[str]
407414
User's input. Supposedly a full path to model library. Prioritized to use.
408415
device_name : str
409416
User's input. Used to construct the library model file name.
@@ -412,21 +419,21 @@ def _get_lib_module_path(
412419
413420
Returns
414421
------
415-
lib_path : str
422+
model_lib_path : str
416423
The path pointing to the model library we find.
417424
418425
Raises
419426
------
420427
FileNotFoundError: if we cannot find a valid model library file.
421428
"""
422-
# 1. Use user's lib_path if provided
423-
if lib_path is not None:
424-
if os.path.isfile(lib_path):
425-
logging.info(f"Using library model: {lib_path}")
426-
return lib_path
429+
# 1. Use user's model_lib_path if provided
430+
if model_lib_path is not None:
431+
if os.path.isfile(model_lib_path):
432+
logging.info(f"Using library model: {model_lib_path}")
433+
return model_lib_path
427434
else:
428435
err_msg = (
429-
f"The `lib_path` you passed in is not a file: {lib_path}.\nPlease checkout "
436+
f"The `model_lib_path` you passed in is not a file: {model_lib_path}.\nPlease checkout "
430437
f"{_PYTHON_GET_STARTED_TUTORIAL_URL} for an example on how to load a model."
431438
)
432439
raise FileNotFoundError(err_msg)
@@ -482,7 +489,7 @@ def _get_lib_module_path(
482489
err_msg += f"- {candidate}\n"
483490
err_msg += (
484491
"If you would like to directly specify the model library path, you may "
485-
"consider passing in the `lib_path` parameter.\n"
492+
"consider passing in the `ChatModule.model_lib_path` parameter.\n"
486493
f"Please checkout {_PYTHON_GET_STARTED_TUTORIAL_URL} for an example "
487494
"on how to load a model."
488495
)
@@ -659,7 +666,7 @@ class ChatModule:
659666
A ``ChatConfig`` instance partially filled. Will be used to override the
660667
``mlc-chat-config.json``.
661668
662-
lib_path : Optional[str]
669+
model_lib_path : Optional[str]
663670
The full path to the model library file to use (e.g. a ``.so`` file).
664671
If unspecified, we will use the provided ``model`` to search over
665672
possible paths.
@@ -670,7 +677,7 @@ def __init__(
670677
model: str,
671678
device: str = "auto",
672679
chat_config: Optional[ChatConfig] = None,
673-
lib_path: Optional[str] = None,
680+
model_lib_path: Optional[str] = None,
674681
):
675682
device_err_msg = (
676683
f"Invalid device name: {device}. Please enter the device in the form "
@@ -732,15 +739,15 @@ def __init__(
732739
self.chat_config = _get_chat_config(self.config_file_path, chat_config)
733740

734741
# 5. Look up model library
735-
self.lib_path = _get_lib_module_path(
736-
model, self.model_path, self.chat_config, lib_path, device_name, self.config_file_path
742+
self.model_lib_path = _get_lib_module_path(
743+
model, self.model_path, self.chat_config, model_lib_path, device_name, self.config_file_path
737744
)
738745

739746
# 6. Call reload
740747
user_chat_config_json_str = _convert_chat_config_to_json_str(
741748
self.chat_config, self.chat_config.conv_template
742749
)
743-
self._reload(self.lib_path, self.model_path, user_chat_config_json_str)
750+
self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str)
744751

745752
def generate(
746753
self,

0 commit comments

Comments
 (0)