|
24 | 24 | import torch |
25 | 25 | import transformers |
26 | 26 |
|
| 27 | +from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear |
| 28 | +from ..nn_modules.qlinear.marlin import MarlinQuantLinear |
| 29 | + |
27 | 30 | if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: |
28 | 31 | try: |
29 | 32 | from modelscope import snapshot_download |
@@ -342,15 +345,14 @@ def from_quantized( |
342 | 345 | raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.") |
343 | 346 | backend = BACKEND.MARLIN |
344 | 347 |
|
345 | | - marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support() |
346 | | - |
347 | | - # check for marlin compat for cuda device onnly |
348 | | - if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA: |
349 | | - unsupported = _validate_marlin_compatibility(qcfg) |
350 | | - if unsupported is None and marlin_compatible: |
351 | | - logger.info( |
352 | | - "Hint: Model is compatible with the Marlin kernel. Marlin is optimized for batched inference on Nvidia GPU: `model = GPTQModel.load(..., backend=BACKEND.MARLIN)`." |
353 | | - ) |
| 348 | + # marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support() |
| 349 | + # check for marlin compat for cuda device only |
| 350 | + # if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA: |
| 351 | + # unsupported = _validate_marlin_compatibility(qcfg) |
| 352 | + # if unsupported is None and marlin_compatible: |
| 353 | + # logger.info( |
| 354 | + # "Hint: Model is compatible with the Marlin kernel. Marlin is optimized for batched inference on Nvidia GPU: `model = GPTQModel.load(..., backend=BACKEND.MARLIN)`." |
| 355 | + # ) |
354 | 356 |
|
355 | 357 | if qcfg.format == FORMAT.BITBLAS: |
356 | 358 | # format bitblas requires bitblas kernel |
@@ -491,14 +493,16 @@ def skip(*args, **kwargs): |
491 | 493 | f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" |
492 | 494 | ) |
493 | 495 |
|
494 | | - t = time.time() |
495 | | - logger.info(f"Format: Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") |
496 | | - model = convert_gptq_v1_to_v2_format( |
497 | | - model, |
498 | | - cfg=qcfg, |
499 | | - qlinear_kernel=preload_qlinear_kernel, |
500 | | - ) |
501 | | - logger.info(f"Format: Conversion complete: {time.time() - t}s") |
| 496 | + # skip v1 to v2 conversion for kernels that can only operate on sym=True (gptq_v1) |
| 497 | + if preload_qlinear_kernel not in [IPEXQuantLinear, MarlinQuantLinear, ExllamaEoraQuantLinear]: |
| 498 | + t = time.time() |
| 499 | + logger.info(f"Format: Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") |
| 500 | + model = convert_gptq_v1_to_v2_format( |
| 501 | + model, |
| 502 | + cfg=qcfg, |
| 503 | + qlinear_kernel=preload_qlinear_kernel, |
| 504 | + ) |
| 505 | + logger.info(f"Format: Conversion complete: {time.time() - t}s") |
502 | 506 |
|
503 | 507 | load_checkpoint_in_model = False |
504 | 508 | qcfg.runtime_format = FORMAT.GPTQ_V2 |
|
0 commit comments