From 83622a1b577539518ffc23f4fd609b6b383f1972 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 28 Oct 2023 23:43:26 -0700 Subject: [PATCH] Support parameter packing --- pyproject.toml | 1 - python/mlc_chat/compiler/compile.py | 12 ++++++++++-- python/mlc_chat/compiler/model/llama_model.py | 12 ++++++++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 85ca20eb24..ccf754554f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ follow_imports = "skip" ignore_errors = false strict_optional = false install_types = true -non_interactive = true [tool.pylint.messages_control] max-line-length = 100 diff --git a/python/mlc_chat/compiler/compile.py b/python/mlc_chat/compiler/compile.py index 8415ca21b8..cc6b61b1c2 100644 --- a/python/mlc_chat/compiler/compile.py +++ b/python/mlc_chat/compiler/compile.py @@ -66,7 +66,7 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes config: Path quantization: str - model_type: Model + model: Model target: Target opt: OptimizationFlags build_func: Callable[[IRModule, "CompileArgs"], None] @@ -79,7 +79,7 @@ def _echo_args(args: CompileArgs) -> None: print(f"{bold('Compiling with arguments:')}", file=out) print(f" {bold('--config'):<25} {args.config}", file=out) print(f" {bold('--quantization'):<25} {args.quantization}", file=out) - print(f" {bold('--model-type'):<25} {args.model_type.name}", file=out) + print(f" {bold('--model-type'):<25} {args.model.name}", file=out) print(f" {bold('--target'):<25} {args.target.export()}", file=out) print(f" {bold('--opt'):<25} {args.opt}", file=out) print(f" {bold('--output'):<25} {args.output}", file=out) @@ -101,6 +101,14 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin config, quantization, model_type, target, opt, build_func, prefix_symbols, output ) _echo_args(args) + model_config = args.model.config.from_file(args.config) + model = args.model.model(model_config) + mod, named_params = model.export_tvm( + spec=model.get_default_spec(), # type: ignore + ) + mod.show(black_format=False) + for name, param in named_params: + print(f"{name}: {param.shape} {param.dtype}") OPT_FLAG_PRESET = { diff --git a/python/mlc_chat/compiler/model/llama_model.py b/python/mlc_chat/compiler/model/llama_model.py index 49e947f741..6bf7647ff1 100644 --- a/python/mlc_chat/compiler/model/llama_model.py +++ b/python/mlc_chat/compiler/model/llama_model.py @@ -217,14 +217,26 @@ def get_default_spec(self): "prefill": { "inputs": nn.spec.Tensor([batch_size, "seq_len"], "int32"), "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, }, "decode": { "inputs": nn.spec.Tensor([batch_size, 1], "int32"), "total_seq_len": int, + "$": { + "param_mode": "packed", + "effect_mode": "packed", + }, }, "softmax_with_temperature": { "logits": nn.spec.Tensor([1, 1, "vocab_size"], "float32"), "temperature": nn.spec.Tensor([], "float32"), + "$": { + "param_mode": "none", + "effect_mode": "none", + }, }, } return nn.spec.ModuleSpec.from_raw(mod_spec, self)