1717# -- do not touch
1818import os
1919
20+ from gptqmodel import BACKEND
21+
2022os .environ ["CUDA_DEVICE_ORDER" ] = "PCI_BUS_ID"
2123# -- end do not touch
2224
@@ -86,7 +88,7 @@ class TestRepacking(unittest.TestCase):
8688 _ , linear , s = gen_quant4 (k , n , group_size )
8789 print ("gen_quant: start...end" )
8890
89- def pack (self , qlinearCls ):
91+ def pack (self , qlinearCls , backend ):
9092 qlinear = qlinearCls (
9193 bits = 4 ,
9294 group_size = self .group_size ,
@@ -95,6 +97,7 @@ def pack(self, qlinearCls):
9597 in_features = self .k ,
9698 out_features = self .n ,
9799 pack_dtype = torch .int32 ,
100+ backend = backend ,
98101 bias = False ,
99102 )
100103
@@ -106,14 +109,14 @@ def pack(self, qlinearCls):
106109 [
107110 # [ExllamaQuantLinear, 9.63], # A100 Z3: 36.89 # 4090? 26.5349
108111 # [TritonV2QuantLinear, 9.67], # A100 Z3: 35.04 # 4090? 26.5268
109- [TorchQuantLinear , 16.63 ], # A100 Z3 33.56 # 4090? 27.0297
112+ [TorchQuantLinear , BACKEND . TORCH , 16.63 ], # A100 Z3 33.56 # 4090? 27.0297
110113 ]
111114 )
112- def test_pack_speed (self , qlinearCls , expect_time ):
115+ def test_pack_speed (self , qlinearCls , backend , expect_time ):
113116 start = time .time ()
114117 with threadpoolctl .threadpool_limits (limits = 1 ):
115118 for i in range (30 ):
116- self .pack (qlinearCls )
119+ self .pack (qlinearCls , backend )
117120 time_usage = time .time () - start
118121 speed = self .k * self .k / time_usage
119122 print (f"{ qlinearCls .__name__ } , time={ time_usage } , speed={ speed :.4f} " )
@@ -124,14 +127,14 @@ def test_pack_speed(self, qlinearCls, expect_time):
124127 [
125128 # [ExllamaQuantLinear, 9.63], # A100 Z3: 36.89 # 4090? 26.5349
126129 # [TritonV2QuantLinear, 9.67], # A100 Z3: 35.04 # 4090? 26.5268
127- [TorchQuantLinear , 12.51 ], # A100 Z3 33.56 # 4090? 27.0297
130+ [TorchQuantLinear , BACKEND . TORCH , 12.51 ], # A100 Z3 33.56 # 4090? 27.0297
128131 ]
129132 )
130- def test_pack_speed_2_threads (self , qlinearCls , expect_time ):
133+ def test_pack_speed_2_threads (self , qlinearCls , backend , expect_time ):
131134 start = time .time ()
132135 with threadpoolctl .threadpool_limits (limits = 2 ):
133136 for i in range (30 ):
134- self .pack (qlinearCls )
137+ self .pack (qlinearCls , backend )
135138 time_usage = time .time () - start
136139 speed = self .k * self .k / time_usage
137140 print (f"{ qlinearCls .__name__ } , time={ time_usage } , speed={ speed :.4f} " )
0 commit comments