3
3
import dataclasses
4
4
import unittest
5
5
6
- from mlc_llm import BuildArgs , utils , core
6
+ from mlc_llm import BuildArgs , core , utils
7
+
7
8
8
9
def old_make_args ():
9
10
"""The exact old way of creating `ArgumentParser`, used to test whether
10
- `BuildArgs` is equivalent to this. """
11
+ `BuildArgs` is equivalent to this."""
11
12
args = argparse .ArgumentParser ()
12
13
args .add_argument (
13
14
"--model" ,
@@ -17,7 +18,7 @@ def old_make_args():
17
18
'The name of the model to build. If it is "auto", we will '
18
19
'automatically set the model name according to "--model-path", '
19
20
'"hf-path" or the model folders under "--artifact-path/models"'
20
- )
21
+ ),
21
22
)
22
23
args .add_argument (
23
24
"--hf-path" ,
@@ -30,19 +31,16 @@ def old_make_args():
30
31
type = str ,
31
32
choices = [* utils .quantization_schemes .keys ()],
32
33
default = list (utils .quantization_schemes .keys ())[0 ],
33
- help = "The quantization mode we use to compile."
34
+ help = "The quantization mode we use to compile." ,
34
35
)
35
36
args .add_argument (
36
37
"--max-seq-len" ,
37
38
type = int ,
38
39
default = - 1 ,
39
- help = "The maximum allowed sequence length for the model."
40
+ help = "The maximum allowed sequence length for the model." ,
40
41
)
41
42
args .add_argument (
42
- "--target" ,
43
- type = str ,
44
- default = "auto" ,
45
- help = "The target platform to compile the model for."
43
+ "--target" , type = str , default = "auto" , help = "The target platform to compile the model for."
46
44
)
47
45
args .add_argument (
48
46
"--reuse-lib" ,
@@ -51,10 +49,7 @@ def old_make_args():
51
49
help = "Whether to reuse a previously generated lib." ,
52
50
)
53
51
args .add_argument (
54
- "--artifact-path" ,
55
- type = str ,
56
- default = "dist" ,
57
- help = "Where to store the output."
52
+ "--artifact-path" , type = str , default = "dist" , help = "Where to store the output."
58
53
)
59
54
args .add_argument (
60
55
"--use-cache" ,
@@ -66,13 +61,13 @@ def old_make_args():
66
61
"--debug-dump" ,
67
62
action = "store_true" ,
68
63
default = False ,
69
- help = "Whether to dump debugging files during compilation."
64
+ help = "Whether to dump debugging files during compilation." ,
70
65
)
71
66
args .add_argument (
72
67
"--debug-load-script" ,
73
68
action = "store_true" ,
74
69
default = False ,
75
- help = "Whether to load the script for debugging."
70
+ help = "Whether to load the script for debugging." ,
76
71
)
77
72
args .add_argument (
78
73
"--llvm-mingw" ,
@@ -81,10 +76,7 @@ def old_make_args():
81
76
help = "/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows." ,
82
77
)
83
78
args .add_argument (
84
- "--system-lib" ,
85
- action = "store_true" ,
86
- default = False ,
87
- help = "A parameter to `relax.build`."
79
+ "--system-lib" , action = "store_true" , default = False , help = "A parameter to `relax.build`."
88
80
)
89
81
args .add_argument (
90
82
"--sep-embed" ,
@@ -99,17 +91,20 @@ def old_make_args():
99
91
100
92
return args
101
93
94
+
102
95
# Referred to HfArgumentParserTest from https://github.com/huggingface/
103
96
# transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils
104
97
# /test_hf_argparser.py#L143
105
98
class BuildArgsTest (unittest .TestCase ):
106
99
"""Tests whether BuildArgs reaches parity with regular ArgumentParser."""
107
- def argparsers_equal ( self , parse_a : argparse . ArgumentParser ,
108
- parse_b : argparse .ArgumentParser ):
100
+
101
+ def argparsers_equal ( self , parse_a : argparse . ArgumentParser , parse_b : argparse .ArgumentParser ):
109
102
"""
110
103
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
111
104
"""
112
- self .assertEqual (len (parse_a ._actions ), len (parse_b ._actions )) # pylint: disable=protected-access
105
+ self .assertEqual (
106
+ len (parse_a ._actions ), len (parse_b ._actions )
107
+ ) # pylint: disable=protected-access
113
108
for x , y in zip (parse_a ._actions , parse_b ._actions ): # pylint: disable=protected-access
114
109
xx = {k : v for k , v in vars (x ).items () if k != "container" }
115
110
yy = {k : v for k , v in vars (y ).items () if k != "container" }
@@ -175,5 +170,6 @@ def test_namespaces_are_equivalent_str_boolean_int(self):
175
170
build_args_namespace = argparse .Namespace (** build_args_as_dict )
176
171
self .assertNotEqual (build_args_namespace , parsed_args )
177
172
178
- if __name__ == '__main__' :
173
+
174
+ if __name__ == "__main__" :
179
175
unittest .main ()
0 commit comments