Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 15 additions & 44 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from typing import Union

from mlc_chat.compiler import ( # pylint: disable=redefined-builtin
HELP,
MODELS,
QUANTIZATION,
OptimizationFlags,
compile,
)

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_target import detect_target_and_host

Expand All @@ -26,17 +28,11 @@
def main():
"""Parse command line argumennts and call `mlc_llm.compiler.compile`."""

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")

def _parse_output(path: Union[str, Path]) -> Path:
path = Path(path)
parent = path.parent
if not parent.is_dir():
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
raise ValueError(f"Directory does not exist: {parent}")
return path

def _check_prefix_symbols(prefix: str) -> str:
Expand All @@ -48,88 +44,63 @@ def _check_prefix_symbols(prefix: str) -> str:
"numbers (0-9), alphabets (A-Z, a-z) and underscore (_)."
)

parser = argparse.ArgumentParser("MLC LLM Compiler")
parser = ArgumentParser("MLC LLM Compiler")
parser.add_argument(
"--config",
type=_parse_config,
"--model",
type=detect_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help="Quantization format.",
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file. "
"(default: %(default)s)",
help=HELP["model_type"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="The GPU device to compile the model to. If not set, it is inferred from locally "
"available GPUs. "
"(default: %(default)s)",
help=HELP["device_compile"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--host",
type=str,
default="auto",
help="The host CPU ISA to compile the model to. If not set, it is inferred from the "
"local CPU. (default: %(default)s)",
help=HELP["host"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--opt",
type=OptimizationFlags.from_str,
default="O2",
help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, "
"denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, "
"and O3 represents extreme optimization that could potentially break the system. "
"Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. "
'--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0. '
"(default: %(default)s)",
help=HELP["opt"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--prefix-symbols",
type=str,
default="",
help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". '
"This is useful when compiling multiple models into a single library to avoid symbol "
"conflicts. Differet from objcopy, this takes no effect for shared library. "
'(default: "")',
help=HELP["prefix_symbols"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--max-sequence-length",
type=int,
default=None,
help="Option to override the maximum sequence length supported by the model. "
"An LLM is usually trained with a fixed maximum sequence length, which is usually "
"explicitly specified in model spec. By default, if this option is not set explicitly, "
"the maximum sequence length is determined by `max_sequence_length` or "
"`max_position_embeddings` in config.json, which can be inaccuate for some models.",
help=HELP["max_sequence_length"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The name of the output file. The suffix determines if the output file is a "
"shared library or objects. Available suffixes: "
"1) Linux: .so (shared), .tar (objects); "
"2) macOS: .dylib (shared), .tar (objects); "
"3) Windows: .dll (shared), .tar (objects); "
"4) Android, iOS: .tar (objects); "
"5) Web: .wasm (web assembly)",
help=HELP["output_compile"] + " (required)",
)
parsed = parser.parse_args()
target, build_func = detect_target_and_host(parsed.device, parsed.host)
Expand Down
60 changes: 23 additions & 37 deletions python/mlc_chat/cli/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from pathlib import Path
from typing import Union

from mlc_chat.compiler import MODELS, QUANTIZATION, convert_weight
from mlc_chat.compiler import HELP, MODELS, QUANTIZATION, convert_weight

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_target import detect_device
from ..support.auto_weight import detect_weight
Expand All @@ -21,12 +22,6 @@
def main():
"""Parse command line argumennts and apply quantization."""

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")

def _parse_source(path: Union[str, Path], config_path: Path) -> Path:
if path == "auto":
return config_path.parent
Expand All @@ -41,61 +36,52 @@ def _parse_output(path: Union[str, Path]) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path

parser = argparse.ArgumentParser("MLC AutoLLM Quantization Framework")
parser = ArgumentParser("MLC AutoLLM Quantization Framework")
parser.add_argument(
"--config",
type=_parse_config,
"--model",
type=detect_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json",
)
parser.add_argument(
"--source",
type=str,
default="auto",
help="The path to original model weight, infer from `config` if missing. "
"(default: %(default)s)",
)
parser.add_argument(
"--source-format",
type=str,
choices=["auto", "huggingface-torch", "huggingface-safetensor", "awq"],
default="auto",
help="The format of source model weight, infer from `config` if missing. "
"(default: %(default)s)",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help="Quantization format, for example `q4f16_1`.",
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file. "
"(default: %(default)s)",
help=HELP["model_type"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--device",
default="auto",
type=detect_device,
help="The device used to do quantization, for example, / `cuda:0`. "
"Detect from local environment if not specified. "
"(default: %(default)s)",
help=HELP["device_quantize"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--source",
type=str,
default="auto",
help=HELP["source"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--source-format",
type=str,
choices=["auto", "huggingface-torch", "huggingface-safetensor", "awq"],
default="auto",
help=HELP["source_format"] + ' (default: "%(default)s", choices: %(choices)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The output directory to save the quantized model weight, "
"will contain `params_shard_*.bin` and `ndarray-cache.json`.",
help=HELP["output_quantize"] + " (required)",
)

parsed = parser.parse_args()
Expand Down
40 changes: 11 additions & 29 deletions python/mlc_chat/cli/gen_mlc_chat_config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Command line entrypoint of configuration generation."""
import argparse
import logging
from pathlib import Path
from typing import Union

from mlc_chat.compiler import CONV_TEMPLATES, MODELS, QUANTIZATION, gen_config
from mlc_chat.compiler import CONV_TEMPLATES, HELP, MODELS, QUANTIZATION, gen_config

from ..support.argparse import ArgumentParser
from ..support.auto_config import detect_config, detect_model_type

logging.basicConfig(
Expand All @@ -18,13 +18,7 @@

def main():
"""Parse command line argumennts and call `mlc_llm.compiler.gen_config`."""
parser = argparse.ArgumentParser("MLC LLM Configuration Generator")

def _parse_config(path: Union[str, Path]) -> Path:
try:
return detect_config(path)
except ValueError as err:
raise argparse.ArgumentTypeError(f"No valid config.json in: {path}. Error: {err}")
parser = ArgumentParser("MLC LLM Configuration Generator")

def _parse_output(path: Union[str, Path]) -> Path:
path = Path(path)
Expand All @@ -33,56 +27,44 @@ def _parse_output(path: Union[str, Path]) -> Path:
return path

parser.add_argument(
"--config",
type=_parse_config,
"--model",
type=detect_config,
required=True,
help="Path to config.json file or to the directory that contains config.json, which is "
"a HuggingFace standard that defines model architecture, for example, "
"https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json. "
"This `config.json` file is expected to colocate with other configurations, such as "
"tokenizer configuration and `generation_config.json`.",
help=HELP["model"] + " (required)",
)
parser.add_argument(
"--quantization",
type=str,
required=True,
choices=list(QUANTIZATION.keys()),
help="Quantization format.",
help=HELP["quantization"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--model-type",
type=str,
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file. "
"(default: %(default)s)",
help=HELP["model_type"] + ' (default: "%(default)s", choices: %(choices)s)',
)
parser.add_argument(
"--conv-template",
type=str,
required=True,
choices=list(CONV_TEMPLATES),
help='Conversation template. It depends on how the model is tuned. Use "LM" for vanilla '
"base model",
help=HELP["conv_template"] + " (required, choices: %(choices)s)",
)
parser.add_argument(
"--max-sequence-length",
type=int,
default=None,
help="Option to override the maximum sequence length supported by the model. "
"An LLM is usually trained with a fixed maximum sequence length, which is usually "
"explicitly specified in model spec. By default, if this option is not set explicitly, "
"the maximum sequence length is determined by `max_sequence_length` or "
"`max_position_embeddings` in config.json, which can be inaccuate for some models.",
help=HELP["max_sequence_length"] + ' (default: "%(default)s")',
)
parser.add_argument(
"--output",
"-o",
type=_parse_output,
required=True,
help="The output directory for generated configurations, including `mlc-chat-config.json`, "
"and tokenizer configuration.",
help=HELP["output_gen_mlc_chat_config"] + " (required)",
)
parsed = parser.parse_args()
model = detect_model_type(parsed.model_type, parsed.config)
Expand Down
1 change: 1 addition & 0 deletions python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .flags_model_config_override import ModelConfigOverride
from .flags_optimization import OptimizationFlags
from .gen_mlc_chat_config import CONV_TEMPLATES, gen_config
from .help import HELP
from .loader import LOADER, ExternMapping, HuggingFaceLoader, QuantizeMapping
from .model import MODEL_PRESETS, MODELS, Model
from .quantization import QUANTIZATION
3 changes: 3 additions & 0 deletions python/mlc_chat/compiler/gen_mlc_chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

FOUND = green("Found")
NOT_FOUND = red("Not found")
VERSION = "0.1.0"


@dataclasses.dataclass
class MLCChatConfig: # pylint: disable=too-many-instance-attributes
"""Arguments for `mlc_chat.compiler.gen_config`."""

version: str = VERSION

model_type: str = None
quantization: str = None
model_config: Dict[str, Any] = None
Expand Down
Loading