Skip to content

Commit 9872c48

Browse files
authored
[Python] Extract common device str parse function in ChatModule (mlc-ai#1074)
This PR lifts the device string parsing (just a few of lines) to a standalone function, so that on the serving side the serving can make use of this function as well. Tested Python API and it does not seem to incur regression.
1 parent d202077 commit 9872c48

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

python/mlc_chat/chat_module.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,38 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio
548548
return json.dumps(asdict(generation_config))
549549

550550

551+
def _parse_device_str(device: str):
552+
"""Parse the input device identifier into device name and id.
553+
554+
Parameters
555+
----------
556+
device : str
557+
The device identifier to parse.
558+
It can be "device_name" (e.g., "cuda") or
559+
"device_name:device_id" (e.g., "cuda:1").
560+
561+
Returns
562+
-------
563+
device_name : str
564+
The name of the device.
565+
566+
device_id : int
567+
The id of the device, or 0 if not specified in the input.
568+
"""
569+
device_err_msg = (
570+
f"Invalid device name: {device}. Please enter the device in the form "
571+
"'device_name:device_id' or 'device_name', where 'device_name' needs to be "
572+
"one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'."
573+
)
574+
device_args = device.split(":")
575+
if len(device_args) == 1:
576+
return device_args[0], 0
577+
elif len(device_args) == 2:
578+
return device_args[0], int(device_args[1])
579+
elif len(device_args) > 2:
580+
raise ValueError(device_err_msg)
581+
582+
551583
def _detect_local_device(device_id: int = 0):
552584
"""Automatically detect the local device if user does not specify.
553585
@@ -647,13 +679,7 @@ def __init__(
647679
)
648680

649681
# 0. Retrieve device_name and device_id (if any, default 0) from device arg
650-
device_args = device.split(":")
651-
if len(device_args) == 1:
652-
device_name, device_id = device_args[0], 0
653-
elif len(device_args) == 2:
654-
device_name, device_id = device_args[0], int(device_args[1])
655-
elif len(device_args) > 2:
656-
raise ValueError(device_err_msg)
682+
device_name, device_id = _parse_device_str(device)
657683

658684
# 1. Get self.device
659685
if device_name == "cuda":

0 commit comments

Comments
 (0)