diff --git a/python/mlc_chat/compiler/model/llama_quantization.py b/python/mlc_chat/compiler/model/llama_quantization.py new file mode 100644 index 0000000000..dbf360c31d --- /dev/null +++ b/python/mlc_chat/compiler/model/llama_quantization.py @@ -0,0 +1,101 @@ +""" +Quantization specs for Llama2 architecture. +TODO: add docstring +""" +from typing import Callable, Dict, List, Optional + +import tvm +from tvm.runtime import NDArray + +from ..parameter import QuantizeMapping +from ..quantization import QuantizeConfig +from ..quantization.group_quantizer import te_quantize as te_group_quantize +from .llama_config import LlamaConfig +from .llama_model import LlamaForCasualLM + + +def huggingface_group_quantize( + model_config: LlamaConfig, + quantize_config: QuantizeConfig, + target: Optional[tvm.target.Target] = None, +) -> QuantizeMapping: + """Returns a parameter mapping that maps a parameter in MLC LLM's model + definition to its eventual names and values after quantization. + + Parameters + ---------- + model_config : LlamaConfig + The configuration of the Llama model. + quantize_config : GroupQuantizeConfig + The configuration of the group quantization. + target : Optional[tvm.target.Target] + The target device to run the quantization on, by default None, which + means the quantization will be run on CPU. + + Returns + ------- + quantize_map : QuantizeMapping + The parameter mapping from a parameter in MLC LLM's model definition to + its eventual names and values after quantization. + """ + + def group_quantize( + param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None + ): + if target is None or target.kind.name == "llvm": + target = tvm.target.Target("llvm") + device = tvm.cpu() + elif target.kind.name == "cuda": + device = tvm.cuda() + else: + raise ValueError(f"Invalid target device: {target}") + param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param") + weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore + param_tensor, config + ) + s = tvm.te.create_schedule( + [compute.op for compute in [weight_compute, scale_compute] + other_computes] + ) + if target.kind.name == "cuda": + # thread_binding for cuda + for compute in [weight_compute, scale_compute] + other_computes: + xo, xi = s[compute].split(compute.op.axis[0], 256) + s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) + s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) + f_quantize = tvm.build( + s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target + ) + weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device) + scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device) + f_quantize(param.copyto(device), weight, scale) + return weight, scale + + # Param check + assert ( + quantize_config.kind == "group_quantize" + ), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}" + assert ( + quantize_config.name == "q4f16_1" + ), """Only support q4f16_1 quantization scheme for now.""" + + # Fetch model parameter & names + model = LlamaForCasualLM(model_config) + _, named_params = model.export_tvm(spec=model.get_default_spec()) + parameter_names = {name for name, _ in named_params} + + # Init mappings + param_map: Dict[str, List[str]] = {} + map_func: Dict[str, Callable] = {} + + # Dispatch quantization scheme + # Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py + for name in parameter_names: + if "norm.weight" not in name and "embed" not in name: + param_map[name] = [f"{name}_quantized", f"{name}_scale"] + map_func[name] = lambda x: group_quantize(x, quantize_config, target=target) + else: + # skip these parameters + param_map[name] = [name] + map_func[name] = lambda x: [x] + + return QuantizeMapping(param_map, map_func) diff --git a/python/mlc_chat/compiler/parameter/huggingface_loader.py b/python/mlc_chat/compiler/parameter/huggingface_loader.py index fa6beb40eb..ed91255c81 100644 --- a/python/mlc_chat/compiler/parameter/huggingface_loader.py +++ b/python/mlc_chat/compiler/parameter/huggingface_loader.py @@ -5,16 +5,21 @@ import logging from collections import OrderedDict, defaultdict from pathlib import Path -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Optional, Tuple import numpy as np from tqdm import tqdm from tvm.runtime import NDArray from tvm.runtime.ndarray import array as as_ndarray -from .mapping import ExternMapping +from .mapping import ExternMapping, QuantizeMapping from .stats import Stats -from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard +from .utils import ( + ParamQuantizer, + check_parameter_usage, + load_safetensor_shard, + load_torch_shard, +) logger = logging.getLogger(__name__) @@ -38,17 +43,22 @@ class HuggingFaceLoader: # pylint: disable=too-few-public-methods cached_files : Dict[Path, Dict[str, np.ndarray]] A cache of the loaded files. The key is the path of the file, and the value is a mapping from parameter name to the parameter value. + + quantize_param_map : Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters. """ stats: Stats - extern_param_map: ExternMapping cached_files: Dict[Path, Dict[str, np.ndarray]] torch_to_path: Dict[str, Path] + extern_param_map: ExternMapping + quantize_param_map: Optional[QuantizeMapping] def __init__( self, path: Path, extern_param_map: ExternMapping, + quantize_param_map: Optional[QuantizeMapping] = None, ) -> None: """Create a parameter loader from HuggingFace PyTorch format. @@ -66,12 +76,17 @@ def __init__( extern_param_map : ExternMapping Maps an MLC parameter to a list of PyTorch/SafeTensor parameters. + + quantize_param_map: Optional[QuantizeMapping] + The quantization mapping from MLC to quantized MLC parameters, default to None, which + means no quantization. """ assert path.is_file() self.stats = Stats() self.extern_param_map = extern_param_map self.cached_files = {} self.torch_to_path = {} + self.quantize_param_map = quantize_param_map if path.suffix in (".bin", ".safetensors"): self._load_file(path) for name in self.cached_files[path].keys(): @@ -90,7 +105,21 @@ def load(self) -> Iterator[Tuple[str, NDArray]]: mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) for mlc_name in tqdm(mlc_names): param = self._load_mlc_param(mlc_name) - yield mlc_name, param + if self.quantize_param_map: + with self.stats.timer("quant_time_sec"): + quantized_params = ParamQuantizer(self.quantize_param_map).quantize( + mlc_name, param + ) + for quantized_name, quantized_param in quantized_params: + logger.info( + ' Quantized Parameter: "%s", shape: %s, dtype: %s', + quantized_name, + quantized_param.shape, + quantized_param.dtype, + ) + yield quantized_name, quantized_param + else: + yield mlc_name, param cached_files = list(self.cached_files.keys()) for path in cached_files: self._unload_file(path) diff --git a/python/mlc_chat/compiler/parameter/mapping.py b/python/mlc_chat/compiler/parameter/mapping.py index 6f63dce71a..aab674cfa8 100644 --- a/python/mlc_chat/compiler/parameter/mapping.py +++ b/python/mlc_chat/compiler/parameter/mapping.py @@ -80,7 +80,7 @@ class QuantizeMapping: used to convert the quantized parameters into the desired form. """ - param_map: Dict[str, Callable[[str], List[str]]] + param_map: Dict[str, List[str]] map_func: Dict[str, Callable[[NDArray], List[NDArray]]] diff --git a/python/mlc_chat/compiler/parameter/utils.py b/python/mlc_chat/compiler/parameter/utils.py index 596941aaca..a2789cee55 100644 --- a/python/mlc_chat/compiler/parameter/utils.py +++ b/python/mlc_chat/compiler/parameter/utils.py @@ -1,15 +1,51 @@ """Common utilities for loading parameters""" +# pylint: disable=too-few-public-methods import logging from pathlib import Path -from typing import Iterator, Set, Tuple +from typing import TYPE_CHECKING, Iterator, Set, Tuple import numpy as np from .mapping import ExternMapping +if TYPE_CHECKING: + from tvm.runtime import NDArray + + from ..parameter import QuantizeMapping + logger = logging.getLogger(__name__) +class ParamQuantizer: + """A parameter quantizer that quantizes given mlc-llm parameters""" + + quantize_map: "QuantizeMapping" + + def __init__(self, quantize_map: "QuantizeMapping") -> None: + self.quantize_map = quantize_map + + def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]: + """Apply quantization to the given parameters + + Parameters + ---------- + name : str + The name of the parameter + param : NDArray + The parameter to be quantized + + Returns + ------- + List[Tuple[str, NDArray]] + The quantized parameters, each with its name + """ + + assert name in self.quantize_map.param_map + quantized_names = self.quantize_map.param_map[name] + quantized_params = self.quantize_map.map_func[name](param) + return zip(quantized_names, quantized_params) + + def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]): """Check that all external parameters have been used and are stored in the weights file.""" used_extern_names = set(sum(param_map.param_map.values(), [])) diff --git a/python/mlc_chat/compiler/quantization/__init__.py b/python/mlc_chat/compiler/quantization/__init__.py index ab352fc6c2..a932119f9c 100644 --- a/python/mlc_chat/compiler/quantization/__init__.py +++ b/python/mlc_chat/compiler/quantization/__init__.py @@ -1,2 +1,2 @@ """A subpackage for quantization and dequantization algorithms""" -from .quantization import QUANT +from .quantization import QUANT, QuantizeConfig diff --git a/python/mlc_chat/compiler/quantization/group_quantizer.py b/python/mlc_chat/compiler/quantization/group_quantizer.py new file mode 100644 index 0000000000..418617dd70 --- /dev/null +++ b/python/mlc_chat/compiler/quantization/group_quantizer.py @@ -0,0 +1,70 @@ +"""A group quantizer for on the fly parameter quantization""" +# pylint: disable=too-few-public-methods + +from typing import List, Tuple + +from tvm import te, tir + +from .quantization import QuantizeConfig + + +def te_quantize( + weight: te.Tensor, config: QuantizeConfig +) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]: + """Group quantization for weight tensor, defined in tensor expression.""" + # pylint: disable=too-many-locals + assert len(weight.shape) == 2 + n, m = weight.shape + # compute scale per group + r = te.reduce_axis((0, config.group_size), name="r") + num_group = tir.ceildiv(m, config.group_size) + scale_shape = (n, num_group) + max_abs = te.compute( + shape=scale_shape, + fcompute=lambda i, j: te.max( + tir.if_then_else( + j * config.group_size + r < weight.shape[1], + te.abs(weight[i, j * config.group_size + r]), + tir.const(1e-4, config.weight_dtype), + ), + axis=r, + ), + name="max_abs_value", + ) + scale = te.compute( + (n, m), + lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype), + name="scale", + ) + + # compute scaled weight + tir_max_int = tir.const(config.max_int_value, config.weight_dtype) + tir_zero = tir.const(0, config.weight_dtype) + tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype) + scaled_weight = te.compute( + shape=weight.shape, + fcompute=lambda i, j: tir.min( + tir.max( + tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int), + tir_zero, + ), + tir_max_int_2, + ).astype(config.storage_dtype), + ) + + # compute quantized weight per storage + r = te.reduce_axis((0, config.num_elem_per_storage), name="r") + num_storage = config.num_storage_per_group * num_group + quantized_weight_shape = (n, num_storage) + quantized_weight = te.compute( + shape=quantized_weight_shape, + fcompute=lambda i, j: tir.sum( + scaled_weight[i, j * config.num_elem_per_storage + r] + << (r * config.quantize_dtype_bits), + axis=r, + where=j * config.num_elem_per_storage + r < m, + ), + name="weight", + ) + return quantized_weight, scale, [max_abs, scaled_weight] + # pylint: enable=too-many-locals diff --git a/tests/python/parameter/test_group_quantizer.py b/tests/python/parameter/test_group_quantizer.py new file mode 100644 index 0000000000..b0e4b6522f --- /dev/null +++ b/tests/python/parameter/test_group_quantizer.py @@ -0,0 +1,157 @@ +# pylint: disable=missing-docstring,too-many-instance-attributes +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Tuple, Union + +import numpy as np +import tvm +from mlc_chat.compiler import MODELS +from mlc_chat.compiler.model.llama_config import LlamaConfig +from mlc_chat.compiler.model.llama_quantization import huggingface_group_quantize +from mlc_chat.compiler.parameter import HuggingFaceLoader +from mlc_chat.support import tqdm +from tvm.runtime import NDArray + +if TYPE_CHECKING: + from tvm.relax.frontend import nn + +logging.basicConfig( + level=logging.DEBUG, + style="{", + datefmt="%Y-%m-%d %H:%M:%S", + format="[{asctime}] {levelname} {filename}:{lineno}: {message}", +) + + +def test_load_torch_llama_group_quantize(base_path: Union[str, Path], target: str = "llvm"): + @dataclass + class TestGroupQuantizeConfig: + name: str = "q4f16_1" + kind: str = "group_quantize" + group_size: int = 32 + weight_dtype: str = "float16" + max_int_value: int = 7 + storage_dtype: str = "uint32" + num_elem_per_storage: int = 8 + num_storage_per_group: int = 4 + quantize_dtype_bits: int = 4 + + def quantize(self, _: "nn.Module") -> "nn.Module": + raise NotImplementedError + + base_path = Path(base_path) + path_config = base_path / "config.json" + path_params = base_path / "pytorch_model.bin.index.json" + + model = MODELS["llama"] + model_config = LlamaConfig.from_file(path_config) + quantize_config = TestGroupQuantizeConfig() + loader = HuggingFaceLoader( + path=path_params, + extern_param_map=model.source["huggingface-torch"](model_config, None), + quantize_param_map=huggingface_group_quantize( + model_config, + quantize_config, + target=tvm.target.Target(target), + ), + ) + with tqdm.redirect(): + for _name, _param in loader.load(): + ... + + +def test_group_quantize_vs_numpy(): + bits = { + "int4": 4, + "int8": 8, + "fp16": 16, + "fp32": 32, + "int32": 32, + "uint32": 32, + } + + # pylint: disable=unused-variable + def group_quantize_np( + w: NDArray, + quantize_dtype: str = "int4", + storage_dtype: str = "uint32", + group_size: int = 32, + # symmetric: bool = True, + # transpose: bool = False, + ) -> Tuple[NDArray, NDArray]: + # pylint: disable=too-many-locals + def _pad_axis_by_factor(tensor: np.ndarray, axis: int, factor: int) -> np.ndarray: + dim = int(tensor.shape[axis]) + if dim % factor == 0: + return tensor + pad_width = [[0, 0] for i in tensor.shape] + pad_width[axis][1] = factor - (dim % factor) + return np.pad(tensor, pad_width, mode="constant", constant_values=0) + + def _clip( + x: np.ndarray, + x_min: int, + x_max: int, + dtype: str, + ) -> np.ndarray: + return np.clip(x, a_min=x_min, a_max=x_max).astype(dtype) + + num_elem_per_storage = bits[storage_dtype] // bits[quantize_dtype] + assert group_size % num_elem_per_storage == 0 + num_storage_units = (group_size + num_elem_per_storage - 1) // num_elem_per_storage + + # using numpy for now + w = w.numpy() + + # Step 1. Tile `w`: [n, k'] -> [n, k, group_size] + w = _pad_axis_by_factor(w, axis=1, factor=group_size) + n, k = [int(v) for v in w.shape] # pylint: disable=invalid-name + assert k % group_size == 0, "Padding is not working properly" + k = k // group_size + w = w.reshape([n, k, group_size]) + + # Step 2. Calculate + if quantize_dtype.startswith("int"): + max_int_value = (2 ** (bits[quantize_dtype] - 1)) - 1 + # 1) `scale`: [n, k, group_size] -> [n, k] + scale = np.maximum(np.amax(w, axis=-1), 1e-4) / max_int_value + # 2) `w`: w / scale + + w = _clip( + np.round(w / scale[:, :, np.newaxis]).astype("int") + max_int_value, + x_min=0, + x_max=max_int_value * 2, + dtype=storage_dtype, + ) + else: + raise NotImplementedError + + # Step 3. Compress `w` to every `num_elem_per_storage` elements + res = np.zeros((n, k, num_storage_units), dtype=np.uint32) + for i in range(n): + for j in range(k): + for m in range(num_storage_units): + for k in range(num_elem_per_storage): + res[i, j, m] += w[i, j, m * num_elem_per_storage + k] * 2**k + return tvm.nd.array(res), tvm.nd.array(scale) + # pylint: enable=too-many-locals + + +if __name__ == "__main__": + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-7b-hf", + target="llvm", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-7b-hf", + target="nvidia/nvidia-a100", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-13b-hf", + target="llvm", + ) + test_load_torch_llama_group_quantize( + base_path="./dist/models/Llama-2-13b-hf", + target="nvidia/nvidia-a100", + )