|
| 1 | +import pytest |
| 2 | + |
| 3 | +triton = pytest.importorskip( |
| 4 | + "triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test" |
| 5 | +) |
| 6 | +hqq = pytest.importorskip("hqq", reason="hqq required to run this test") |
| 7 | +hqq_quantize = pytest.importorskip( |
| 8 | + "hqq.core.quantize", reason="hqq required to run this test" |
| 9 | +) |
| 10 | +HQQLinear = hqq_quantize.HQQLinear |
| 11 | +BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig |
| 12 | + |
| 13 | +import itertools |
| 14 | + |
| 15 | +import torch |
| 16 | +from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer |
| 17 | + |
| 18 | +from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm |
| 19 | + |
| 20 | +torch.manual_seed(0) |
| 21 | +# N, K = shape |
| 22 | +Q_SHAPES = [[4096, 4096]] |
| 23 | +KV_SHAPES = [[4096, 4096], [1024, 4096]] |
| 24 | +GROUP_SIZES = [64, 128] |
| 25 | +AXES = [1] |
| 26 | +DTYPES = [torch.bfloat16] |
| 27 | + |
| 28 | +TRANSPOSED = [False, True] |
| 29 | +TRITON_KERNEL_TYPE = ["compute_bound"] |
| 30 | +TEST_CONFIGS = list( |
| 31 | + itertools.product( |
| 32 | + Q_SHAPES, KV_SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE |
| 33 | + ) |
| 34 | +) |
| 35 | + |
| 36 | + |
| 37 | +BASE_QUANT_CONFIG = { |
| 38 | + "optimize": True, |
| 39 | + "view_as_float": False, |
| 40 | + "nbits": 4, |
| 41 | + "bitpack": False, |
| 42 | + "axis": 1, |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +def _arg_to_id(arg): |
| 47 | + if isinstance(arg, list): |
| 48 | + return "x".join([str(x) for x in arg]) |
| 49 | + return str(arg) |
| 50 | + |
| 51 | + |
| 52 | +def quantize_helper( |
| 53 | + weight_shape, quant_config, dtype, device="cuda", quant_dtype=torch.uint8 |
| 54 | +): |
| 55 | + N, K = weight_shape |
| 56 | + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) |
| 57 | + |
| 58 | + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) |
| 59 | + W_q, meta = hqq_linear.W_q, hqq_linear.meta |
| 60 | + W_q = W_q.to(dtype=quant_dtype) |
| 61 | + W_q = ( |
| 62 | + W_q.reshape(meta["shape"]) |
| 63 | + if quant_config["weight_quant_params"]["bitpack"] == False |
| 64 | + else W_q |
| 65 | + ) |
| 66 | + |
| 67 | + scale, zero = meta["scale"], meta["zero"] |
| 68 | + scale = scale.reshape(N, -1) |
| 69 | + zero = zero.reshape(N, -1) |
| 70 | + |
| 71 | + return W_q, scale, zero |
| 72 | + |
| 73 | + |
| 74 | +def fuse_qkv(W_qs, scales, zeros): |
| 75 | + """ |
| 76 | + Args: |
| 77 | + W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv |
| 78 | + scales (list[torch.Tensor]): each is N x (K // group_size), with same N requirements per W_qs |
| 79 | + zeros (list[torch.Tensor]): same as scales |
| 80 | +
|
| 81 | + Returns: |
| 82 | + qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv |
| 83 | + scales (torch.Tensor): (N_qkv x (K // group_size)) |
| 84 | + zeros (torch.Tensor): (N_qkv x (K // group_size)) |
| 85 | + """ |
| 86 | + qkv = torch.cat(W_qs, dim=0) # Fuse along N |
| 87 | + fused_scales = torch.cat([s for s in scales], dim=0) |
| 88 | + fused_zeros = torch.cat([z for z in zeros], dim=0) |
| 89 | + return qkv, fused_scales, fused_zeros |
| 90 | + |
| 91 | + |
| 92 | +def ref_proj(x, packed_w, scale, zero, group_size, kernel_type, transposed=False): |
| 93 | + return triton_mixed_mm( |
| 94 | + x, |
| 95 | + packed_w, |
| 96 | + scale.T, |
| 97 | + zero.T, |
| 98 | + transposed=transposed, |
| 99 | + group_size=group_size, |
| 100 | + fp8_fast_accum=False, |
| 101 | + kernel_type=kernel_type, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +@pytest.mark.parametrize( |
| 106 | + "q_shape, kv_shape, group_size, axis, dtype, transposed, kernel_type", |
| 107 | + TEST_CONFIGS, |
| 108 | + ids=_arg_to_id, |
| 109 | +) |
| 110 | +def test_mixed_mm( |
| 111 | + q_shape, |
| 112 | + kv_shape, |
| 113 | + group_size, |
| 114 | + axis, |
| 115 | + dtype, |
| 116 | + transposed, |
| 117 | + kernel_type, |
| 118 | + seqlen=16, |
| 119 | + device="cuda", |
| 120 | + quant_dtype=torch.uint8, |
| 121 | +): |
| 122 | + """ |
| 123 | + Note we test with dtype float32 in the transposed case, since fused and non-fused ops are not exactly equivalent in this case. |
| 124 | +
|
| 125 | + More specifically when running transposed matmul: |
| 126 | + - fused: we are reducing along fused N within the kernel |
| 127 | + - non-fused: we are launching 3 individual kernels and reducing along N within each of these kernels for q, k, v then post-hoc summing these three terms to simulate the fused op |
| 128 | +
|
| 129 | + This gives rise to a number of numeric issues when testing equivalence, given how accumulation is treated within triton MAC loop. |
| 130 | + Using higher precision mitigates these issues for the purposes of this test. |
| 131 | + """ |
| 132 | + |
| 133 | + # Override dtype per the above comment |
| 134 | + if transposed: |
| 135 | + dtype = torch.float32 |
| 136 | + |
| 137 | + qcfg = { |
| 138 | + **BASE_QUANT_CONFIG, |
| 139 | + **dict(group_size=group_size, axis=axis), |
| 140 | + } |
| 141 | + |
| 142 | + quant_config = BaseQuantizeConfig( |
| 143 | + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False |
| 144 | + ) |
| 145 | + quant_config.update({"weight_quant_params": qcfg}) |
| 146 | + |
| 147 | + # Quantize q, k, v individually |
| 148 | + W_qs, packed_ws, scales, zeros = [], [], [], [] |
| 149 | + for shape in [q_shape, kv_shape, kv_shape]: |
| 150 | + W_q, scale, zero = quantize_helper( |
| 151 | + shape, quant_config, dtype, device, quant_dtype |
| 152 | + ) |
| 153 | + W_qs.append(W_q) |
| 154 | + packed_ws.append(pack_2xint4(W_q.T)) |
| 155 | + scales.append(scale) |
| 156 | + zeros.append(zero) |
| 157 | + |
| 158 | + # Fuse q, k, v, scales, zeros |
| 159 | + qkv_fused, scales_fused, zeros_fused = fuse_qkv(W_qs, scales, zeros) |
| 160 | + qkv_fused_packed = pack_2xint4(qkv_fused.T) |
| 161 | + |
| 162 | + Ks = [shape[1] for shape in [q_shape, kv_shape]] |
| 163 | + |
| 164 | + K = Ks[0] |
| 165 | + |
| 166 | + # Check shapes |
| 167 | + assert all([k == K for k in Ks]) |
| 168 | + assert qkv_fused_packed.shape[0] * 2 == qkv_fused.shape[1] == Ks[0] |
| 169 | + |
| 170 | + if transposed: |
| 171 | + Ns = [q_shape[0], kv_shape[0], kv_shape[0]] |
| 172 | + xs = [torch.randn(seqlen, n, dtype=dtype, device=device) for n in Ns] |
| 173 | + x_fused = torch.cat(xs, dim=1) |
| 174 | + q_ref, k_ref, v_ref = [ |
| 175 | + ref_proj(x, p, s, z, group_size, kernel_type, transposed=True) |
| 176 | + for x, p, s, z in zip(xs, packed_ws, scales, zeros) |
| 177 | + ] |
| 178 | + tt_fused = triton_mixed_mm( |
| 179 | + x_fused, |
| 180 | + qkv_fused_packed, |
| 181 | + scales_fused.T, |
| 182 | + zeros_fused.T, |
| 183 | + transposed=True, |
| 184 | + group_size=group_size, |
| 185 | + fp8_fast_accum=False, |
| 186 | + kernel_type=kernel_type, |
| 187 | + ) |
| 188 | + tt_ref = q_ref + k_ref + v_ref |
| 189 | + assert torch.allclose(tt_ref, tt_fused, atol=1e-4) |
| 190 | + else: |
| 191 | + x = torch.randn(seqlen, K, dtype=dtype, device=device) |
| 192 | + |
| 193 | + q_ref, k_ref, v_ref = [ |
| 194 | + ref_proj(x, p, s, z, group_size, kernel_type) |
| 195 | + for p, s, z in zip(packed_ws, scales, zeros) |
| 196 | + ] |
| 197 | + |
| 198 | + tt_fused = triton_mixed_mm( |
| 199 | + x, |
| 200 | + qkv_fused_packed, |
| 201 | + scales_fused.T, |
| 202 | + zeros_fused.T, |
| 203 | + transposed=False, |
| 204 | + group_size=group_size, |
| 205 | + fp8_fast_accum=False, |
| 206 | + kernel_type=kernel_type, |
| 207 | + ) |
| 208 | + qN, kN, vN = q_shape[0], kv_shape[0], kv_shape[0] |
| 209 | + q_fused, k_fused, v_fused = tt_fused.split([qN, kN, vN], dim=1) |
| 210 | + |
| 211 | + for ref, fused in zip([q_ref, k_ref, v_ref], [q_fused, k_fused, v_fused]): |
| 212 | + assert torch.allclose(ref, fused) |
0 commit comments