Skip to content

Commit 2679106

Browse files
authored
cleanup mlx code (#1101)
* cleanup mlx code * format
1 parent 1c96a66 commit 2679106

23 files changed

+77
-74
lines changed

gptqmodel/models/auto.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,10 @@
3434
from huggingface_hub import list_repo_files # noqa: E402
3535
from transformers import AutoConfig # noqa: E402
3636

37-
from ..quantization import QUANT_CONFIG_FILENAME, FORMAT # noqa: E402
37+
from ..quantization import QUANT_CONFIG_FILENAME # noqa: E402
3838
from ..utils import BACKEND, EVAL # noqa: E402
3939
from ..utils.logger import setup_logger # noqa: E402
4040
from ..utils.model import check_and_get_model_type # noqa: E402
41-
from ..nn_modules.qlinear.torch import TorchQuantLinear
4241
from .base import BaseGPTQModel, QuantizeConfig # noqa: E402
4342
from .definitions.baichuan import BaiChuanGPTQ # noqa: E402
4443
from .definitions.bloom import BloomGPTQ # noqa: E402
@@ -341,28 +340,29 @@ def eval(
341340
return results
342341
else:
343342
raise ValueError("Eval framework support: EVAL.LM_EVAL, EVAL.EVALPLUS")
344-
343+
345344
@staticmethod
346345
def export(model_id_or_path: str, target_path: str, format: str, trust_remote_code: bool = False):
347346
# load config
348347
config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
349348

350349
if not config.quantization_config:
351350
raise ValueError("Model is not quantized")
352-
351+
353352
gptq_config = config.quantization_config
354-
353+
355354
# load gptq model
356355
gptq_model = GPTQModel.load(model_id_or_path, backend=BACKEND.TORCH)
357356

358357
if format == "mlx":
359358
try:
360-
from mlx_lm.utils import save_weights, save_config
359+
from mlx_lm.utils import save_config, save_weights
360+
361361
from ..utils.mlx import convert_gptq_to_mlx_weights
362362
except ImportError:
363363
raise ValueError("MLX not installed. Please install via `pip install gptqmodel[mlx] --no-build-isolation`.")
364-
365-
mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, gptq_model.model, gptq_config)
364+
365+
mlx_weights, mlx_config = convert_gptq_to_mlx_weights(model_id_or_path, gptq_model, gptq_config)
366366

367367
save_weights(target_path, mlx_weights, donate_weights=True)
368368

gptqmodel/models/base.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,15 @@
1515

1616
from __future__ import annotations
1717

18-
import copy
1918
import json
2019
import os
2120
import shutil
2221
import time
23-
from typing import Any, Dict, List, Optional, Union, Tuple
22+
from typing import Any, Dict, List, Optional, Tuple, Union
2423

2524
import torch
2625
import torch.nn as nn
2726
from packaging import version
28-
from torch import autocast
2927
from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizerBase, modeling_utils
3028

3129
from ..nn_modules.hooked_linear import replace_linear_with_hooked_linear
@@ -36,11 +34,12 @@
3634
from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory
3735
from ..utils.importer import select_quant_linear
3836
from ..utils.logger import setup_logger
39-
from ..utils.model import (MODALITY, check_to_quantized, find_layers, get_device, get_module_by_name_prefix,
40-
get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model, get_module)
37+
from ..utils.model import (MODALITY, check_to_quantized, find_layers, get_device,
38+
get_module, get_module_by_name_prefix, get_moe_layer_modules,
39+
move_to, nested_move_to, normalize_tokenizer, pack_model)
4140
from ..utils.progress import ProgressBar
4241
from ..utils.torch import torch_empty_cache
43-
from ._const import CPU, DEVICE, CUDA, SUPPORTS_MODULE_TYPES
42+
from ._const import CPU, DEVICE, SUPPORTS_MODULE_TYPES
4443
from .loader import ModelLoader
4544
from .writer import (QUANT_LOG_DAMP, QUANT_LOG_FWD_TIME, QUANT_LOG_LAYER,
4645
QUANT_LOG_LOSS, QUANT_LOG_MODULE, QUANT_LOG_TIME, ModelWriter)
@@ -402,8 +401,8 @@ def collate_batch(batch):
402401
tied_keys = self.model._tied_weights_keys
403402
for item in tied_keys:
404403
if self.lm_head in item:
405-
raise NotImplementedError(f"quantizing lm_head with tied weights has not been supported "
406-
f"currently")
404+
raise NotImplementedError("quantizing lm_head with tied weights has not been supported "
405+
"currently")
407406

408407
lm_head_module = get_module(self.model, key=self.lm_head)
409408
if get_module(self.model, key=self.lm_head) is None:
@@ -566,7 +565,7 @@ def store_lm_head_input_hook(_, args, kwargs):
566565
for i in layer_pb:
567566
is_lm_head = i >= layer_count
568567
if is_lm_head:
569-
layer_pb.set_description(f"Quantizing lm_head")
568+
layer_pb.set_description("Quantizing lm_head")
570569
layer = get_module(self.model, key=self.lm_head)
571570
if self.quantize_config.lm_head and not self.quantize_config.lm_head_low_gpu_mem_usage:
572571
layer_inputs = lm_head_inputs

gptqmodel/models/loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,7 @@ def skip(*args, **kwargs):
470470
)
471471

472472
t = time.time()
473-
logger.info(
474-
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`.")
473+
logger.info(f"Converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`.")
475474
model = convert_gptq_v1_to_v2_format(
476475
model,
477476
quantize_config=quantize_config,
@@ -578,9 +577,10 @@ def skip(*args, **kwargs):
578577
if backend == BACKEND.MLX:
579578
import tempfile
580579
try:
581-
from ..utils.mlx import convert_gptq_to_mlx_weights, mlx_generate
582-
from mlx_lm.utils import save_weights, save_config
583580
from mlx_lm import load
581+
from mlx_lm.utils import save_config, save_weights
582+
583+
from ..utils.mlx import convert_gptq_to_mlx_weights, mlx_generate
584584
except ModuleNotFoundError as exception:
585585
raise type(exception)(
586586
"GPTQModel load mlx model required dependencies are not installed.",
@@ -593,7 +593,7 @@ def skip(*args, **kwargs):
593593
save_weights(temp_dir, mlx_weights, donate_weights=True)
594594
save_config(mlx_config, config_path=temp_dir + "/config.json")
595595
tokenizer.save_pretrained(temp_dir)
596-
596+
597597
model, _ = load(temp_dir)
598598

599599
cls.generate = lambda _, **kwargs: mlx_generate(model=model, tokenizer=tokenizer, **kwargs)

gptqmodel/nn_modules/hooked_linear.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torch
1717
from transformers.pytorch_utils import Conv1D
1818

19+
1920
# Models using conv1d: gpt2
2021
class HookedConv1D(Conv1D):
2122
def __init__(self, nf: int, nx: int) -> None:

gptqmodel/nn_modules/qlinear/marlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def validate_device(cls, device: DEVICE):
318318
if device == DEVICE.CUDA:
319319
if IS_ROCM:
320320
raise NotImplementedError("Marlin kernel is not supported on ROCm.")
321-
321+
322322
if CUDA_VISIBLE_DEVICES is None:
323323
has_cuda_v8 = all(torch.cuda.get_device_capability(i)[0] >= 8 for i in range(torch.cuda.device_count()))
324324
else:

gptqmodel/nn_modules/qlinear/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22

3+
34
# Copied from https://github.com/IST-DASLab/marlin/pull/1
45
@torch.no_grad()
56
def unpack_4bit_to_32bit_signed(qweight, qzeros):
@@ -35,4 +36,4 @@ def dequantize_4bits_weight(layer):
3536
scales = scales.repeat_interleave(group_size, dim=0)
3637
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
3738
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
38-
return unpacked_qweight.T, unpacked_qzeros
39+
return unpacked_qweight.T, unpacked_qzeros

gptqmodel/quantization/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
import re
2020
from dataclasses import dataclass, field, fields
2121
from importlib.metadata import version as pkg_version
22-
from os.path import isdir, join
22+
from os.path import join
2323
from typing import Any, Dict, List, Optional, Tuple, Union
2424

2525
import torch
2626
from packaging import version
27-
from transformers.utils.hub import cached_file
2827

2928
from ..utils.logger import setup_logger
3029

@@ -367,7 +366,7 @@ def to_dict(self):
367366

368367
def calculate_bits_per_weight(self):
369368
bpw = ((self.group_size * self.bits) + 16 * 2) / self.group_size
370-
logger.info(f"Effective BPW (bits per weight): {bpw} bits")
369+
logger.info(f"Effective Quantization BPW (bits per weight): {bpw} bpw, based on [bits: {self.bits}, group_size: {self.group_size}]")
371370

372371
@dataclass
373372
class AutoRoundQuantizeConfig(QuantizeConfig):

gptqmodel/quantization/gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def quantize(
127127
start = time.time()
128128
if self.device.type not in ["mps", "cpu"]:
129129
self.layer.weight.data = self.layer.weight.data.cpu()
130-
130+
131131
# TODO: waiting for pytorch implementation of ops for MPS
132132
if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1":
133133
raise RuntimeError("For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.")

gptqmodel/utils/importer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
import torch
2121

22-
from .rocm import IS_ROCM
2322
from ..models._const import DEVICE, normalize_device
2423
from ..nn_modules.qlinear import BaseQuantLinear
2524
from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear
@@ -33,6 +32,7 @@
3332
from ..quantization import FORMAT
3433
from ..utils.logger import setup_logger
3534
from . import BACKEND
35+
from .rocm import IS_ROCM
3636
from .torch import HAS_CUDA, HAS_MPS, HAS_XPU
3737

3838
message_logged = False

gptqmodel/utils/logger.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import logging
1717

18+
# global static/shared logger instance
1819
logger = None
20+
1921
def setup_logger():
2022
global logger
2123
if logger is not None:
@@ -27,6 +29,6 @@ def setup_logger():
2729
handler.setFormatter(formatter)
2830
logger.propagate = False
2931
logger.addHandler(handler)
30-
logger.setLevel(logging.INFO)
32+
logger.setLevel(logging.DEBUG)
3133

3234
return logger

0 commit comments

Comments
 (0)