-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Slim-LM] Enable Group Quant #1129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cbd45b2
b6644a4
20e449e
ab83723
3b2867d
110598d
42aed4b
b8e153a
6826fdb
ce075f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
""" | ||
Quantization specs for Llama2 architecture. | ||
TODO: add docstring | ||
""" | ||
from typing import Callable, Dict, List, Optional | ||
|
||
import tvm | ||
from tvm.runtime import NDArray | ||
|
||
from ..parameter import QuantizeMapping | ||
from ..quantization import QuantizeConfig | ||
from ..quantization.group_quantizer import te_quantize as te_group_quantize | ||
from .llama_config import LlamaConfig | ||
from .llama_model import LlamaForCasualLM | ||
|
||
|
||
def huggingface_group_quantize( | ||
model_config: LlamaConfig, | ||
quantize_config: QuantizeConfig, | ||
target: Optional[tvm.target.Target] = None, | ||
) -> QuantizeMapping: | ||
"""Returns a parameter mapping that maps a parameter in MLC LLM's model | ||
definition to its eventual names and values after quantization. | ||
Parameters | ||
---------- | ||
model_config : LlamaConfig | ||
The configuration of the Llama model. | ||
quantize_config : GroupQuantizeConfig | ||
The configuration of the group quantization. | ||
target : Optional[tvm.target.Target] | ||
The target device to run the quantization on, by default None, which | ||
means the quantization will be run on CPU. | ||
Returns | ||
------- | ||
quantize_map : QuantizeMapping | ||
The parameter mapping from a parameter in MLC LLM's model definition to | ||
its eventual names and values after quantization. | ||
""" | ||
|
||
def group_quantize( | ||
param: NDArray, config: QuantizeConfig, target: Optional[tvm.target.Target] = None | ||
): | ||
if target is None or target.kind.name == "llvm": | ||
target = tvm.target.Target("llvm") | ||
device = tvm.cpu() | ||
elif target.kind.name == "cuda": | ||
device = tvm.cuda() | ||
else: | ||
raise ValueError(f"Invalid target device: {target}") | ||
param_tensor = tvm.te.placeholder(param.shape, dtype=param.dtype, name="param") | ||
weight_compute, scale_compute, other_computes = te_group_quantize( # type: ignore | ||
param_tensor, config | ||
) | ||
s = tvm.te.create_schedule( | ||
[compute.op for compute in [weight_compute, scale_compute] + other_computes] | ||
) | ||
if target.kind.name == "cuda": | ||
# thread_binding for cuda | ||
for compute in [weight_compute, scale_compute] + other_computes: | ||
xo, xi = s[compute].split(compute.op.axis[0], 256) | ||
s[compute].bind(xo, tvm.te.thread_axis("blockIdx.x")) | ||
s[compute].bind(xi, tvm.te.thread_axis("threadIdx.x")) | ||
f_quantize = tvm.build( | ||
s, [param_tensor, weight_compute, scale_compute], name="group_quantize", target=target | ||
) | ||
weight = tvm.nd.empty(weight_compute.shape, weight_compute.dtype, device=device) | ||
scale = tvm.nd.empty(scale_compute.shape, scale_compute.dtype, device=device) | ||
f_quantize(param.copyto(device), weight, scale) | ||
return weight, scale | ||
|
||
# Param check | ||
assert ( | ||
quantize_config.kind == "group_quantize" | ||
), f"Invalid quantization config: group quantization expected but got {quantize_config.kind}" | ||
assert ( | ||
quantize_config.name == "q4f16_1" | ||
), """Only support q4f16_1 quantization scheme for now.""" | ||
|
||
# Fetch model parameter & names | ||
model = LlamaForCasualLM(model_config) | ||
_, named_params = model.export_tvm(spec=model.get_default_spec()) | ||
parameter_names = {name for name, _ in named_params} | ||
|
||
# Init mappings | ||
param_map: Dict[str, List[str]] = {} | ||
map_func: Dict[str, Callable] = {} | ||
|
||
# Dispatch quantization scheme | ||
# Also see https://github.com/mlc-ai/mlc-llm/blob/main/mlc_llm/quantization/__init__.py | ||
for name in parameter_names: | ||
if "norm.weight" not in name and "embed" not in name: | ||
param_map[name] = [f"{name}_quantized", f"{name}_scale"] | ||
map_func[name] = lambda x: group_quantize(x, quantize_config, target=target) | ||
else: | ||
# skip these parameters | ||
param_map[name] = [name] | ||
map_func[name] = lambda x: [x] | ||
|
||
return QuantizeMapping(param_map, map_func) |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,16 +5,21 @@ | |||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||
from collections import OrderedDict, defaultdict | ||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||
from typing import Dict, Iterator, List, Tuple | ||||||||||||||||||||||||||
from typing import Dict, Iterator, List, Optional, Tuple | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||
from tqdm import tqdm | ||||||||||||||||||||||||||
from tvm.runtime import NDArray | ||||||||||||||||||||||||||
from tvm.runtime.ndarray import array as as_ndarray | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
from .mapping import ExternMapping | ||||||||||||||||||||||||||
from .mapping import ExternMapping, QuantizeMapping | ||||||||||||||||||||||||||
from .stats import Stats | ||||||||||||||||||||||||||
from .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard | ||||||||||||||||||||||||||
from .utils import ( | ||||||||||||||||||||||||||
ParamQuantizer, | ||||||||||||||||||||||||||
check_parameter_usage, | ||||||||||||||||||||||||||
load_safetensor_shard, | ||||||||||||||||||||||||||
load_torch_shard, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -38,17 +43,22 @@ class HuggingFaceLoader: # pylint: disable=too-few-public-methods | |||||||||||||||||||||||||
cached_files : Dict[Path, Dict[str, np.ndarray]] | ||||||||||||||||||||||||||
A cache of the loaded files. The key is the path of the file, and the value is a mapping | ||||||||||||||||||||||||||
from parameter name to the parameter value. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
quantize_param_map : Optional[QuantizeMapping] | ||||||||||||||||||||||||||
The quantization mapping from MLC to quantized MLC parameters. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
stats: Stats | ||||||||||||||||||||||||||
extern_param_map: ExternMapping | ||||||||||||||||||||||||||
cached_files: Dict[Path, Dict[str, np.ndarray]] | ||||||||||||||||||||||||||
torch_to_path: Dict[str, Path] | ||||||||||||||||||||||||||
extern_param_map: ExternMapping | ||||||||||||||||||||||||||
quantize_param_map: Optional[QuantizeMapping] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||
path: Path, | ||||||||||||||||||||||||||
extern_param_map: ExternMapping, | ||||||||||||||||||||||||||
quantize_param_map: Optional[QuantizeMapping] = None, | ||||||||||||||||||||||||||
) -> None: | ||||||||||||||||||||||||||
"""Create a parameter loader from HuggingFace PyTorch format. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
@@ -66,12 +76,17 @@ def __init__( | |||||||||||||||||||||||||
|
||||||||||||||||||||||||||
extern_param_map : ExternMapping | ||||||||||||||||||||||||||
Maps an MLC parameter to a list of PyTorch/SafeTensor parameters. | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
quantize_param_map: Optional[QuantizeMapping] | ||||||||||||||||||||||||||
The quantization mapping from MLC to quantized MLC parameters, default to None, which | ||||||||||||||||||||||||||
means no quantization. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
assert path.is_file() | ||||||||||||||||||||||||||
self.stats = Stats() | ||||||||||||||||||||||||||
self.extern_param_map = extern_param_map | ||||||||||||||||||||||||||
self.cached_files = {} | ||||||||||||||||||||||||||
self.torch_to_path = {} | ||||||||||||||||||||||||||
self.quantize_param_map = quantize_param_map | ||||||||||||||||||||||||||
Comment on lines
84
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: reorder to be consistent with type annotation
Suggested change
|
||||||||||||||||||||||||||
if path.suffix in (".bin", ".safetensors"): | ||||||||||||||||||||||||||
self._load_file(path) | ||||||||||||||||||||||||||
for name in self.cached_files[path].keys(): | ||||||||||||||||||||||||||
|
@@ -90,7 +105,21 @@ def load(self) -> Iterator[Tuple[str, NDArray]]: | |||||||||||||||||||||||||
mlc_names = _loading_order(self.extern_param_map, self.torch_to_path) | ||||||||||||||||||||||||||
for mlc_name in tqdm(mlc_names): | ||||||||||||||||||||||||||
param = self._load_mlc_param(mlc_name) | ||||||||||||||||||||||||||
yield mlc_name, param | ||||||||||||||||||||||||||
if self.quantize_param_map: | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. handle the case that |
||||||||||||||||||||||||||
with self.stats.timer("quant_time_sec"): | ||||||||||||||||||||||||||
quantized_params = ParamQuantizer(self.quantize_param_map).quantize( | ||||||||||||||||||||||||||
mlc_name, param | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
for quantized_name, quantized_param in quantized_params: | ||||||||||||||||||||||||||
logger.info( | ||||||||||||||||||||||||||
' Quantized Parameter: "%s", shape: %s, dtype: %s', | ||||||||||||||||||||||||||
quantized_name, | ||||||||||||||||||||||||||
quantized_param.shape, | ||||||||||||||||||||||||||
quantized_param.dtype, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
yield quantized_name, quantized_param | ||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||
yield mlc_name, param | ||||||||||||||||||||||||||
cached_files = list(self.cached_files.keys()) | ||||||||||||||||||||||||||
for path in cached_files: | ||||||||||||||||||||||||||
self._unload_file(path) | ||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,51 @@ | ||
"""Common utilities for loading parameters""" | ||
# pylint: disable=too-few-public-methods | ||
import logging | ||
from pathlib import Path | ||
from typing import Iterator, Set, Tuple | ||
from typing import TYPE_CHECKING, Iterator, Set, Tuple | ||
|
||
import numpy as np | ||
|
||
from .mapping import ExternMapping | ||
|
||
if TYPE_CHECKING: | ||
from tvm.runtime import NDArray | ||
|
||
from ..parameter import QuantizeMapping | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class ParamQuantizer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we unlikely need this helper class - seems its core logic is ~10 lines of code? |
||
"""A parameter quantizer that quantizes given mlc-llm parameters""" | ||
|
||
quantize_map: "QuantizeMapping" | ||
|
||
def __init__(self, quantize_map: "QuantizeMapping") -> None: | ||
self.quantize_map = quantize_map | ||
|
||
def quantize(self, name: str, param: "NDArray") -> Iterator[Tuple[str, "NDArray"]]: | ||
"""Apply quantization to the given parameters | ||
Parameters | ||
---------- | ||
name : str | ||
The name of the parameter | ||
param : NDArray | ||
The parameter to be quantized | ||
Returns | ||
------- | ||
List[Tuple[str, NDArray]] | ||
The quantized parameters, each with its name | ||
""" | ||
|
||
assert name in self.quantize_map.param_map | ||
quantized_names = self.quantize_map.param_map[name] | ||
quantized_params = self.quantize_map.map_func[name](param) | ||
return zip(quantized_names, quantized_params) | ||
|
||
|
||
def check_parameter_usage(param_map: ExternMapping, extern_weights: Set[str]): | ||
"""Check that all external parameters have been used and are stored in the weights file.""" | ||
used_extern_names = set(sum(param_map.param_map.values(), [])) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
"""A subpackage for quantization and dequantization algorithms""" | ||
from .quantization import QUANT | ||
from .quantization import QUANT, QuantizeConfig |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,70 @@ | ||||||
"""A group quantizer for on the fly parameter quantization""" | ||||||
# pylint: disable=too-few-public-methods | ||||||
|
||||||
from typing import List, Tuple | ||||||
|
||||||
from tvm import te, tir | ||||||
|
||||||
from .quantization import QuantizeConfig | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. define the dataclass |
||||||
|
||||||
def te_quantize( | ||||||
weight: te.Tensor, config: QuantizeConfig | ||||||
) -> Tuple[te.Tensor, te.Tensor, List[te.Tensor]]: | ||||||
"""Group quantization for weight tensor, defined in tensor expression.""" | ||||||
# pylint: disable=too-many-locals | ||||||
assert len(weight.shape) == 2 | ||||||
n, m = weight.shape | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: use
Suggested change
|
||||||
# compute scale per group | ||||||
r = te.reduce_axis((0, config.group_size), name="r") | ||||||
num_group = tir.ceildiv(m, config.group_size) | ||||||
scale_shape = (n, num_group) | ||||||
max_abs = te.compute( | ||||||
shape=scale_shape, | ||||||
fcompute=lambda i, j: te.max( | ||||||
tir.if_then_else( | ||||||
j * config.group_size + r < weight.shape[1], | ||||||
te.abs(weight[i, j * config.group_size + r]), | ||||||
tir.const(1e-4, config.weight_dtype), | ||||||
), | ||||||
axis=r, | ||||||
), | ||||||
name="max_abs_value", | ||||||
) | ||||||
scale = te.compute( | ||||||
(n, m), | ||||||
lambda i, j: max_abs[i, j] / tir.const(config.max_int_value, dtype=config.weight_dtype), | ||||||
name="scale", | ||||||
) | ||||||
|
||||||
# compute scaled weight | ||||||
tir_max_int = tir.const(config.max_int_value, config.weight_dtype) | ||||||
tir_zero = tir.const(0, config.weight_dtype) | ||||||
tir_max_int_2 = tir.const(config.max_int_value * 2, config.weight_dtype) | ||||||
scaled_weight = te.compute( | ||||||
shape=weight.shape, | ||||||
fcompute=lambda i, j: tir.min( | ||||||
tir.max( | ||||||
tir.round(weight[i, j] / scale[i, j // config.group_size] + tir_max_int), | ||||||
tir_zero, | ||||||
), | ||||||
tir_max_int_2, | ||||||
).astype(config.storage_dtype), | ||||||
) | ||||||
|
||||||
# compute quantized weight per storage | ||||||
r = te.reduce_axis((0, config.num_elem_per_storage), name="r") | ||||||
num_storage = config.num_storage_per_group * num_group | ||||||
quantized_weight_shape = (n, num_storage) | ||||||
quantized_weight = te.compute( | ||||||
shape=quantized_weight_shape, | ||||||
fcompute=lambda i, j: tir.sum( | ||||||
scaled_weight[i, j * config.num_elem_per_storage + r] | ||||||
<< (r * config.quantize_dtype_bits), | ||||||
axis=r, | ||||||
where=j * config.num_elem_per_storage + r < m, | ||||||
), | ||||||
name="weight", | ||||||
) | ||||||
return quantized_weight, scale, [max_abs, scaled_weight] | ||||||
# pylint: enable=too-many-locals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my POV, the docstring looks pretty reasonable. Any specific ones in your mind?