Skip to content

Commit d30c983

Browse files
authored
Improve Marlin accuracy (default) but add MARLIN_FP16 backend for faster with less-accuracy (#1317)
* sync marlin kernel from upstream for fp32 reduce ops precision fix Signed-off-by: Qubitium <[email protected]> * add `MARLIN_FP16` backend Signed-off-by: Qubitium <[email protected]> --------- Signed-off-by: Qubitium <[email protected]>
1 parent 629e7ca commit d30c983

File tree

13 files changed

+305
-78
lines changed

13 files changed

+305
-78
lines changed

gptqmodel/models/loader.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@
4747
from ..utils.backend import BACKEND
4848
from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear
4949
from ..utils.logger import setup_logger
50-
from ..utils.marlin import (_validate_marlin_compatibility,
51-
_validate_marlin_device_support)
50+
from ..utils.marlin import _validate_marlin_compatibility, _validate_marlin_device_support
5251
from ..utils.model import (auto_dtype, convert_gptq_v1_to_v2_format, find_modules, get_checkpoints,
5352
get_moe_layer_modules, gptqmodel_post_init, load_checkpoint_in_model_then_tie_weights,
5453
make_quant, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes)
@@ -339,14 +338,14 @@ def from_quantized(
339338

340339
if qcfg.format == FORMAT.MARLIN:
341340
# format marlin requires marlin kernel
342-
if backend != BACKEND.MARLIN and backend != BACKEND.AUTO:
341+
if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and backend != BACKEND.AUTO:
343342
raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.")
344343
backend = BACKEND.MARLIN
345344

346345
marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support()
347346

348347
# check for marlin compat for cuda device onnly
349-
if backend != BACKEND.MARLIN and device == DEVICE.CUDA:
348+
if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA:
350349
unsupported = _validate_marlin_compatibility(qcfg)
351350
if unsupported is None and marlin_compatible:
352351
logger.info(
@@ -504,7 +503,7 @@ def skip(*args, **kwargs):
504503
load_checkpoint_in_model = False
505504
qcfg.runtime_format = FORMAT.GPTQ_V2
506505

507-
if backend == BACKEND.MARLIN and (
506+
if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and (
508507
preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN):
509508
if is_sharded:
510509
raise ValueError(
@@ -541,7 +540,7 @@ def skip(*args, **kwargs):
541540

542541
# If we use marlin or bitblas to load the quantized model, the model is already a converted model,
543542
# and we no longer need to call load_checkpoint_in_model()
544-
if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.BITBLAS]:
543+
if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]:
545544
load_checkpoint_in_model_then_tie_weights(
546545
model,
547546
dtype=torch_dtype,

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.nn as nn
2424
import transformers
2525
from gptqmodel.adapter.adapter import LORA_MERGED_WEIGHT_PATHS, Adapter
26+
from gptqmodel.utils.backend import BACKEND
2627

2728
from ...models._const import DEVICE, PLATFORM
2829

@@ -52,6 +53,7 @@ def __init__(self,
5253
out_features: int,
5354
bias: bool,
5455
pack_dtype: t.dtype,
56+
backend: BACKEND,
5557
adapter: Adapter,
5658
name: str = None,
5759
register_buffers: bool = False,
@@ -68,6 +70,7 @@ def __init__(self,
6870
self.bits = bits
6971
self.desc_act = desc_act
7072
self.pack_dtype = pack_dtype
73+
self.backend = backend
7174
self.maxq = 2 ** self.bits - 1
7275
self.pack_dtype = pack_dtype
7376
# we need to clone the adapter since passed in adapter may be shared

gptqmodel/nn_modules/qlinear/exllama.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,6 @@ def ext_make_q4(qweight, qzeros, scales, g_idx, device):
4343
return make_q4(qweight, qzeros, scales, g_idx if g_idx is not None else NON_TENSOR, device)
4444

4545

46-
def ext_q4_matmul(x, q4, q4_width):
47-
"""Matrix multiplication, returns x @ q4"""
48-
outshape = x.shape[:-1] + (q4_width,)
49-
x = x.view(-1, x.shape[-1])
50-
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
51-
52-
q4_matmul(x, q4, output)
53-
54-
return output.view(outshape)
5546

5647

5748
class ExllamaQuantLinear(BaseQuantLinear):
@@ -151,6 +142,22 @@ def post_init(self):
151142

152143
super().post_init()
153144

145+
def ext_q4_matmul(self, x, q4, q4_width):
146+
"""Matrix multiplication, returns x @ q4"""
147+
outshape = x.shape[:-1] + (q4_width,)
148+
x = x.view(-1, x.shape[-1])
149+
150+
output = torch.empty((x.shape[0], q4_width), dtype=torch.float16, device=x.device)
151+
q4_matmul(x, q4, output)
152+
153+
if self.bias is not None:
154+
output.add_(self.bias)
155+
156+
if self.adapter:
157+
output = self.adapter.apply(x=x, out=output)
158+
159+
return output.view(outshape)
160+
154161

155162
def forward(self, x):
156163
x_dtype = x.dtype
@@ -166,12 +173,6 @@ def forward(self, x):
166173
# if x.size(-1) != self.in_features:
167174
# x = F.pad(x, self.in_features_padding_shape)
168175

169-
out = ext_q4_matmul(x, self.q4, self.width)
170-
171-
if self.bias is not None:
172-
out.add_(self.bias)
173-
174-
if self.adapter:
175-
out = self.adapter.apply(x=x, out=out)
176+
out = self.ext_q4_matmul(x, self.q4, self.width)
176177

177178
return out.to(x_dtype)

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
from gptqmodel.adapter.adapter import Adapter, Lora
2525
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
26+
from gptqmodel.utils.backend import BACKEND
2627
from torch.nn.parameter import Parameter
2728

2829
from ...models._const import DEVICE, PLATFORM
@@ -133,23 +134,29 @@ def apply_gptq_marlin_linear(
133134
output_size_per_partition: int,
134135
input_size_per_partition: int,
135136
is_k_full: bool,
136-
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
137+
bias: torch.Tensor,
138+
fp32: bool,
139+
) -> torch.Tensor:
140+
137141
reshaped_x = input.reshape(-1, input.shape[-1])
138142
out_shape = input.shape[:-1] + (output_size_per_partition, )
139143

140-
output = gptqmodel_marlin_kernels.gptq_marlin_gemm(reshaped_x,
141-
weight,
142-
weight_scale,
143-
weight_zp,
144-
g_idx,
145-
g_idx_sort_indices,
146-
workspace,
147-
num_bits,
148-
reshaped_x.shape[0],
149-
output_size_per_partition,
150-
input_size_per_partition,
151-
is_k_full,
152-
False)
144+
output = gptqmodel_marlin_kernels.gptq_marlin_gemm(
145+
reshaped_x,
146+
weight,
147+
weight_scale,
148+
weight_zp,
149+
g_idx,
150+
g_idx_sort_indices,
151+
workspace,
152+
num_bits,
153+
reshaped_x.shape[0],
154+
output_size_per_partition,
155+
input_size_per_partition,
156+
is_k_full,
157+
False,
158+
fp32, # <- True: enable fp32 reduce for higher accuracy, False: fp16
159+
)
153160

154161
if bias is not None:
155162
output.add_(bias) # In-place add
@@ -191,8 +198,8 @@ def __init__(
191198
f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}"
192199
)
193200

194-
self.original_in_features = in_features
195-
self.original_out_features = out_features
201+
# self.original_in_features = in_features
202+
# self.original_out_features = out_features
196203

197204
if desc_act and group_size == -1:
198205
# In this case, act_order == True is the same as act_order == False
@@ -212,6 +219,9 @@ def __init__(
212219
register_buffers=False,
213220
**kwargs)
214221

222+
# toggle fp32 mode depending on MARLIN or MARLIN_FP16 backend
223+
self.fp32 = True if self.backend is BACKEND.MARLIN else False
224+
215225
# Determine sharding
216226
if marlin_repeat_scales_on_all_ranks(desc_act,
217227
self.group_size,
@@ -390,6 +400,7 @@ def forward(self, A: torch.Tensor):
390400
input_size_per_partition=self.in_features,
391401
is_k_full=self.is_k_full,
392402
bias=self.bias,
403+
fp32=self.fp32,
393404
)
394405

395406
if self.adapter:

gptqmodel/quantization/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from .config import (FORMAT, FORMAT_FIELD_CODE, FORMAT_FIELD_JSON,
18-
QUANT_CONFIG_FILENAME, QUANT_METHOD, QUANT_METHOD_FIELD, BaseQuantizeConfig, QuantizeConfig)
17+
from .config import (FORMAT, FORMAT_FIELD_CODE, FORMAT_FIELD_JSON, QUANT_CONFIG_FILENAME,
18+
QUANT_METHOD, QUANT_METHOD_FIELD, BaseQuantizeConfig, QuantizeConfig)
1919
from .gptq import GPTQ
2020
from .quantizer import Quantizer, quantize

gptqmodel/utils/backend.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,16 @@
2020
class BACKEND(str, Enum):
2121
AUTO = "auto" # choose the optimal local kernel based on quant_config compatibility
2222
AUTO_TRAINABLE = "auto_trainable" # choose the optimal trainable local kernel for post-quant training
23-
CUDA = "cuda"
24-
TORCH = "torch"
25-
TRITON = "triton"
26-
EXLLAMA_V1 = "exllama_v1"
27-
EXLLAMA_V2 = "exllama_v2"
23+
CUDA = "cuda" # OK: Performance same as Torch for most cases
24+
TORCH = "torch" # GOOD: about 80% of triton
25+
TRITON = "triton" # VERY GOOD: all-around kernel
26+
EXLLAMA_V1 = "exllama_v1" # FAST: optimized for batching == 1
27+
EXLLAMA_V2 = "exllama_v2" # FASTER: optimized for batching > 1
2828
# EXLLAMA_EORA = "exllama_eora"
29-
MARLIN = "marlin"
30-
BITBLAS = "bitblas"
31-
IPEX = "ipex"
32-
VLLM = "vllm" # external inference engine (CUDA + ROCM + IPEX)
33-
SGLANG = "sglang" # external inference engine (CUDA + ROCm)
34-
MLX = "mlx" # external inference engine (Apple MLX on M1+)
29+
MARLIN = "marlin" # FASTEST: marlin reduce ops in fp32 (higher precision -> more accurate, slightly slower)
30+
MARLIN_FP16 = "marlin_fp16" # FASTEST and then some: marlin reduce ops in fp16 (lower precision -> less accurate, slightly faster)
31+
BITBLAS = "bitblas" # EXTREMELY FAST: speed at the cost of 10+ minutes of AOT (ahead of time compilation with disk cache)
32+
IPEX = "ipex" # Best kernel for Intel XPU and Intel/AMD CPU with AVX512, AMX, XMX
33+
VLLM = "vllm" # External inference engine: CUDA + ROCm + IPEX
34+
SGLANG = "sglang" # External inference engine: CUDA + ROCm
35+
MLX = "mlx" # External inference engine: Apple MLX on M1+ (Apple Silicon)

gptqmodel/utils/importer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
})
5353

5454
FORMAT_DICT = {
55-
FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH], # BACKEND.EXLLAMA_EORA
56-
FORMAT.GPTQ_V2: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH], # , BACKEND.EXLLAMA_EORA
57-
FORMAT.MARLIN: [BACKEND.MARLIN],
55+
FORMAT.GPTQ: [BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.IPEX, BACKEND.TORCH, BACKEND.MARLIN_FP16], # BACKEND.EXLLAMA_EORA
56+
FORMAT.GPTQ_V2: [BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH], # , BACKEND.EXLLAMA_EORA
57+
FORMAT.MARLIN: [BACKEND.MARLIN, BACKEND.MARLIN_FP16],
5858
FORMAT.BITBLAS: [BACKEND.BITBLAS],
5959
FORMAT.IPEX: [BACKEND.IPEX],
6060
}
@@ -228,7 +228,7 @@ def select_quant_linear(
228228
qlinear = TritonV2QuantLinear
229229
elif backend == BACKEND.BITBLAS:
230230
qlinear = BitBLASQuantLinear
231-
elif backend == BACKEND.MARLIN:
231+
elif backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16]:
232232
qlinear = MarlinQuantLinear
233233
# elif backend == BACKEND.EXLLAMA_EORA:
234234
# qlinear = ExllamaEoraQuantLinear

gptqmodel/utils/marlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,4 @@ def _validate_marlin_compatibility(cfg: QuantizeConfig, throw_error: bool = Fals
4040
validate, err = MarlinQuantLinear.validate(bits=cfg.bits, group_size=cfg.group_size, desc_act=cfg.desc_act, sym=cfg.sym, pack_dtype=cfg.pack_dtype, dynamic=cfg.dynamic)
4141
if throw_error and err is not None:
4242
raise ValueError(err)
43-
return err
43+
return err

gptqmodel/utils/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def make_quant(
226226
device=device,
227227
lm_head_name=lm_head_name,
228228
pack_dtype=pack_dtype,
229+
backend=backend,
229230
adapter=qcfg.adapter,
230231
)
231232
logger.info(f"Kernel: selected -> `{linear_cls.__name__}`.")
@@ -252,6 +253,7 @@ def create_quant_layer(
252253
device: DEVICE,
253254
lm_head_name: str,
254255
pack_dtype: torch.dtype,
256+
backend: BACKEND,
255257
adapter: Optional[Adapter] = None,
256258
) -> Type[BaseQuantLinear]:
257259
if isinstance(module, linear_cls):
@@ -334,6 +336,7 @@ def create_quant_layer(
334336
#weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype,
335337
name=name,
336338
lm_head_name=lm_head_name,
339+
backend=backend,
337340
adapter=adapter,
338341
)
339342
new_layer.device = ori_layer_device

0 commit comments

Comments
 (0)