Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def tmp(_, inp, out):
torch.cuda.empty_cache()

if self.quantize_config.format == FORMAT.BITBLAS:
from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear

# BitBLASQuantLinear does not have a pack method and needs to be converted to BitBLAS format when saving.
logger.info("Converting model to BitBlas Format...")
Expand Down Expand Up @@ -515,15 +515,15 @@ def save_quantized(
# no need to set it back, no calculation below
if quantize_config.bits != 4:
cuda_name_modules = {}
from gptqmodel.nn_modules.qlinear.qlinear_cuda import BaseCudaQuantLinear
from gptqmodel.nn_modules.qlinear.qlinear_cuda import CudaQuantLinear
for name, module in model.named_modules():
if isinstance(module, BaseCudaQuantLinear):
if isinstance(module, CudaQuantLinear):
cuda_name_modules[name] = module.gptqmodel_cuda
module.gptqmodel_cuda = None
model = copy.deepcopy(self.model)

for name, modules in model.named_modules():
if isinstance(module, BaseCudaQuantLinear) and name in cuda_name_modules:
if isinstance(module, CudaQuantLinear) and name in cuda_name_modules:
module.gptqmodel_cuda = cuda_name_modules[name]

del cuda_name_modules
Expand Down Expand Up @@ -1109,9 +1109,9 @@ def skip(*args, **kwargs):

# == step6: (optional) warmup triton == #
if backend == Backend.TRITON and warmup_triton:
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear

QuantLinear.warmup(model, seqlen=model.seqlen)
TritonV2QuantLinear.warmup(model, seqlen=model.seqlen)

return cls(
model,
Expand All @@ -1124,9 +1124,9 @@ def warmup_triton(self, enabled: bool = True):
if not enabled:
return

from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear

QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
TritonV2QuantLinear.warmup(self.model, seqlen=self.model.seqlen)

def __getattr__(self, item):
try:
Expand Down
15 changes: 4 additions & 11 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@


class BaseQuantLinear(nn.Module):
# override me
QUANT_TYPE = "base"

SUPPORTED_BITS = []
SUPPORTED_GROUP_SIZE = []
Expand All @@ -17,16 +15,16 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e
err = ""
if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS:
validate = False
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`"
err = f"{cls} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`"
elif cls.SUPPORTED_GROUP_SIZE and group_size not in cls.SUPPORTED_GROUP_SIZE:
validate = False
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
err = f"{cls} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
elif cls.SUPPORTED_SYM and sym not in cls.SUPPORTED_SYM:
validate = False
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`"
err = f"{cls} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`"
elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
validate = False
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"

if not validate and raise_error:
raise NotImplementedError(err)
Expand All @@ -36,8 +34,3 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e
# override me
def post_init(self):
pass


class BaseCudaQuantLinear(BaseQuantLinear):
# override me
QUANT_TYPE = "base-cuda"
7 changes: 3 additions & 4 deletions gptqmodel/nn_modules/qlinear/qlinear_bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.nn as nn
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

from .qlinear_cuda_old import QuantLinear as QuantLinearOld
from .qlinear_cuda_old import CudaOldQuantLinear

logger = getLogger(__name__)

Expand Down Expand Up @@ -65,8 +65,7 @@ def unpack_qzeros(qzeros, bits):
return unpacked_zeros


class QuantLinear(BaseQuantLinear):
QUANT_TYPE = "bitblas"
class BitBLASQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [1, 2, 4]
SUPPORTED_DESC_ACT = [False]
SUPPORTED_SHARDS = False
Expand Down Expand Up @@ -244,7 +243,7 @@ def post_init(self):
param_list.append(self.bias)
self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list]

def repack_from_gptq(self, gptq_module: QuantLinearOld):
def repack_from_gptq(self, gptq_module: CudaOldQuantLinear):
from bitblas.quantization.utils import general_compress

# qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed.
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

logger = getLogger(__name__)


class QuantLinear(BaseCudaQuantLinear):
QUANT_TYPE = "cuda"
class CudaQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [2, 3, 4, 8]

def __init__(
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import torch
import torch.nn as nn
import transformers
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

logger = getLogger(__name__)


class QuantLinear(BaseCudaQuantLinear):
QUANT_TYPE = "cuda-old"
class CudaOldQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [2, 3, 4, 8]

def __init__(
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def ext_q4_matmul(x, q4, q4_width):
return output.view(outshape)


class QuantLinear(BaseQuantLinear):
QUANT_TYPE = "exllama"
class ExllamaQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]


Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
)


class QuantLinear(BaseQuantLinear):
QUANT_TYPE = "exllamav2"
class ExllamaV2QuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]

"""Linear layer implementation with per-group 4-bit quantization of the weights"""
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/nn_modules/qlinear/qlinear_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def _get_perms():
_perm, _scale_perm, _scale_perm_single = _get_perms()


class QuantLinear(BaseQuantLinear):
QUANT_TYPE = "marlin"
class MarlinQuantLinear(BaseQuantLinear):
SUPPORTED_BITS = [4]
SUPPORTED_GROUP_SIZE = [128, -1]
SUPPORTED_DESC_ACT = [False]
Expand Down
4 changes: 1 addition & 3 deletions gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
logger = getLogger(__name__)


class QuantLinear(BaseQuantLinear, TritonModuleMixin):
class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
"""
Triton v2 quantized linear layer.

Expand All @@ -22,8 +22,6 @@ class QuantLinear(BaseQuantLinear, TritonModuleMixin):
dequant and matmul into single kernel.add()
"""

QUANT_TYPE = "tritonv2"

def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs,):
super().__init__()
if bits not in [2, 4, 8]:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from accelerate.utils import find_tied_parameters
from tqdm import tqdm

from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from .model import recurse_getattr, recurse_setattr

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/utils/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear


def exllama_set_max_input_length(model, max_input_length: int):
Expand Down
37 changes: 21 additions & 16 deletions gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections import OrderedDict
from logging import getLogger

from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear as CudaQuantLinear
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear as CudaOldQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear as ExllamaV2QuantLinear
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear
from ..nn_modules.qlinear.qlinear_cuda import CudaQuantLinear
from ..nn_modules.qlinear.qlinear_cuda_old import CudaOldQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import ExllamaV2QuantLinear
from ..nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear
from ..quantization import FORMAT
from .backend import Backend

Expand Down Expand Up @@ -55,18 +55,23 @@ def select_quant_linear(
# Handle the case where backend is not AUTO.
if backend == Backend.TRITON:
logger.info("Using tritonv2 for GPTQ")
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear
return TritonV2QuantLinear
elif backend == Backend.BITBLAS:
from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear
return BitBLASQuantLinear
elif bits == 4 and sym and not desc_act and backend == Backend.MARLIN:
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear
from ..nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear
return MarlinQuantLinear
elif bits == 4 and backend == Backend.EXLLAMA_V2:
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import ExllamaV2QuantLinear
return ExllamaV2QuantLinear
elif bits == 4 and backend == Backend.EXLLAMA:
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
from ..nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear
return ExllamaQuantLinear
elif not desc_act or group_size == -1:
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
from ..nn_modules.qlinear.qlinear_cuda_old import CudaOldQuantLinear
return CudaOldQuantLinear
else:
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear

return QuantLinear
from ..nn_modules.qlinear.qlinear_cuda import CudaQuantLinear
return CudaQuantLinear
3 changes: 1 addition & 2 deletions gptqmodel/utils/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from accelerate.utils import find_tied_parameters
from tqdm import tqdm

from ..nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear
from ..nn_modules.qlinear.qlinear_marlin import _get_perms, unpack_qzeros
from ..nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear, _get_perms, unpack_qzeros
from ..quantization import FORMAT, QuantizeConfig
from .model import recurse_getattr, recurse_setattr

Expand Down
7 changes: 4 additions & 3 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

from ..models._const import CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTED_MODELS
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear as ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear as ExllamaV2QuantLinear
from ..nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.qlinear_exllamav2 import ExllamaV2QuantLinear
from ..nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear
from ..quantization import FORMAT, QuantizeConfig
from .backend import Backend
from .importer import select_quant_linear
Expand Down Expand Up @@ -303,7 +304,7 @@ def pack_model(
zero.to(CPU),
g_idx.to(CPU),
)
if QuantLinear.QUANT_TYPE == "marlin":
if QuantLinear is MarlinQuantLinear:
qlayers[name].pack(layers[name], scale)
else:
qlayers[name].pack(layers[name], scale, zero, g_idx)
Expand Down