Skip to content

Commit 1f06127

Browse files
authored
marlin fp32 mode should also be enabled if kernel was selected due to backend.auto (#1318)
Signed-off-by: Qubitium <[email protected]>
1 parent d30c983 commit 1f06127

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch.nn.parameter import Parameter
2828

2929
from ...models._const import DEVICE, PLATFORM
30+
from ...utils.logger import setup_logger
3031
from ...utils.rocm import IS_ROCM
3132

3233
marlin_import_exception = None
@@ -35,6 +36,8 @@
3536
except ImportError as e:
3637
marlin_import_exception = e
3738

39+
logger = setup_logger()
40+
fp32_warning_logged = False
3841

3942
GPTQ_MARLIN_TILE = 16
4043
GPTQ_MARLIN_MIN_THREAD_N = 64
@@ -163,7 +166,6 @@ def apply_gptq_marlin_linear(
163166

164167
return output.reshape(out_shape)
165168

166-
167169
class MarlinQuantLinear(BaseQuantLinear):
168170
SUPPORTS_BITS = [4, 8]
169171
SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128]
@@ -220,7 +222,12 @@ def __init__(
220222
**kwargs)
221223

222224
# toggle fp32 mode depending on MARLIN or MARLIN_FP16 backend
223-
self.fp32 = True if self.backend is BACKEND.MARLIN else False
225+
self.fp32 = True if self.backend in [BACKEND.MARLIN, BACKEND.AUTO] else False
226+
227+
global fp32_warning_logged
228+
if not fp32_warning_logged:
229+
fp32_warning_logged = True
230+
logger.warn("Kernel: Marlin FP16 mode is activated with reduced accuracy. Use default Marlin model for improved inference quality.")
224231

225232
# Determine sharding
226233
if marlin_repeat_scales_on_all_ranks(desc_act,

gptqmodel/utils/logger.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def _process(self, op: Callable, msg, *args, **kwargs):
8484
handler.flush = sys.stdout.flush
8585
logger.addHandler(handler)
8686

87+
# fix warnings about warn() deprecated
88+
if hasattr(logger, "warning"):
89+
logger.warn = logger.warning
90+
8791
return logger
8892

8993

0 commit comments

Comments
 (0)