Skip to content

Commit e9b85ce

Browse files
authored
More formatting (mlc-ai#1099)
1 parent cf39bf6 commit e9b85ce

File tree

6 files changed

+92
-99
lines changed

6 files changed

+92
-99
lines changed

tests/debug/compare_lib.py renamed to tests/python/legacy/compare_lib.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
from typing import List
2-
31
import argparse
4-
import os
52
import json
3+
import os
4+
from typing import List
65

7-
import tvm
8-
from tvm import relax
9-
from tvm import rpc
10-
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
116
import numpy as np
12-
137
import torch
8+
import tvm
149
from transformers import AutoTokenizer, LlamaTokenizer
10+
from tvm import relax, rpc
11+
from tvm.relax.testing.lib_comparator import LibCompareVMInstrument
1512

1613
from mlc_llm import utils
1714

@@ -53,7 +50,7 @@ def compare(
5350

5451
if self.time_eval and name not in self.time_eval_results:
5552
res = self.mod.time_evaluator(
56-
name, self.device, number=20, repeat=3#, cache_flush_bytes=256 * 10**6
53+
name, self.device, number=20, repeat=3 # , cache_flush_bytes=256 * 10**6
5754
)(*new_args)
5855
self.time_eval_results[name] = (res.mean, 1)
5956
print(f"Time-eval result {name} on {self.device}: {res}")
@@ -121,9 +118,7 @@ def __init__(self, args):
121118
)
122119
)
123120
self.cmp_device = tvm.device(args.cmp_device)
124-
self.const_params_dict = utils.load_params(
125-
args.artifact_path, self.primary_device
126-
)
121+
self.const_params_dict = utils.load_params(args.artifact_path, self.primary_device)
127122
self.cmp_instrument = LibCompare(
128123
self.lib,
129124
self.cmp_device,
@@ -134,9 +129,7 @@ def __init__(self, args):
134129

135130

136131
def deploy_to_pipeline(args) -> None:
137-
with open(
138-
os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r"
139-
) as f:
132+
with open(os.path.join(args.artifact_path, "params", "mlc-chat-config.json"), "r") as f:
140133
config = json.load(f)
141134

142135
primary_device = tvm.device(args.primary_device)
@@ -157,18 +150,14 @@ def deploy_to_pipeline(args) -> None:
157150
tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy(),
158151
primary_device,
159152
)
160-
first_sampled_token = tvm.nd.array(
161-
np.array([[6234]]).astype("int32"), primary_device
162-
)
153+
first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device)
163154
seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]])
164155
second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1])
165156
kv_caches = state.vm["create_kv_cache"]()
166157

167158
print("Running inference...")
168159
print("======================= Starts Encoding =======================")
169-
logits, kv_caches = state.vm["prefill"](
170-
inputs, seq_len_shape, kv_caches, const_params
171-
)
160+
logits, kv_caches = state.vm["prefill"](inputs, seq_len_shape, kv_caches, const_params)
172161
print_as_table(
173162
sorted(
174163
state.cmp_instrument.time_eval_results.items(),

tests/debug/dump_intermediate.py renamed to tests/python/legacy/dump_intermediate.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import argparse
22
import os
3+
import pickle
34

45
import numpy as np
56
import torch
67
import tvm
78
from transformers import AutoTokenizer
89
from tvm import relax
9-
import pickle
1010

1111
from mlc_llm import utils
1212

@@ -77,12 +77,8 @@ def deploy_to_pipeline(args) -> None:
7777
)
7878

7979
print("Tokenizing...")
80-
inputs = (
81-
tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy()
82-
)
83-
first_sampled_token = tvm.nd.array(
84-
np.array([[6234]]).astype("int32"), primary_device
85-
)
80+
inputs = tokenizer(args.prompt, return_tensors="pt").input_ids.to(torch.int32).numpy()
81+
first_sampled_token = tvm.nd.array(np.array([[6234]]).astype("int32"), primary_device)
8682
seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1]])
8783
second_seq_len_shape = tvm.runtime.ShapeTuple([inputs.shape[1] + 1])
8884
kv_caches = state.vm["create_kv_cache"]()

tests/evaluate.py renamed to tests/python/legacy/evaluate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def compare(
5858
repeat=3,
5959
)(*new_args).mean
6060
shapes = [arg.shape for arg in new_args]
61-
total_bytes = sum(
62-
arg.numpy().size * arg.numpy().itemsize for arg in new_args
63-
)
61+
total_bytes = sum(arg.numpy().size * arg.numpy().itemsize for arg in new_args)
6462
self.time_eval_results[name] = (res, 1, shapes, total_bytes)
6563
else:
6664
record = self.time_eval_results[name]
@@ -177,9 +175,7 @@ def deploy_to_pipeline(args) -> None: # pylint: disable=too-many-locals
177175
print("Profiling...")
178176
kv_caches = vm["create_kv_cache"]()
179177

180-
logits, kv_caches = vm["prefill"](
181-
inputs, seq_len_shape, kv_caches, const_params
182-
)
178+
logits, kv_caches = vm["prefill"](inputs, seq_len_shape, kv_caches, const_params)
183179
print("======================= Encoding Profiling =======================")
184180
print_as_table(
185181
sorted(

tests/python/test_build_args.py renamed to tests/python/legacy/test_build_args.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import dataclasses
44
import unittest
55

6-
from mlc_llm import BuildArgs, utils, core
6+
from mlc_llm import BuildArgs, core, utils
7+
78

89
def old_make_args():
910
"""The exact old way of creating `ArgumentParser`, used to test whether
10-
`BuildArgs` is equivalent to this. """
11+
`BuildArgs` is equivalent to this."""
1112
args = argparse.ArgumentParser()
1213
args.add_argument(
1314
"--model",
@@ -17,7 +18,7 @@ def old_make_args():
1718
'The name of the model to build. If it is "auto", we will '
1819
'automatically set the model name according to "--model-path", '
1920
'"hf-path" or the model folders under "--artifact-path/models"'
20-
)
21+
),
2122
)
2223
args.add_argument(
2324
"--hf-path",
@@ -30,19 +31,16 @@ def old_make_args():
3031
type=str,
3132
choices=[*utils.quantization_schemes.keys()],
3233
default=list(utils.quantization_schemes.keys())[0],
33-
help="The quantization mode we use to compile."
34+
help="The quantization mode we use to compile.",
3435
)
3536
args.add_argument(
3637
"--max-seq-len",
3738
type=int,
3839
default=-1,
39-
help="The maximum allowed sequence length for the model."
40+
help="The maximum allowed sequence length for the model.",
4041
)
4142
args.add_argument(
42-
"--target",
43-
type=str,
44-
default="auto",
45-
help="The target platform to compile the model for."
43+
"--target", type=str, default="auto", help="The target platform to compile the model for."
4644
)
4745
args.add_argument(
4846
"--reuse-lib",
@@ -51,10 +49,7 @@ def old_make_args():
5149
help="Whether to reuse a previously generated lib.",
5250
)
5351
args.add_argument(
54-
"--artifact-path",
55-
type=str,
56-
default="dist",
57-
help="Where to store the output."
52+
"--artifact-path", type=str, default="dist", help="Where to store the output."
5853
)
5954
args.add_argument(
6055
"--use-cache",
@@ -66,13 +61,13 @@ def old_make_args():
6661
"--debug-dump",
6762
action="store_true",
6863
default=False,
69-
help="Whether to dump debugging files during compilation."
64+
help="Whether to dump debugging files during compilation.",
7065
)
7166
args.add_argument(
7267
"--debug-load-script",
7368
action="store_true",
7469
default=False,
75-
help="Whether to load the script for debugging."
70+
help="Whether to load the script for debugging.",
7671
)
7772
args.add_argument(
7873
"--llvm-mingw",
@@ -81,10 +76,7 @@ def old_make_args():
8176
help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows.",
8277
)
8378
args.add_argument(
84-
"--system-lib",
85-
action="store_true",
86-
default=False,
87-
help="A parameter to `relax.build`."
79+
"--system-lib", action="store_true", default=False, help="A parameter to `relax.build`."
8880
)
8981
args.add_argument(
9082
"--sep-embed",
@@ -99,17 +91,20 @@ def old_make_args():
9991

10092
return args
10193

94+
10295
# Referred to HfArgumentParserTest from https://github.com/huggingface/
10396
# transformers/blob/e84bf1f734f87aa2bedc41b9b9933d00fc6add98/tests/utils
10497
# /test_hf_argparser.py#L143
10598
class BuildArgsTest(unittest.TestCase):
10699
"""Tests whether BuildArgs reaches parity with regular ArgumentParser."""
107-
def argparsers_equal(self, parse_a: argparse.ArgumentParser,
108-
parse_b: argparse.ArgumentParser):
100+
101+
def argparsers_equal(self, parse_a: argparse.ArgumentParser, parse_b: argparse.ArgumentParser):
109102
"""
110103
Small helper to check pseudo-equality of parsed arguments on `ArgumentParser` instances.
111104
"""
112-
self.assertEqual(len(parse_a._actions), len(parse_b._actions)) # pylint: disable=protected-access
105+
self.assertEqual(
106+
len(parse_a._actions), len(parse_b._actions)
107+
) # pylint: disable=protected-access
113108
for x, y in zip(parse_a._actions, parse_b._actions): # pylint: disable=protected-access
114109
xx = {k: v for k, v in vars(x).items() if k != "container"}
115110
yy = {k: v for k, v in vars(y).items() if k != "container"}
@@ -175,5 +170,6 @@ def test_namespaces_are_equivalent_str_boolean_int(self):
175170
build_args_namespace = argparse.Namespace(**build_args_as_dict)
176171
self.assertNotEqual(build_args_namespace, parsed_args)
177172

178-
if __name__ == '__main__':
173+
174+
if __name__ == "__main__":
179175
unittest.main()

0 commit comments

Comments
 (0)