@@ -95,6 +95,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
95
95
TEST_CONFIGS_DEQUANT = list (itertools .product (SHAPES , INNERKTILES , QGROUP_SIZES ))
96
96
97
97
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
98
+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
98
99
@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
99
100
def test_unpack_tensor_core_tiled_layout_correctness (shape , inner_k_tiles ):
100
101
N , K = shape
@@ -107,14 +108,15 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
107
108
108
109
# TODO: Fix "test_aot_dispatch_dynamic" test failure
109
110
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
111
+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
110
112
@pytest .mark .parametrize ("shape, inner_k_tiles" , TEST_CONFIGS_UNPACK , ids = str )
111
113
def test_unpack_tensor_core_tiled_layout_op (shape , inner_k_tiles ):
112
114
test_utils = [
113
115
"test_schema" ,
114
116
"test_autograd_registration" ,
115
117
"test_faketensor" ,
116
118
]
117
-
119
+
118
120
# TODO: Figure out why test fails unless torch >= 2.5
119
121
if TORCH_VERSION_AFTER_2_5 :
120
122
test_utils .append ("test_aot_dispatch_dynamic" )
@@ -137,10 +139,10 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
137
139
assert scales .shape == zeros .shape
138
140
139
141
midpoint = 2 ** (nbits - 1 )
140
-
142
+
141
143
#Convert fron u4 -> s4 and upcast to bfloat16
142
144
q = q .sub (midpoint ).to (dtype )
143
-
145
+
144
146
# Dequantize
145
147
q = q .reshape (- 1 , group_size )
146
148
dq = q * scales .reshape (- 1 , 1 ) + zeros .reshape (- 1 , 1 )
@@ -149,21 +151,22 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
149
151
150
152
151
153
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
154
+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
152
155
@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
153
156
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant (shape , inner_k_tiles , group_size ):
154
157
n , k = shape
155
- dtype = torch .bfloat16
158
+ dtype = torch .bfloat16
156
159
157
160
device = "cuda"
158
161
159
162
t = torch .randn (n , k , dtype = dtype , device = device )
160
163
scales , zeros = get_groupwise_affine_qparams (t , n_bit = 4 , groupsize = group_size , dtype = dtype )
161
-
164
+
162
165
# Quantize
163
166
q = groupwise_affine_quantize_tensor_from_qparams (
164
167
t , scales , zeros , n_bit = 4 , groupsize = group_size
165
168
)
166
-
169
+
167
170
# Pack to tensor core layout
168
171
packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
169
172
scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
@@ -174,7 +177,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
174
177
dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
175
178
q , scales , zeros , n_bit = 4 , groupsize = group_size
176
179
)
177
-
180
+
178
181
# Dequantize by passing in an identity matrix as the activation
179
182
a_eye = torch .eye (k , device = device , dtype = dtype )
180
183
dq_id = torch .ops .aten ._weight_int4pack_mm (
@@ -183,34 +186,35 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
183
186
group_size ,
184
187
scales_and_zeros ,
185
188
).t ()
186
-
189
+
187
190
# Actual operation to test
188
191
dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
189
-
192
+
190
193
# Compare results
191
194
diff_ao_id = (dq_id - dq_ao ).abs ().max ()
192
195
diff_op_id = (dq_op - dq_id ).abs ().max ()
193
196
diff_op_ao = (dq_op - dq_ao ).abs ().max ()
194
-
197
+
195
198
# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
196
199
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
197
200
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
198
201
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
199
-
200
- # Test that the `dequant` kernel gives same results as identity matrix-based dequant
202
+
203
+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
201
204
assert diff_op_id == 0
202
-
205
+
203
206
# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
204
207
assert diff_op_ao == diff_ao_id
205
208
206
209
assert diff_op_ao < 1e-1
207
210
208
211
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
209
212
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
213
+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
210
214
@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
211
215
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant (shape , inner_k_tiles , group_size ):
212
216
n , k = shape
213
- dtype = torch .bfloat16
217
+ dtype = torch .bfloat16
214
218
device = "cuda"
215
219
216
220
# Quantize and pack
@@ -222,13 +226,13 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
222
226
223
227
packed = torch .ops .aten ._convert_weight_to_int4pack (q , inner_k_tiles )
224
228
scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
225
-
229
+
226
230
# Unpack and dequantize
227
231
unpacked = torchao .ops .unpack_tensor_core_tiled_layout (packed , inner_k_tiles )
228
232
dq_ao = groupwise_affine_dequantize_tensor_from_qparams (
229
233
unpacked , scales , zeros , n_bit = 4 , groupsize = group_size
230
234
)
231
-
235
+
232
236
# Dequantize by passing in an identity matrix as the activation
233
237
a_eye = torch .eye (k , device = device , dtype = dtype )
234
238
dq_id = torch .ops .aten ._weight_int4pack_mm (
@@ -237,29 +241,30 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
237
241
group_size ,
238
242
scales_and_zeros ,
239
243
).t ()
240
-
244
+
241
245
# Actual operation to test
242
246
dq_op = torchao .ops .dequantize_tensor_core_tiled_layout (packed , scales_and_zeros , group_size , inner_k_tiles )
243
-
247
+
244
248
# Compare results
245
249
diff_ao_id = (dq_id - dq_ao ).abs ().max ()
246
250
diff_op_id = (dq_op - dq_id ).abs ().max ()
247
251
diff_op_ao = (dq_op - dq_ao ).abs ().max ()
248
-
252
+
249
253
# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
250
254
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
251
255
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
252
256
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
253
-
254
- # Test that the `dequant` kernel gives same results as identity matrix-based dequant
257
+
258
+ # Test that the `dequant` kernel gives same results as identity matrix-based dequant
255
259
assert diff_op_id == 0
256
-
260
+
257
261
# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
258
262
assert diff_op_ao == diff_ao_id
259
263
260
264
assert diff_op_ao < 1e-1
261
265
262
266
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
267
+ @pytest .mark .skipif (TORCH_VERSION_AFTER_2_5 , reason = "weight packing is updated in 2.5+" )
263
268
@pytest .mark .parametrize ("shape, inner_k_tiles, group_size" , TEST_CONFIGS_DEQUANT , ids = str )
264
269
def test_dequantize_tensor_core_tiled_layout_op (shape , inner_k_tiles , group_size ):
265
270
n , k = shape
@@ -271,7 +276,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
271
276
scales = torch .randn (n , q_groups , dtype = torch .bfloat16 , device = device )
272
277
zeros = torch .randn_like (scales )
273
278
scales_and_zeros = pack_tinygemm_scales_and_zeros (scales , zeros )
274
-
279
+
275
280
test_utils = [
276
281
"test_schema" ,
277
282
"test_autograd_registration" ,
@@ -287,4 +292,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
287
292
)
288
293
289
294
if __name__ == "__main__" :
290
- run_tests ()
295
+ run_tests ()
0 commit comments