Skip to content

Commit 9a56c81

Browse files
fix unit test quantlinear name (ModelCloud#138)
1 parent 192e660 commit 9a56c81

File tree

7 files changed

+11
-11
lines changed

7 files changed

+11
-11
lines changed

tests/test_lm_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_load(self):
3535
model = GPTQModel.from_quantized(self.MODEL_ID, use_safetensors=True, device=self.DEVICE)
3636

3737
# validate lm_head is loaded as quantized layer
38-
assert model.model.lm_head.__class__.__name__ == "QuantLinear"
38+
assert model.model.lm_head.__class__.__name__ == "ExllamaV2QuantLinear"
3939

4040
res = model.model.generate(
4141
**inputs, num_beams=1, min_new_tokens=1, max_new_tokens=128, repetition_penalty=1.25

tests/test_q4_cuda.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import unittest # noqa: E402
1010

1111
import torch # noqa: E402
12-
from gptqmodel.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearCudaOld # noqa: E402
12+
from gptqmodel.nn_modules.qlinear.qlinear_cuda_old import CudaOldQuantLinear # noqa: E402
1313
from parameterized import parameterized # noqa: E402
1414

1515
try:
@@ -561,7 +561,7 @@ def test_cuda_old(self, use_half2: bool):
561561
device = "cuda"
562562

563563
weight_dtype = torch.float16 if use_half2 else torch.float32
564-
linear = QuantLinearCudaOld(
564+
linear = CudaOldQuantLinear(
565565
bits=4,
566566
group_size=group_size,
567567
desc_act=False,

tests/test_q4_exallama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch # noqa: E402
1212
from gptqmodel import GPTQModel, exllama_set_max_input_length # noqa: E402
1313
from gptqmodel.models._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH # noqa: E402
14-
from gptqmodel.nn_modules.qlinear.qlinear_exllama import QuantLinear # noqa: E402
14+
from gptqmodel.nn_modules.qlinear.qlinear_exllama import ExllamaQuantLinear # noqa: E402
1515
from gptqmodel.quantization import FORMAT
1616
from gptqmodel.utils.importer import select_quant_linear # noqa: E402
1717
from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402
@@ -1078,7 +1078,7 @@ def test_exllama(self):
10781078
outfeatures=n,
10791079
bias=False,
10801080
)
1081-
self.assertTrue(isinstance(linear, QuantLinear))
1081+
self.assertTrue(isinstance(linear, ExllamaQuantLinear))
10821082

10831083
torch.manual_seed(42)
10841084

tests/test_q4_exallama_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch # noqa: E402
1010
from gptqmodel import Backend, GPTQModel # noqa: E402
11-
from gptqmodel.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear # noqa: E402
11+
from gptqmodel.nn_modules.qlinear.qlinear_exllamav2 import ExllamaV2QuantLinear # noqa: E402
1212
from gptqmodel.quantization import FORMAT
1313
from gptqmodel.utils.importer import select_quant_linear # noqa: E402
1414
from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402
@@ -46,7 +46,7 @@ def test_exllamav2(self):
4646
bias=False,
4747
)
4848

49-
self.assertTrue(isinstance(linear, QuantLinear))
49+
self.assertTrue(isinstance(linear, ExllamaV2QuantLinear))
5050

5151
torch.manual_seed(42)
5252

tests/test_q4_marlin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch # noqa: E402
1010
from gptqmodel import Backend, GPTQModel # noqa: E402
11-
from gptqmodel.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear # noqa: E402
11+
from gptqmodel.nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear # noqa: E402
1212
from transformers import AutoTokenizer # noqa: E402
1313

1414

tests/test_q4_triton.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch # noqa: E402
1010
from gptqmodel import Backend, GPTQModel # noqa: E402
11-
from gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import QuantLinear as TritonV2QuantLinear # noqa: E402
11+
from gptqmodel.nn_modules.qlinear.qlinear_tritonv2 import TritonV2QuantLinear # noqa: E402
1212
from transformers import AutoTokenizer # noqa: E402
1313

1414
GENERATE_EVAL_SIZE = 100

tests/test_repacking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import torch.nn as nn # noqa: E402
1313
import gptqmodel_marlin_cuda # noqa: E402
1414
# isort: on
15-
from gptqmodel.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as CudaOldQuantLinear # noqa: E402
16-
from gptqmodel.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear # noqa: E402
15+
from gptqmodel.nn_modules.qlinear.qlinear_cuda_old import CudaOldQuantLinear # noqa: E402
16+
from gptqmodel.nn_modules.qlinear.qlinear_marlin import MarlinQuantLinear # noqa: E402
1717
from gptqmodel.nn_modules.qlinear.qlinear_marlin import _get_perms, dequantize_weight # noqa: E402
1818

1919

0 commit comments

Comments
 (0)