Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions python/mlc_chat/compiler/model/llama_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Quantization specs for Llama2 architecture.
TODO: add docstring
Comment on lines +2 to +3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my POV, the docstring looks pretty reasonable. Any specific ones in your mind?

"""
from typing import Callable, Dict, List, Optional

import tvm
from tvm.runtime import NDArray

from ..parameter import QuantizeMapping
from ..quantization import QuantizeConfig
from ..quantization.group_quantizer import te_quantize as te_group_quantize
from .llama_config import LlamaConfig
from .llama_model import LlamaForCasualLM


def huggingface_group_quantize(
model_config: LlamaConfig,
quantize_config: QuantizeConfig,
target: Optional[tvm.target.Target] = None,
) -> QuantizeMapping:
"""Returns a parameter mapping that maps a parameter in MLC LLM's model
definition to its eventual names and values after quantization.
Parameters
----------
model_config : LlamaConfig
The configuration of the Llama model.
quantize_config : GroupQuantizeConfig
The configuration of the group quantization.
target : Optional[tvm.target.Target]
The target device to run the quantization on, by default None, which
means the quantization will be run on CPU.
Returns
-------
quantize_map : QuantizeMapping
The parameter mapping from a parameter in MLC LLM's model definition to
its eventual names and values after quantization.
"""

def group_quantize(
param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None
):
if target is None or target.kind.name == "llvm":
target = tvm.target.Target("llvm")
device = tvm.cpu()
elif target.kind.name == "cuda":
device = tvm.cuda()
else:
raise ValueError(f"Invalid target device: {target}")
param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param")
weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore
param_tensor, config
)
s = tvm.te.create_schedule(
[compute.op for compute in [weight_compute, scale_compute] + other_computes]
)
if target.kind.name == "cuda":
# thread_binding for cuda
for compute in [weight_compute, scale_compute] + other_computes:
xo, xi = s[compute].split(compute.op.axis[0], 256)
s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x"))
s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x"))
f_quantize = tvm.build(
s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target
)
weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device)
scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device)
f_quantize(param.copyto(device), weight, scale)
return weight, scale

# Param check
assert (
quantize_config.kind == "group_quantize"
), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}"
assert (
quantize_config.name == "q4f16_1"
), """Only support q4f16_1 quantization scheme for now."""

# Fetch model parameter & names
model = LlamaForCasualLM(model_config)
_, named_params = model.export_tvm(spec=model.get_default_spec())
parameter_names = {name for name, _ in named_params}

# Init mappings
param_map: Dict[str, List[str]] = {}
map_func: Dict[str, Callable] = {}

# Dispatch quantization scheme
# Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py
for name in parameter_names:
if "norm.weight" not in name and "embed" not in name:
param_map[name] = [f"{name}_quantized", f"{name}_scale"]
map_func[name] = lambda x: group_quantize(x, quantize_config, target=target)
else:
# skip these parameters
param_map[name] = [name]
map_func[name] = lambda x: [x]

return QuantizeMapping(param_map, map_func)
39 changes: 34 additions & 5 deletions python/mlc_chat/compiler/parameter/huggingface_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
import logging
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, Iterator, List, Tuple
from typing import Dict, Iterator, List, Optional, Tuple

import numpy as np
from tqdm import tqdm
from tvm.runtime import NDArray
from tvm.runtime.ndarray import array as as_ndarray

from .mapping import ExternMapping
from .mapping import ExternMapping, QuantizeMapping
from .stats import Stats
from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard
from .utils import (
ParamQuantizer,
check_parameter_usage,
load_safetensor_shard,
load_torch_shard,
)

logger = logging.getLogger(__name__)

Expand All @@ -38,17 +43,22 @@ class HuggingFaceLoader: # pylint: disable=too-few-public-methods
cached_files : Dict[Path, Dict[str, np.ndarray]]
A cache of the loaded files. The key is the path of the file, and the value is a mapping
from parameter name to the parameter value.

quantize_param_map : Optional[QuantizeMapping]
The quantization mapping from MLC to quantized MLC parameters.
"""

stats: Stats
extern_param_map: ExternMapping
cached_files: Dict[Path, Dict[str, np.ndarray]]
torch_to_path: Dict[str, Path]
extern_param_map: ExternMapping
quantize_param_map: Optional[QuantizeMapping]

def __init__(
self,
path: Path,
extern_param_map: ExternMapping,
quantize_param_map: Optional[QuantizeMapping] = None,
) -> None:
"""Create a parameter loader from HuggingFace PyTorch format.

Expand All @@ -66,12 +76,17 @@ def __init__(

extern_param_map : ExternMapping
Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.

quantize_param_map: Optional[QuantizeMapping]
The quantization mapping from MLC to quantized MLC parameters, default to None, which
means no quantization.
"""
assert path.is_file()
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
self.torch_to_path = {}
self.quantize_param_map = quantize_param_map
Comment on lines 84 to +89
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: reorder to be consistent with type annotation

Suggested change
assert path.is_file()
self.stats = Stats()
self.extern_param_map = extern_param_map
self.cached_files = {}
self.torch_to_path = {}
self.quantize_param_map = quantize_param_map
assert path.is_file()
self.stats = Stats()
self.cached_files = {}
self.torch_to_path = {}
self.extern_param_map = extern_param_map
self.quantize_param_map = quantize_param_map

if path.suffix in (".bin", ".safetensors"):
self._load_file(path)
for name in self.cached_files[path].keys():
Expand All @@ -90,7 +105,21 @@ def load(self) -> Iterator[Tuple[str, NDArray]]:
mlc_names = _loading_order(self.extern_param_map, self.torch_to_path)
for mlc_name in tqdm(mlc_names):
param = self._load_mlc_param(mlc_name)
yield mlc_name, param
if self.quantize_param_map:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

handle the case that param is not in self.quantize_param_map - meaning no quantization

with self.stats.timer("quant_time_sec"):
quantized_params = ParamQuantizer(self.quantize_param_map).quantize(
mlc_name, param
)
for quantized_name, quantized_param in quantized_params:
logger.info(
' Quantized Parameter: "%s", shape: %s, dtype: %s',
quantized_name,
quantized_param.shape,
quantized_param.dtype,
)
yield quantized_name, quantized_param
else:
yield mlc_name, param
cached_files = list(self.cached_files.keys())
for path in cached_files:
self._unload_file(path)
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/parameter/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class QuantizeMapping:
used to convert the quantized parameters into the desired form.
"""

param_map: Dict[str, Callable[[str], List[str]]]
param_map: Dict[str, List[str]]
map_func: Dict[str, Callable[[NDArray], List[NDArray]]]


Expand Down
38 changes: 37 additions & 1 deletion python/mlc_chat/compiler/parameter/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,51 @@
"""Common utilities for loading parameters"""
# pylint: disable=too-few-public-methods
import logging
from pathlib import Path
from typing import Iterator, Set, Tuple
from typing import TYPE_CHECKING, Iterator, Set, Tuple

import numpy as np

from .mapping import ExternMapping

if TYPE_CHECKING:
from tvm.runtime import NDArray

from ..parameter import QuantizeMapping

logger = logging.getLogger(__name__)


class ParamQuantizer:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we unlikely need this helper class - seems its core logic is ~10 lines of code?

"""A parameter quantizer that quantizes given mlc-llm parameters"""

quantize_map: "QuantizeMapping"

def __init__(self, quantize_map: "QuantizeMapping") -> None:
self.quantize_map = quantize_map

def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]:
"""Apply quantization to the given parameters
Parameters
----------
name : str
The name of the parameter
param : NDArray
The parameter to be quantized
Returns
-------
List[Tuple[str, NDArray]]
The quantized parameters, each with its name
"""

assert name in self.quantize_map.param_map
quantized_names = self.quantize_map.param_map[name]
quantized_params = self.quantize_map.map_func[name](param)
return zip(quantized_names, quantized_params)


def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]):
"""Check that all external parameters have been used and are stored in the weights file."""
used_extern_names = set(sum(param_map.param_map.values(), []))
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_chat/compiler/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""A subpackage for quantization and dequantization algorithms"""
from .quantization import QUANT
from .quantization import QUANT, QuantizeConfig
70 changes: 70 additions & 0 deletions python/mlc_chat/compiler/quantization/group_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""A group quantizer for on the fly parameter quantization"""
# pylint: disable=too-few-public-methods

from typing import List, Tuple

from tvm import te, tir

from .quantization import QuantizeConfig

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define the dataclass GroupQuantizeConfig here


def te_quantize(
weight: te.Tensor, config: QuantizeConfig
) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]:
"""Group quantization for weight tensor, defined in tensor expression."""
# pylint: disable=too-many-locals
assert len(weight.shape) == 2
n, m = weight.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use k to denote the reduction axis

Suggested change
n, m = weight.shape
n, k = weight.shape

# compute scale per group
r = te.reduce_axis((0, config.group_size), name="r")
num_group = tir.ceildiv(m, config.group_size)
scale_shape = (n, num_group)
max_abs = te.compute(
shape=scale_shape,
fcompute=lambda i, j: te.max(
tir.if_then_else(
j * config.group_size + r < weight.shape[1],
te.abs(weight[i, j * config.group_size + r]),
tir.const(1e-4, config.weight_dtype),
),
axis=r,
),
name="max_abs_value",
)
scale = te.compute(
(n, m),
lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype),
name="scale",
)

# compute scaled weight
tir_max_int = tir.const(config.max_int_value, config.weight_dtype)
tir_zero = tir.const(0, config.weight_dtype)
tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype)
scaled_weight = te.compute(
shape=weight.shape,
fcompute=lambda i, j: tir.min(
tir.max(
tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int),
tir_zero,
),
tir_max_int_2,
).astype(config.storage_dtype),
)

# compute quantized weight per storage
r = te.reduce_axis((0, config.num_elem_per_storage), name="r")
num_storage = config.num_storage_per_group * num_group
quantized_weight_shape = (n, num_storage)
quantized_weight = te.compute(
shape=quantized_weight_shape,
fcompute=lambda i, j: tir.sum(
scaled_weight[i, j * config.num_elem_per_storage + r]
<< (r * config.quantize_dtype_bits),
axis=r,
where=j * config.num_elem_per_storage + r < m,
),
name="weight",
)
return quantized_weight, scale, [max_abs, scaled_weight]
# pylint: enable=too-many-locals
Loading