From 9adfc710d5db04bbaa1288fde789718be467774a Mon Sep 17 00:00:00 2001 From: Lesheng Jin Date: Mon, 18 Dec 2023 05:36:27 +0000 Subject: [PATCH] [SLIM][AWQ] AWQ GEMM support --- .../compiler/model/llama/llama_loader.py | 10 +++++-- .../compiler/quantization/awq_quantization.py | 28 ++++++++----------- .../compiler/quantization/quantization.py | 4 +-- .../mlc_chat/compiler/quantization/utils.py | 7 ++++- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/python/mlc_chat/compiler/model/llama/llama_loader.py b/python/mlc_chat/compiler/model/llama/llama_loader.py index 7118340d21..2f1d758d53 100644 --- a/python/mlc_chat/compiler/model/llama/llama_loader.py +++ b/python/mlc_chat/compiler/model/llama/llama_loader.py @@ -122,7 +122,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: f"{attn}.v_proj.{quantize_suffix}", ], functools.partial( - lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), + lambda q, k, v, dtype: np.concatenate( + [q, k, v], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), dtype=mlc_param.dtype, ), ) @@ -140,7 +143,10 @@ def awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping: f"{mlp}.up_proj.{quantize_suffix}", ], functools.partial( - lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype), + lambda gate, up, dtype: np.concatenate( + [gate, up], + axis=1, # AWQ GEMM would transpose the weight + ).astype(dtype), dtype=mlc_param.dtype, ), ) diff --git a/python/mlc_chat/compiler/quantization/awq_quantization.py b/python/mlc_chat/compiler/quantization/awq_quantization.py index ff5dc6bee8..58e5cc1583 100644 --- a/python/mlc_chat/compiler/quantization/awq_quantization.py +++ b/python/mlc_chat/compiler/quantization/awq_quantization.py @@ -3,7 +3,7 @@ from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional -from tvm import DataType, DataTypeCode, te, tir +from tvm import DataType, DataTypeCode, te, tir, topi from tvm.relax.frontend import nn from tvm.runtime import NDArray @@ -138,7 +138,8 @@ def _dequantize( self.num_elem_per_storage, self.storage_dtype, self.model_dtype, - out_shape, + [weight.shape[0], weight.shape[1] * self.num_elem_per_storage], + ft_reorder=True, ) float_zeros = convert_uint_to_float( zeros, @@ -146,8 +147,12 @@ def _dequantize( self.num_elem_per_storage, self.storage_dtype, self.model_dtype, - out_shape, + [zeros.shape[0], zeros.shape[1] * self.num_elem_per_storage], + ft_reorder=True, ) + float_weight = topi.transpose(float_weight) + float_zeros = topi.transpose(float_zeros) + scale = topi.transpose(scale) return te.compute( shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage] if out_shape is None @@ -177,23 +182,14 @@ def __init__( # pylint: disable=too-many-arguments self.out_dtype = out_dtype self.config = config self.qweight = nn.Parameter( - (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)), - config.storage_dtype, + (in_features, out_features // config.num_elem_per_storage), config.storage_dtype ) self.qzeros = nn.Parameter( - ( - out_features, - _calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage), - ), - dtype=config.storage_dtype, + (in_features // config.group_size, out_features // config.num_elem_per_storage), + config.storage_dtype, ) self.scales = nn.Parameter( - ( - out_features, - _calculate_zeros_width(in_features, config.group_size, config.num_elem_per_storage) - * config.num_elem_per_storage, - ), - config.model_dtype, + (in_features // config.group_size, out_features), config.model_dtype ) if bias: self.bias = nn.Parameter( diff --git a/python/mlc_chat/compiler/quantization/quantization.py b/python/mlc_chat/compiler/quantization/quantization.py index a4d915fc7d..75f6169abe 100644 --- a/python/mlc_chat/compiler/quantization/quantization.py +++ b/python/mlc_chat/compiler/quantization/quantization.py @@ -59,8 +59,8 @@ def quantize_weight(self, weight: tvm.runtime.NDArray) -> List[tvm.runtime.NDArr storage_dtype="uint32", model_dtype="float32", ), - "q4f16_awq": AWQQuantize( - name="q4f16_awq", + "q4f16_autoawq": AWQQuantize( + name="q4f16_autoawq", kind="awq", group_size=128, quantize_dtype="int4", diff --git a/python/mlc_chat/compiler/quantization/utils.py b/python/mlc_chat/compiler/quantization/utils.py index 9a879d2e96..a9ade4d64b 100644 --- a/python/mlc_chat/compiler/quantization/utils.py +++ b/python/mlc_chat/compiler/quantization/utils.py @@ -11,6 +11,7 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments storage_dtype: str, model_dtype: str, out_shape: Optional[List[tir.PrimExpr]] = None, + ft_reorder: Optional[bool] = False, ) -> te.Tensor: """Convert a quantized uint weight to an unquantized float weight.""" tir_bin_mask = tir.const((1 << bits) - 1, storage_dtype) @@ -21,7 +22,11 @@ def convert_uint_to_float( # pylint: disable=too-many-arguments fcompute=lambda i, j: tir.bitwise_and( tir.shift_right( weight[i, j // num_elem_per_storage], - ((j % num_elem_per_storage) * bits).astype(storage_dtype), + ( + ((j % num_elem_per_storage) % 2 * 4 + (j % num_elem_per_storage) // 2) * bits + if ft_reorder + else (j % num_elem_per_storage) * bits + ).astype(storage_dtype), ), tir_bin_mask, ).astype(model_dtype),