|
47 | 47 | from ..utils.backend import BACKEND |
48 | 48 | from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear |
49 | 49 | from ..utils.logger import setup_logger |
50 | | -from ..utils.marlin import (_validate_marlin_compatibility, |
51 | | - _validate_marlin_device_support) |
| 50 | +from ..utils.marlin import _validate_marlin_compatibility, _validate_marlin_device_support |
52 | 51 | from ..utils.model import (auto_dtype, convert_gptq_v1_to_v2_format, find_modules, get_checkpoints, |
53 | 52 | get_moe_layer_modules, gptqmodel_post_init, load_checkpoint_in_model_then_tie_weights, |
54 | 53 | make_quant, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes) |
@@ -339,14 +338,14 @@ def from_quantized( |
339 | 338 |
|
340 | 339 | if qcfg.format == FORMAT.MARLIN: |
341 | 340 | # format marlin requires marlin kernel |
342 | | - if backend != BACKEND.MARLIN and backend != BACKEND.AUTO: |
| 341 | + if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and backend != BACKEND.AUTO: |
343 | 342 | raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.") |
344 | 343 | backend = BACKEND.MARLIN |
345 | 344 |
|
346 | 345 | marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support() |
347 | 346 |
|
348 | 347 | # check for marlin compat for cuda device onnly |
349 | | - if backend != BACKEND.MARLIN and device == DEVICE.CUDA: |
| 348 | + if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA: |
350 | 349 | unsupported = _validate_marlin_compatibility(qcfg) |
351 | 350 | if unsupported is None and marlin_compatible: |
352 | 351 | logger.info( |
@@ -504,7 +503,7 @@ def skip(*args, **kwargs): |
504 | 503 | load_checkpoint_in_model = False |
505 | 504 | qcfg.runtime_format = FORMAT.GPTQ_V2 |
506 | 505 |
|
507 | | - if backend == BACKEND.MARLIN and ( |
| 506 | + if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and ( |
508 | 507 | preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN): |
509 | 508 | if is_sharded: |
510 | 509 | raise ValueError( |
@@ -541,7 +540,7 @@ def skip(*args, **kwargs): |
541 | 540 |
|
542 | 541 | # If we use marlin or bitblas to load the quantized model, the model is already a converted model, |
543 | 542 | # and we no longer need to call load_checkpoint_in_model() |
544 | | - if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.BITBLAS]: |
| 543 | + if load_checkpoint_in_model and backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]: |
545 | 544 | load_checkpoint_in_model_then_tie_weights( |
546 | 545 | model, |
547 | 546 | dtype=torch_dtype, |
|
0 commit comments