|
11 | 11 | from torchao.quantization.quant_primitives import (
|
12 | 12 | get_group_qparams_symmetric,
|
13 | 13 | get_groupwise_affine_qparams,
|
| 14 | + groupwise_affine_quantize_tensor_from_qparams, |
| 15 | + groupwise_affine_dequantize_tensor_from_qparams, |
14 | 16 | quantize_affine,
|
15 | 17 | dequantize_affine,
|
16 | 18 | choose_qparams_affine,
|
@@ -38,6 +40,86 @@ def check_idempotent(self, fn, *args, **kwargs):
|
38 | 40 | self.assertTrue(torch.equal(output0, output1), f"Expected given function {fn} to be idempotent.")
|
39 | 41 | return output1
|
40 | 42 |
|
| 43 | +# Legacy tinygemm ops |
| 44 | +def _get_groupwise_affine_qparams(w, n_bit=4, groupsize=128, dtype=torch.bfloat16): |
| 45 | + if groupsize > w.shape[-1]: |
| 46 | + groupsize = w.shape[-1] |
| 47 | + assert groupsize > 1 |
| 48 | + assert w.shape[-1] % groupsize == 0 |
| 49 | + assert w.dim() == 2 |
| 50 | + |
| 51 | + to_quant = w.reshape(-1, groupsize) |
| 52 | + # assert torch.isnan(to_quant).sum() == 0 |
| 53 | + |
| 54 | + max_val = to_quant.amax(dim=1, keepdim=True) |
| 55 | + min_val = to_quant.amin(dim=1, keepdim=True) |
| 56 | + max_int = 2**n_bit - 1 |
| 57 | + scales = (max_val - min_val).clamp(min=1e-6) / max_int |
| 58 | + zeros = min_val + scales * (2 ** (n_bit - 1)) |
| 59 | + return scales.to(dtype=dtype).reshape(w.shape[0], -1), zeros.to( |
| 60 | + dtype=dtype |
| 61 | + ).reshape(w.shape[0], -1) |
| 62 | + |
| 63 | +def _groupwise_affine_quantize_tensor_from_qparams( |
| 64 | + w, |
| 65 | + scales, |
| 66 | + zeros, |
| 67 | + n_bit=4, |
| 68 | + groupsize=128, |
| 69 | +): |
| 70 | + assert groupsize > 1 |
| 71 | + # needed for GPTQ single column quantize |
| 72 | + if groupsize > w.shape[-1] and scales.shape[-1] == 1: |
| 73 | + groupsize = w.shape[-1] |
| 74 | + |
| 75 | + assert w.shape[-1] % groupsize == 0 |
| 76 | + assert w.dim() == 2 |
| 77 | + |
| 78 | + to_quant = w.reshape(-1, groupsize) |
| 79 | + # assert torch.isnan(to_quant).sum() == 0 |
| 80 | + |
| 81 | + scales = scales.reshape(-1, 1) |
| 82 | + zeros = zeros.reshape(-1, 1) |
| 83 | + min_val = zeros - scales * (2 ** (n_bit - 1)) |
| 84 | + max_int = 2**n_bit - 1 |
| 85 | + min_int = 0 |
| 86 | + w_int4x8 = ( |
| 87 | + to_quant.sub(min_val) |
| 88 | + .div(scales) |
| 89 | + .round() |
| 90 | + .clamp_(min_int, max_int) |
| 91 | + .to(torch.int32) |
| 92 | + .reshape_as(w) |
| 93 | + ) |
| 94 | + |
| 95 | + return w_int4x8 |
| 96 | + |
| 97 | +def _groupwise_affine_dequantize_tensor_from_qparams( |
| 98 | + w_int4x8, |
| 99 | + scales, |
| 100 | + zeros, |
| 101 | + n_bit=4, |
| 102 | + groupsize=128, |
| 103 | +): |
| 104 | + assert groupsize > 1 |
| 105 | + # needed for GPTQ single column dequantize |
| 106 | + if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: |
| 107 | + groupsize = w_int4x8.shape[-1] |
| 108 | + assert w_int4x8.shape[-1] % groupsize == 0 |
| 109 | + assert w_int4x8.dim() == 2 |
| 110 | + |
| 111 | + w_int4x8_grouped = w_int4x8.reshape(-1, groupsize) |
| 112 | + scales = scales.reshape(-1, 1) |
| 113 | + zeros = zeros.reshape(-1, 1) |
| 114 | + |
| 115 | + w_dq = ( |
| 116 | + w_int4x8_grouped.sub(2 ** (n_bit - 1)) |
| 117 | + .mul(scales) |
| 118 | + .add(zeros) |
| 119 | + .reshape_as(w_int4x8) |
| 120 | + ) |
| 121 | + return w_dq |
| 122 | + |
41 | 123 |
|
42 | 124 | class TestQuantPrimitives(unittest.TestCase):
|
43 | 125 | SEED = 123
|
@@ -356,12 +438,12 @@ def test_not_preserve_zero_not_supported(self):
|
356 | 438 | )
|
357 | 439 |
|
358 | 440 |
|
359 |
| - def test_tinygemm_get_groupwise_affine_qparams(self): |
| 441 | + def test_get_groupwise_affine_qparams(self): |
360 | 442 | from torchao.quantization.quant_primitives import ZeroPointDomain
|
361 | 443 |
|
362 | 444 | input = torch.randn(10, 256)
|
363 | 445 | n_bit = 4
|
364 |
| - scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) |
| 446 | + scale_ref, zero_point_ref = _get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) |
365 | 447 |
|
366 | 448 | mapping_type = MappingType.ASYMMETRIC
|
367 | 449 | dtype = torch.int8
|
@@ -389,6 +471,29 @@ def test_tinygemm_get_groupwise_affine_qparams(self):
|
389 | 471 | self.assertTrue(torch.equal(scale, scale_ref))
|
390 | 472 | self.assertTrue(torch.equal(zero_point, zero_point_ref))
|
391 | 473 |
|
| 474 | + def test_groupwise_affine_quantize_tensor_from_qparams(self): |
| 475 | + input = torch.randn(10, 256) |
| 476 | + scales = torch.randn(10, 2) |
| 477 | + zeros = torch.randn(10, 2) |
| 478 | + n_bit = 4 |
| 479 | + groupsize = 128 |
| 480 | + |
| 481 | + w_int4x8 = groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) |
| 482 | + w_int4x8_ref = _groupwise_affine_quantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) |
| 483 | + |
| 484 | + self.assertTrue(torch.equal(w_int4x8, w_int4x8_ref)) |
| 485 | + |
| 486 | + def test_groupwise_affine_dequantize_tensor_from_qparams(self): |
| 487 | + input = torch.randint(0, 15, (10, 256), dtype=torch.int32) |
| 488 | + scales = torch.randn(10, 2).bfloat16() |
| 489 | + zeros = torch.randn(10, 2).bfloat16() |
| 490 | + n_bit = 4 |
| 491 | + groupsize = 128 |
| 492 | + |
| 493 | + w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) |
| 494 | + w_bf16_ref = _groupwise_affine_dequantize_tensor_from_qparams(input, scales, zeros, n_bit, groupsize) |
| 495 | + |
| 496 | + self.assertTrue(torch.equal(w_bf16, w_bf16_ref)) |
392 | 497 |
|
393 | 498 | if __name__ == "__main__":
|
394 | 499 | unittest.main()
|
0 commit comments