Skip to content

Commit 2acc761

Browse files
authored
fix bitblas loading regression (#1324)
Signed-off-by: Qubitium <[email protected]>
1 parent 26d6911 commit 2acc761

File tree

3 files changed

+23
-22
lines changed

3 files changed

+23
-22
lines changed

gptqmodel/models/loader.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
import torch
2525
import transformers
2626

27+
from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear
28+
from ..nn_modules.qlinear.marlin import MarlinQuantLinear
29+
2730
if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']:
2831
try:
2932
from modelscope import snapshot_download
@@ -342,15 +345,14 @@ def from_quantized(
342345
raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.")
343346
backend = BACKEND.MARLIN
344347

345-
marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support()
346-
347-
# check for marlin compat for cuda device onnly
348-
if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA:
349-
unsupported = _validate_marlin_compatibility(qcfg)
350-
if unsupported is None and marlin_compatible:
351-
logger.info(
352-
"Hint: Model is compatible with the Marlin kernel. Marlin is optimized for batched inference on Nvidia GPU: `model = GPTQModel.load(..., backend=BACKEND.MARLIN)`."
353-
)
348+
# marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support()
349+
# check for marlin compat for cuda device only
350+
# if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA:
351+
# unsupported = _validate_marlin_compatibility(qcfg)
352+
# if unsupported is None and marlin_compatible:
353+
# logger.info(
354+
# "Hint: Model is compatible with the Marlin kernel. Marlin is optimized for batched inference on Nvidia GPU: `model = GPTQModel.load(..., backend=BACKEND.MARLIN)`."
355+
# )
354356

355357
if qcfg.format == FORMAT.BITBLAS:
356358
# format bitblas requires bitblas kernel
@@ -491,14 +493,16 @@ def skip(*args, **kwargs):
491493
f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
492494
)
493495

494-
t = time.time()
495-
logger.info(f"Format: Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.")
496-
model = convert_gptq_v1_to_v2_format(
497-
model,
498-
cfg=qcfg,
499-
qlinear_kernel=preload_qlinear_kernel,
500-
)
501-
logger.info(f"Format: Conversion complete: {time.time() - t}s")
496+
# skip v1 to v2 conversion for kernels that can only operate on sym=True (gptq_v1)
497+
if preload_qlinear_kernel not in [IPEXQuantLinear, MarlinQuantLinear, ExllamaEoraQuantLinear]:
498+
t = time.time()
499+
logger.info(f"Format: Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.")
500+
model = convert_gptq_v1_to_v2_format(
501+
model,
502+
cfg=qcfg,
503+
qlinear_kernel=preload_qlinear_kernel,
504+
)
505+
logger.info(f"Format: Conversion complete: {time.time() - t}s")
502506

503507
load_checkpoint_in_model = False
504508
qcfg.runtime_format = FORMAT.GPTQ_V2

gptqmodel/nn_modules/qlinear/bitblas.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torch.nn as nn
2626
from gptqmodel.adapter.adapter import Adapter, Lora
2727
from gptqmodel.nn_modules.qlinear import PackableQuantLinear
28+
from gptqmodel.utils import BACKEND
2829

2930
from ...models._const import DEVICE, PLATFORM
3031
from ...utils.logger import setup_logger
@@ -140,6 +141,7 @@ def __init__(
140141
out_features=out_features,
141142
bias=bias,
142143
pack_dtype=pack_dtype,
144+
backend=BACKEND.BITBLAS,
143145
adapter=adapter,
144146
register_buffers=False,
145147
**kwargs)

gptqmodel/utils/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,16 +357,11 @@ def hf_convert_gptq_v1_to_v2_format(
357357
else:
358358
return model, False
359359

360-
# TODO: FIXME: the v1 -> v2 zeropoint offsets are assuming INT32 pack_dtype
361360
def convert_gptq_v1_to_v2_format(
362361
model,
363362
cfg: QuantizeConfig,
364363
qlinear_kernel: Type[BaseQuantLinear],
365364
):
366-
# skip v1 to v2 conversion for kernels that can only operate on sym=True (gptq_v1)
367-
if qlinear_kernel in [IPEXQuantLinear, MarlinQuantLinear, ExllamaEoraQuantLinear]:
368-
return model
369-
370365
# Limit thread usage to avoid auto-parallizataion regression
371366
with tctl.threadpool_limits(limits=1):
372367
for _, submodule in model.named_modules():

0 commit comments

Comments
 (0)