|
1 | 1 | """Parameter mapping for converting different LLM implementations to MLC LLM."""
|
2 | 2 | import dataclasses
|
3 |
| -from typing import Callable, Dict, List, Set |
| 3 | +from typing import Callable, Dict, List, Set, Union |
4 | 4 |
|
5 | 5 | import numpy as np
|
6 | 6 | from tvm.runtime import NDArray
|
7 | 7 |
|
| 8 | +MapFuncVariadic = Union[ |
| 9 | + Callable[[], np.ndarray], |
| 10 | + Callable[[np.ndarray], np.ndarray], |
| 11 | + Callable[[np.ndarray, np.ndarray], np.ndarray], |
| 12 | + Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray], |
| 13 | + Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray], |
| 14 | +] |
| 15 | + |
8 | 16 |
|
9 | 17 | @dataclasses.dataclass
|
10 | 18 | class ExternMapping:
|
@@ -33,8 +41,8 @@ class ExternMapping:
|
33 | 41 | """
|
34 | 42 |
|
35 | 43 | param_map: Dict[str, List[str]]
|
36 |
| - map_func: Dict[str, Callable[[np.ndarray, ...], np.ndarray]] |
37 |
| - unused_params: Set[str] = dataclasses.field(default_factory=dict) |
| 44 | + map_func: Dict[str, MapFuncVariadic] |
| 45 | + unused_params: Set[str] = dataclasses.field(default_factory=set) |
38 | 46 |
|
39 | 47 |
|
40 | 48 | @dataclasses.dataclass
|
@@ -72,8 +80,8 @@ class QuantizeMapping:
|
72 | 80 | used to convert the quantized parameters into the desired form.
|
73 | 81 | """
|
74 | 82 |
|
75 |
| - param_map: Dict[str, Callable[str, List[str]]] |
76 |
| - map_func: Dict[str, Callable[NDArray, List[NDArray]]] |
| 83 | + param_map: Dict[str, Callable[[str], List[str]]] |
| 84 | + map_func: Dict[str, Callable[[NDArray], List[NDArray]]] |
77 | 85 |
|
78 | 86 |
|
79 | 87 | __all__ = ["ExternMapping", "QuantizeMapping"]
|
0 commit comments