Skip to content

Commit 96a2573

Browse files
remove Quant_Type and rename QuantLinear with type (#116)
1 parent b64b2a0 commit 96a2573

File tree

14 files changed

+51
-61
lines changed

14 files changed

+51
-61
lines changed

gptqmodel/models/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def tmp(_, inp, out):
435435
torch.cuda.empty_cache()
436436

437437
if self.quantize_config.format == FORMAT.BITBLAS:
438-
from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear
438+
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear
439439

440440
# BitBLASQuantLinear does not have a pack method and needs to be converted to BitBLAS format when saving.
441441
logger.info("Converting model to BitBlas Format...")
@@ -515,15 +515,15 @@ def save_quantized(
515515
# no need to set it back, no calculation below
516516
if quantize_config.bits != 4:
517517
cuda_name_modules = {}
518-
from gptqmodel.nn_modules.qlinear.qlinear_cuda import BaseCudaQuantLinear
518+
from gptqmodel.nn_modules.qlinear.qlinear_cuda import CudaQuantLinear
519519
for name, module in model.named_modules():
520-
if isinstance(module, BaseCudaQuantLinear):
520+
if isinstance(module, CudaQuantLinear):
521521
cuda_name_modules[name] = module.gptqmodel_cuda
522522
module.gptqmodel_cuda = None
523523
model = copy.deepcopy(self.model)
524524

525525
for name, modules in model.named_modules():
526-
if isinstance(module, BaseCudaQuantLinear) and name in cuda_name_modules:
526+
if isinstance(module, CudaQuantLinear) and name in cuda_name_modules:
527527
module.gptqmodel_cuda = cuda_name_modules[name]
528528

529529
del cuda_name_modules
@@ -1109,9 +1109,9 @@ def skip(*args, **kwargs):
11091109

11101110
# == step6: (optional) warmup triton == #
11111111
if backend == Backend.TRITON and warmup_triton:
1112-
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
1112+
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear
11131113

1114-
QuantLinear.warmup(model, seqlen=model.seqlen)
1114+
TritonV2QuantLinear.warmup(model, seqlen=model.seqlen)
11151115

11161116
return cls(
11171117
model,
@@ -1124,9 +1124,9 @@ def warmup_triton(self, enabled: bool = True):
11241124
if not enabled:
11251125
return
11261126

1127-
from ..nn_modules.qlinear.qlinear_tritonv2 import QuantLinear
1127+
from ..nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear
11281128

1129-
QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
1129+
TritonV2QuantLinear.warmup(self.model, seqlen=self.model.seqlen)
11301130

11311131
def __getattr__(self, item):
11321132
try:

gptqmodel/nn_modules/qlinear/__init__.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33

44
class BaseQuantLinear(nn.Module):
5-
# override me
6-
QUANT_TYPE = "base"
75

86
SUPPORTED_BITS = []
97
SUPPORTED_GROUP_SIZE = []
@@ -17,16 +15,16 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e
1715
err = ""
1816
if cls.SUPPORTED_BITS and bits not in cls.SUPPORTED_BITS:
1917
validate = False
20-
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`"
18+
err = f"{cls} only supports `{cls.SUPPORTED_BITS}` bits: actual bits = `{bits}`"
2119
elif cls.SUPPORTED_GROUP_SIZE and group_size not in cls.SUPPORTED_GROUP_SIZE:
2220
validate = False
23-
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
21+
err = f"{cls} only supports `{cls.SUPPORTED_GROUP_SIZE}` group_size: actual group_size = `{group_size}`"
2422
elif cls.SUPPORTED_SYM and sym not in cls.SUPPORTED_SYM:
2523
validate = False
26-
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`"
24+
err = f"{cls} only supports `{cls.SUPPORTED_SYM}` bits: actual sym = `{sym}`"
2725
elif cls.SUPPORTED_DESC_ACT and desc_act not in cls.SUPPORTED_DESC_ACT:
2826
validate = False
29-
err = f"{cls.QUANT_TYPE} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
27+
err = f"{cls} only supports `{cls.SUPPORTED_DESC_ACT}` bits: actual desc_act = `{desc_act}`"
3028

3129
if not validate and raise_error:
3230
raise NotImplementedError(err)
@@ -36,8 +34,3 @@ def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, raise_e
3634
# override me
3735
def post_init(self):
3836
pass
39-
40-
41-
class BaseCudaQuantLinear(BaseQuantLinear):
42-
# override me
43-
QUANT_TYPE = "base-cuda"

gptqmodel/nn_modules/qlinear/qlinear_bitblas.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn as nn
1010
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
1111

12-
from .qlinear_cuda_old import QuantLinear as QuantLinearOld
12+
from .qlinear_cuda_old import CudaOldQuantLinear
1313

1414
logger = getLogger(__name__)
1515

@@ -66,8 +66,7 @@ def unpack_qzeros(qzeros, bits):
6666
return unpacked_zeros
6767

6868

69-
class QuantLinear(BaseQuantLinear):
70-
QUANT_TYPE = "bitblas"
69+
class BitBLASQuantLinear(BaseQuantLinear):
7170
SUPPORTED_BITS = [1, 2, 4]
7271
SUPPORTED_DESC_ACT = [False]
7372
SUPPORTED_SHARDS = False
@@ -245,7 +244,7 @@ def post_init(self):
245244
param_list.append(self.bias)
246245
self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list]
247246

248-
def repack_from_gptq(self, gptq_module: QuantLinearOld):
247+
def repack_from_gptq(self, gptq_module: CudaOldQuantLinear):
249248
from bitblas.quantization.utils import general_compress
250249

251250
# qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed.

gptqmodel/nn_modules/qlinear/qlinear_cuda.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
import torch
88
import torch.nn as nn
99
import transformers
10-
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear
10+
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
1111

1212
logger = getLogger(__name__)
1313

1414

15-
class QuantLinear(BaseCudaQuantLinear):
16-
QUANT_TYPE = "cuda"
15+
class CudaQuantLinear(BaseQuantLinear):
1716
SUPPORTED_BITS = [2, 3, 4, 8]
1817

1918
def __init__(

gptqmodel/nn_modules/qlinear/qlinear_cuda_old.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
import torch
88
import torch.nn as nn
99
import transformers
10-
from gptqmodel.nn_modules.qlinear import BaseCudaQuantLinear
10+
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
1111

1212
logger = getLogger(__name__)
1313

1414

15-
class QuantLinear(BaseCudaQuantLinear):
16-
QUANT_TYPE = "cuda-old"
15+
class CudaOldQuantLinear(BaseQuantLinear):
1716
SUPPORTED_BITS = [2, 3, 4, 8]
1817

1918
def __init__(

gptqmodel/nn_modules/qlinear/qlinear_exllama.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def ext_q4_matmul(x, q4, q4_width):
3434
return output.view(outshape)
3535

3636

37-
class QuantLinear(BaseQuantLinear):
38-
QUANT_TYPE = "exllama"
37+
class ExllamaQuantLinear(BaseQuantLinear):
3938
SUPPORTED_BITS = [4]
4039

4140

gptqmodel/nn_modules/qlinear/qlinear_exllamav2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
9595
)
9696

9797

98-
class QuantLinear(BaseQuantLinear):
99-
QUANT_TYPE = "exllamav2"
98+
class ExllamaV2QuantLinear(BaseQuantLinear):
10099
SUPPORTED_BITS = [4]
101100

102101
"""Linear layer implementation with per-group 4-bit quantization of the weights"""

gptqmodel/nn_modules/qlinear/qlinear_marlin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def _get_perms():
6161
_perm, _scale_perm, _scale_perm_single = _get_perms()
6262

6363

64-
class QuantLinear(BaseQuantLinear):
65-
QUANT_TYPE = "marlin"
64+
class MarlinQuantLinear(BaseQuantLinear):
6665
SUPPORTED_BITS = [4]
6766
SUPPORTED_GROUP_SIZE = [128, -1]
6867
SUPPORTED_DESC_ACT = [False]

gptqmodel/nn_modules/qlinear/qlinear_tritonv2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
logger = getLogger(__name__)
1414

1515

16-
class QuantLinear(BaseQuantLinear, TritonModuleMixin):
16+
class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin):
1717
"""
1818
Triton v2 quantized linear layer.
1919
@@ -22,8 +22,6 @@ class QuantLinear(BaseQuantLinear, TritonModuleMixin):
2222
dequant and matmul into single kernel.add()
2323
"""
2424

25-
QUANT_TYPE = "tritonv2"
26-
2725
def __init__(self, bits, group_size, infeatures, outfeatures, bias, **kwargs,):
2826
super().__init__()
2927
if bits not in [2, 4, 8]:

gptqmodel/utils/bitblas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from accelerate.utils import find_tied_parameters
99
from tqdm import tqdm
1010

11-
from ..nn_modules.qlinear.qlinear_bitblas import QuantLinear as BitBLASQuantLinear
11+
from ..nn_modules.qlinear.qlinear_bitblas import BitBLASQuantLinear
1212
from ..quantization import FORMAT, QuantizeConfig
1313
from .model import recurse_getattr, recurse_setattr
1414

0 commit comments

Comments
 (0)