Skip to content

Commit 4b8c2f1

Browse files
fix no backend when creating a quant linear (#1329)
1 parent 21b1759 commit 4b8c2f1

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

tests/test_packing_speed.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# -- do not touch
1818
import os
1919

20+
from gptqmodel import BACKEND
21+
2022
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
2123
# -- end do not touch
2224

@@ -86,7 +88,7 @@ class TestRepacking(unittest.TestCase):
8688
_, linear, s = gen_quant4(k, n, group_size)
8789
print("gen_quant: start...end")
8890

89-
def pack(self, qlinearCls):
91+
def pack(self, qlinearCls, backend):
9092
qlinear = qlinearCls(
9193
bits=4,
9294
group_size=self.group_size,
@@ -95,6 +97,7 @@ def pack(self, qlinearCls):
9597
in_features=self.k,
9698
out_features=self.n,
9799
pack_dtype=torch.int32,
100+
backend=backend,
98101
bias=False,
99102
)
100103

@@ -106,14 +109,14 @@ def pack(self, qlinearCls):
106109
[
107110
# [ExllamaQuantLinear, 9.63], # A100 Z3: 36.89 # 4090? 26.5349
108111
# [TritonV2QuantLinear, 9.67], # A100 Z3: 35.04 # 4090? 26.5268
109-
[TorchQuantLinear, 16.63], # A100 Z3 33.56 # 4090? 27.0297
112+
[TorchQuantLinear, BACKEND.TORCH,16.63], # A100 Z3 33.56 # 4090? 27.0297
110113
]
111114
)
112-
def test_pack_speed(self, qlinearCls, expect_time):
115+
def test_pack_speed(self, qlinearCls, backend, expect_time):
113116
start = time.time()
114117
with threadpoolctl.threadpool_limits(limits=1):
115118
for i in range(30):
116-
self.pack(qlinearCls)
119+
self.pack(qlinearCls, backend)
117120
time_usage = time.time() - start
118121
speed = self.k * self.k / time_usage
119122
print(f"{qlinearCls.__name__}, time={time_usage}, speed={speed:.4f}")
@@ -124,14 +127,14 @@ def test_pack_speed(self, qlinearCls, expect_time):
124127
[
125128
# [ExllamaQuantLinear, 9.63], # A100 Z3: 36.89 # 4090? 26.5349
126129
# [TritonV2QuantLinear, 9.67], # A100 Z3: 35.04 # 4090? 26.5268
127-
[TorchQuantLinear, 12.51], # A100 Z3 33.56 # 4090? 27.0297
130+
[TorchQuantLinear, BACKEND.TORCH, 12.51], # A100 Z3 33.56 # 4090? 27.0297
128131
]
129132
)
130-
def test_pack_speed_2_threads(self, qlinearCls, expect_time):
133+
def test_pack_speed_2_threads(self, qlinearCls, backend, expect_time):
131134
start = time.time()
132135
with threadpoolctl.threadpool_limits(limits=2):
133136
for i in range(30):
134-
self.pack(qlinearCls)
137+
self.pack(qlinearCls, backend)
135138
time_usage = time.time() - start
136139
speed = self.k * self.k / time_usage
137140
print(f"{qlinearCls.__name__}, time={time_usage}, speed={speed:.4f}")

tests/test_q4_exllama_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@ def test_exllama(self):
10991099
out_features=n,
11001100
bias=False,
11011101
pack_dtype=pack_dtype,
1102+
backend=BACKEND.EXLLAMA_V1,
11021103
)
11031104
self.assertTrue(isinstance(linear, ExllamaQuantLinear))
11041105

tests/test_q4_exllama_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_exllamav2(self):
6262
out_features=n,
6363
bias=False,
6464
pack_dtype=pack_dtype,
65+
backend=BACKEND.EXLLAMA_V2,
6566
)
6667

6768
self.assertTrue(isinstance(linear, ExllamaV2QuantLinear))

0 commit comments

Comments
 (0)