Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import os
import time
from importlib.metadata import PackageNotFoundError, version
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -464,13 +465,15 @@ def skip(*args, **kwargs):
f"Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)

t = time.time()
logger.info(
f"Compatibility: converting `{FORMAT_FIELD_JSON}` from `{FORMAT.GPTQ}` to `{FORMAT.GPTQ_V2}`.")
model = convert_gptq_v1_to_v2_format(
model,
quantize_config=quantize_config,
qlinear_kernel=preload_qlinear_kernel,
)
logger.info(f"Conversion complete: {time.time()-t}s")
load_checkpoint_in_model = True
quantize_config.runtime_format = FORMAT.GPTQ_V2

Expand Down
7 changes: 7 additions & 0 deletions gptqmodel/nn_modules/qlinear/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ def _forward(self, x, x_dtype, out_shape):
out = out + self.bias if self.bias is not None else out
return out

# clear gptq only weights: useful in de-quantization
def _empty_gptq_only_weights(self):
self.qzeros = None
self.qweight = None
self.g_idx = None
self.scales = None

@torch.no_grad()
def dequantize_weight(self, num_itr=1):
if self.wf.device != self.qzeros.device:
Expand Down
24 changes: 22 additions & 2 deletions gptqmodel/utils/mlx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
import gc
import logging

import mlx.core.metal

from .progress import ProgressBar
from .torch import torch_empty_cache
from ..quantization import QuantizeConfig, FORMAT
from transformers import PreTrainedModel
from ..nn_modules.qlinear.torch import TorchQuantLinear
Expand Down Expand Up @@ -39,12 +46,23 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, gptq_model: PreTrainedMod

# Convert weights
weights = {}
for name, module in gptq_model.named_modules():
n = 1
pb = ProgressBar(gptq_model.named_modules(), total=len(list(gptq_model.named_modules())))
for name, module in pb:
pb.set_description(f" Converting to mlx: {name}")
if isinstance(module, TorchQuantLinear):
weights[f"{name}.weight"] = mx.array(
module.dequantize_weight().T.detach().to("cpu", torch.float16).numpy()
)

module._empty_gptq_only_weights()

if n % 14 == 0:
# Below saves memory but also make each iter slower: test call every N loop
torch_empty_cache()

n += 1

elif hasattr(module, "weight") and (
name != "lm_head" if config.get("tie_word_embeddings", False) else True):
weights[f"{name}.weight"] = mx.array(
Expand All @@ -57,6 +75,8 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, gptq_model: PreTrainedMod
module.bias.detach().to("cpu", torch.float16).numpy()
)

torch_empty_cache()

# Load and quantize weights
mlx_model.load_weights(list(weights.items()))
weights, mlx_config = quantize_model(mlx_model, config, q_group_size=gptq_config["group_size"],
Expand Down Expand Up @@ -121,4 +141,4 @@ def mlx_generate(model, tokenizer, **kwargs,):
if "min_tokens_to_keep" in kwargs:
sampling_params["min_tokens_to_keep"] = kwargs.pop("min_tokens_to_keep", None)

return generate(model=model, tokenizer=tokenizer, prompt=prompt, formatter=formatter ,verbose=verbose, **sampling_params)
return generate(model=model, tokenizer=tokenizer, prompt=prompt, formatter=formatter ,verbose=verbose, **sampling_params)
1 change: 1 addition & 0 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from __future__ import annotations

import functools
import gc
import hashlib
import json
import operator
Expand Down
13 changes: 13 additions & 0 deletions gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
HAS_CUDA = False
HAS_XPU = False
HAS_MPS = False
HAS_MLX = False

if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_available") and torch.cuda.is_available():
HAS_CUDA = True
Expand All @@ -30,6 +31,12 @@
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available():
HAS_MPS = True

# mlx check
try:
import mlx.core.metal
HAS_MLX = True
except BaseException:
pass

def torch_sync(device: torch.device = None):
# check all backends
Expand Down Expand Up @@ -61,6 +68,8 @@ def torch_empty_cache(device: torch.device = None, gc: bool = True):
torch.xpu.empty_cache()
if HAS_MPS:
torch.mps.empty_cache()
if HAS_MLX:
mlx.core.metal.clear_cache()
return

# if device passed, only execute for device backend
Expand All @@ -70,3 +79,7 @@ def torch_empty_cache(device: torch.device = None, gc: bool = True):
torch.xpu.empty_cache()
elif device.type == "mps":
torch.mps.empty_cache()

# mlx is detached from pytorch
if HAS_MLX:
mlx.core.metal.clear_cache()
Loading