Skip to content

Commit 2ec0cc8

Browse files
authored
Minor enhancements to ChatModule (mlc-ai#1132)
Some minor enhancements to `ChatModule`, mainly handle the device parsing solely in `_parse_device_str` instead of handling it both in the member function and the `__init__` function to avoid redundancy; and some type annotation fix.
1 parent 2c492e5 commit 2ec0cc8

File tree

2 files changed

+43
-41
lines changed

2 files changed

+43
-41
lines changed

cpp/llm_chat.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ class LLMChat {
437437

438438
/*!
439439
* \brief Reload model, tokenizers and configurations from the specified model path.
440-
* \param executable The module to reload.
440+
* \param reload_lib The module to reload, it can either be a path to the library or a tvm Module.
441441
* \param model_path The path to search for models.
442442
* \param app_config_json The JSON string used to partially override the configuration loaded from
443443
* disk, default to empty string.

python/mlc_chat/chat_module.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -532,14 +532,16 @@ def _get_lib_module_path(
532532
raise FileNotFoundError(err_msg)
533533

534534

535-
def _convert_chat_config_to_json_str(chat_config: Optional[ChatConfig], conv_template: str) -> str:
535+
def _convert_chat_config_to_json_str(
536+
chat_config: Optional[ChatConfig], conv_template: Optional[str]
537+
) -> str:
536538
"""Convert user's input ChatConfig to a json string, omitting ``None`` fields.
537539
538540
Parameters
539541
----------
540542
chat_config : Optional[ChatConfig]
541543
User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``.
542-
conv_template : str
544+
conv_template : Optional[str]
543545
The ``conv_template`` that will be used after considering potential override.
544546
545547
Returns
@@ -591,7 +593,7 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio
591593
return json.dumps(asdict(generation_config))
592594

593595

594-
def _parse_device_str(device: str):
596+
def _parse_device_str(device: str) -> (tvm.runtime.Device, str):
595597
"""Parse the input device identifier into device name and id.
596598
597599
Parameters
@@ -603,11 +605,11 @@ def _parse_device_str(device: str):
603605
604606
Returns
605607
-------
608+
dev : tvm.runtime.Device
609+
The device.
610+
606611
device_name : str
607612
The name of the device.
608-
609-
device_id : int
610-
The id of the device, or 0 if not specified in the input.
611613
"""
612614
device_err_msg = (
613615
f"Invalid device name: {device}. Please enter the device in the form "
@@ -616,14 +618,32 @@ def _parse_device_str(device: str):
616618
)
617619
device_args = device.split(":")
618620
if len(device_args) == 1:
619-
return device_args[0], 0
621+
device_name, device_id = device_args[0], 0
620622
elif len(device_args) == 2:
621-
return device_args[0], int(device_args[1])
623+
device_name, device_id = device_args[0], int(device_args[1])
622624
elif len(device_args) > 2:
623625
raise ValueError(device_err_msg)
624626

627+
if device_name == "cuda":
628+
device = tvm.cuda(device_id)
629+
elif device_name == "metal":
630+
device = tvm.metal(device_id)
631+
elif device_name == "vulkan":
632+
device = tvm.vulkan(device_id)
633+
elif device_name == "rocm":
634+
device = tvm.rocm(device_id)
635+
elif device_name == "opencl":
636+
device = tvm.opencl(device_id)
637+
elif device_name == "auto":
638+
device, device_name = _detect_local_device(device_id)
639+
logging.info(f"System automatically detected device: {device_name}")
640+
else:
641+
raise ValueError(device_err_msg)
642+
643+
return device, device_name
625644

626-
def _detect_local_device(device_id: int = 0):
645+
646+
def _detect_local_device(device_id: int = 0) -> (tvm.runtime.Device, str):
627647
"""Automatically detect the local device if user does not specify.
628648
629649
Parameters
@@ -633,8 +653,11 @@ def _detect_local_device(device_id: int = 0):
633653
634654
Returns
635655
------
636-
dev : Device
656+
dev : tvm.runtime.Device
637657
The local device.
658+
659+
device_name : str
660+
The name of the device.
638661
"""
639662
if tvm.metal().exist:
640663
return tvm.metal(device_id), "metal"
@@ -715,34 +738,13 @@ def __init__(
715738
chat_config: Optional[ChatConfig] = None,
716739
model_lib_path: Optional[str] = None,
717740
):
718-
device_err_msg = (
719-
f"Invalid device name: {device}. Please enter the device in the form "
720-
"'device_name:device_id' or 'device_name', where 'device_name' needs to be "
721-
"one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'."
722-
)
723-
724-
# 0. Retrieve device_name and device_id (if any, default 0) from device arg
725-
device_name, device_id = _parse_device_str(device)
726-
727-
# 1. Get self.device
728-
if device_name == "cuda":
729-
self.device = tvm.cuda(device_id)
730-
elif device_name == "metal":
731-
self.device = tvm.metal(device_id)
732-
elif device_name == "vulkan":
733-
self.device = tvm.vulkan(device_id)
734-
elif device_name == "rocm":
735-
self.device = tvm.rocm(device_id)
736-
elif device_name == "opencl":
737-
self.device = tvm.opencl(device_id)
738-
elif device_name == "auto":
739-
self.device, device_name = _detect_local_device(device_id)
740-
logging.info(f"System automatically detected device: {device_name}")
741-
else:
742-
raise ValueError(device_err_msg)
741+
# 0. Get device:
742+
# Retrieve device_name and device_id (if any, default 0) from device arg
743+
self.device, device_name = _parse_device_str(device)
743744
device_type = self.device.device_type
745+
device_id = self.device.device_id
744746

745-
# 2. Populate chat module and their functions
747+
# 1. Populate chat module and their functions
746748
fcreate_chat_mod = tvm.get_global_func("mlc.llm_chat_create")
747749
assert fcreate_chat_mod is not None
748750
chat_mod = fcreate_chat_mod(device_type, device_id)
@@ -768,13 +770,13 @@ def __init__(
768770
self._get_role0_func = chat_mod["get_role0"]
769771
self._get_role1_func = chat_mod["get_role1"]
770772

771-
# 3. Look up model_path
773+
# 2. Look up model_path
772774
self.model_path, self.config_file_path = _get_model_path(model)
773775

774-
# 4. Instantiate chat_config
776+
# 3. Instantiate chat_config
775777
self.chat_config = _get_chat_config(self.config_file_path, chat_config)
776778

777-
# 5. Look up model library
779+
# 4. Look up model library
778780
self.model_lib_path = _get_lib_module_path(
779781
model,
780782
self.model_path,
@@ -784,7 +786,7 @@ def __init__(
784786
self.config_file_path,
785787
)
786788

787-
# 6. Call reload
789+
# 5. Call reload
788790
user_chat_config_json_str = _convert_chat_config_to_json_str(
789791
self.chat_config, self.chat_config.conv_template
790792
)

0 commit comments

Comments
 (0)