@@ -532,14 +532,16 @@ def _get_lib_module_path(
532
532
raise FileNotFoundError (err_msg )
533
533
534
534
535
- def _convert_chat_config_to_json_str (chat_config : Optional [ChatConfig ], conv_template : str ) -> str :
535
+ def _convert_chat_config_to_json_str (
536
+ chat_config : Optional [ChatConfig ], conv_template : Optional [str ]
537
+ ) -> str :
536
538
"""Convert user's input ChatConfig to a json string, omitting ``None`` fields.
537
539
538
540
Parameters
539
541
----------
540
542
chat_config : Optional[ChatConfig]
541
543
User's input. A partial ChatConfig for overriding ``mlc-chat-config.json``.
542
- conv_template : str
544
+ conv_template : Optional[ str]
543
545
The ``conv_template`` that will be used after considering potential override.
544
546
545
547
Returns
@@ -591,7 +593,7 @@ def _convert_generation_config_to_json_str(generation_config: Optional[Generatio
591
593
return json .dumps (asdict (generation_config ))
592
594
593
595
594
- def _parse_device_str (device : str ):
596
+ def _parse_device_str (device : str ) -> ( tvm . runtime . Device , str ) :
595
597
"""Parse the input device identifier into device name and id.
596
598
597
599
Parameters
@@ -603,11 +605,11 @@ def _parse_device_str(device: str):
603
605
604
606
Returns
605
607
-------
608
+ dev : tvm.runtime.Device
609
+ The device.
610
+
606
611
device_name : str
607
612
The name of the device.
608
-
609
- device_id : int
610
- The id of the device, or 0 if not specified in the input.
611
613
"""
612
614
device_err_msg = (
613
615
f"Invalid device name: { device } . Please enter the device in the form "
@@ -616,14 +618,32 @@ def _parse_device_str(device: str):
616
618
)
617
619
device_args = device .split (":" )
618
620
if len (device_args ) == 1 :
619
- return device_args [0 ], 0
621
+ device_name , device_id = device_args [0 ], 0
620
622
elif len (device_args ) == 2 :
621
- return device_args [0 ], int (device_args [1 ])
623
+ device_name , device_id = device_args [0 ], int (device_args [1 ])
622
624
elif len (device_args ) > 2 :
623
625
raise ValueError (device_err_msg )
624
626
627
+ if device_name == "cuda" :
628
+ device = tvm .cuda (device_id )
629
+ elif device_name == "metal" :
630
+ device = tvm .metal (device_id )
631
+ elif device_name == "vulkan" :
632
+ device = tvm .vulkan (device_id )
633
+ elif device_name == "rocm" :
634
+ device = tvm .rocm (device_id )
635
+ elif device_name == "opencl" :
636
+ device = tvm .opencl (device_id )
637
+ elif device_name == "auto" :
638
+ device , device_name = _detect_local_device (device_id )
639
+ logging .info (f"System automatically detected device: { device_name } " )
640
+ else :
641
+ raise ValueError (device_err_msg )
642
+
643
+ return device , device_name
625
644
626
- def _detect_local_device (device_id : int = 0 ):
645
+
646
+ def _detect_local_device (device_id : int = 0 ) -> (tvm .runtime .Device , str ):
627
647
"""Automatically detect the local device if user does not specify.
628
648
629
649
Parameters
@@ -633,8 +653,11 @@ def _detect_local_device(device_id: int = 0):
633
653
634
654
Returns
635
655
------
636
- dev : Device
656
+ dev : tvm.runtime. Device
637
657
The local device.
658
+
659
+ device_name : str
660
+ The name of the device.
638
661
"""
639
662
if tvm .metal ().exist :
640
663
return tvm .metal (device_id ), "metal"
@@ -715,34 +738,13 @@ def __init__(
715
738
chat_config : Optional [ChatConfig ] = None ,
716
739
model_lib_path : Optional [str ] = None ,
717
740
):
718
- device_err_msg = (
719
- f"Invalid device name: { device } . Please enter the device in the form "
720
- "'device_name:device_id' or 'device_name', where 'device_name' needs to be "
721
- "one of 'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto'."
722
- )
723
-
724
- # 0. Retrieve device_name and device_id (if any, default 0) from device arg
725
- device_name , device_id = _parse_device_str (device )
726
-
727
- # 1. Get self.device
728
- if device_name == "cuda" :
729
- self .device = tvm .cuda (device_id )
730
- elif device_name == "metal" :
731
- self .device = tvm .metal (device_id )
732
- elif device_name == "vulkan" :
733
- self .device = tvm .vulkan (device_id )
734
- elif device_name == "rocm" :
735
- self .device = tvm .rocm (device_id )
736
- elif device_name == "opencl" :
737
- self .device = tvm .opencl (device_id )
738
- elif device_name == "auto" :
739
- self .device , device_name = _detect_local_device (device_id )
740
- logging .info (f"System automatically detected device: { device_name } " )
741
- else :
742
- raise ValueError (device_err_msg )
741
+ # 0. Get device:
742
+ # Retrieve device_name and device_id (if any, default 0) from device arg
743
+ self .device , device_name = _parse_device_str (device )
743
744
device_type = self .device .device_type
745
+ device_id = self .device .device_id
744
746
745
- # 2 . Populate chat module and their functions
747
+ # 1 . Populate chat module and their functions
746
748
fcreate_chat_mod = tvm .get_global_func ("mlc.llm_chat_create" )
747
749
assert fcreate_chat_mod is not None
748
750
chat_mod = fcreate_chat_mod (device_type , device_id )
@@ -768,13 +770,13 @@ def __init__(
768
770
self ._get_role0_func = chat_mod ["get_role0" ]
769
771
self ._get_role1_func = chat_mod ["get_role1" ]
770
772
771
- # 3 . Look up model_path
773
+ # 2 . Look up model_path
772
774
self .model_path , self .config_file_path = _get_model_path (model )
773
775
774
- # 4 . Instantiate chat_config
776
+ # 3 . Instantiate chat_config
775
777
self .chat_config = _get_chat_config (self .config_file_path , chat_config )
776
778
777
- # 5 . Look up model library
779
+ # 4 . Look up model library
778
780
self .model_lib_path = _get_lib_module_path (
779
781
model ,
780
782
self .model_path ,
@@ -784,7 +786,7 @@ def __init__(
784
786
self .config_file_path ,
785
787
)
786
788
787
- # 6 . Call reload
789
+ # 5 . Call reload
788
790
user_chat_config_json_str = _convert_chat_config_to_json_str (
789
791
self .chat_config , self .chat_config .conv_template
790
792
)
0 commit comments