Skip to content

Commit 729fa4d

Browse files
authored
Add fused QKV HQQ triton_mm test (#306)
1 parent 8dbf031 commit 729fa4d

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

test/hqq/test_triton_qkv_fused.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import pytest
2+
3+
triton = pytest.importorskip(
4+
"triton", minversion="3.0.0", reason="Triton > 3.0.0 required to run this test"
5+
)
6+
hqq = pytest.importorskip("hqq", reason="hqq required to run this test")
7+
hqq_quantize = pytest.importorskip(
8+
"hqq.core.quantize", reason="hqq required to run this test"
9+
)
10+
HQQLinear = hqq_quantize.HQQLinear
11+
BaseQuantizeConfig = hqq_quantize.BaseQuantizeConfig
12+
13+
import itertools
14+
15+
import torch
16+
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear, Quantizer
17+
18+
from torchao.prototype.hqq import pack_2xint4, triton_mixed_mm
19+
20+
torch.manual_seed(0)
21+
# N, K = shape
22+
Q_SHAPES = [[4096, 4096]]
23+
KV_SHAPES = [[4096, 4096], [1024, 4096]]
24+
GROUP_SIZES = [64, 128]
25+
AXES = [1]
26+
DTYPES = [torch.bfloat16]
27+
28+
TRANSPOSED = [False, True]
29+
TRITON_KERNEL_TYPE = ["compute_bound"]
30+
TEST_CONFIGS = list(
31+
itertools.product(
32+
Q_SHAPES, KV_SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE
33+
)
34+
)
35+
36+
37+
BASE_QUANT_CONFIG = {
38+
"optimize": True,
39+
"view_as_float": False,
40+
"nbits": 4,
41+
"bitpack": False,
42+
"axis": 1,
43+
}
44+
45+
46+
def _arg_to_id(arg):
47+
if isinstance(arg, list):
48+
return "x".join([str(x) for x in arg])
49+
return str(arg)
50+
51+
52+
def quantize_helper(
53+
weight_shape, quant_config, dtype, device="cuda", quant_dtype=torch.uint8
54+
):
55+
N, K = weight_shape
56+
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
57+
58+
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
59+
W_q, meta = hqq_linear.W_q, hqq_linear.meta
60+
W_q = W_q.to(dtype=quant_dtype)
61+
W_q = (
62+
W_q.reshape(meta["shape"])
63+
if quant_config["weight_quant_params"]["bitpack"] == False
64+
else W_q
65+
)
66+
67+
scale, zero = meta["scale"], meta["zero"]
68+
scale = scale.reshape(N, -1)
69+
zero = zero.reshape(N, -1)
70+
71+
return W_q, scale, zero
72+
73+
74+
def fuse_qkv(W_qs, scales, zeros):
75+
"""
76+
Args:
77+
W_qs (list[torch.Tensor]): len 3 list of tensors with shapes Nq x K, Nk x K, Nv x K where Nk == Nv
78+
scales (list[torch.Tensor]): each is N x (K // group_size), with same N requirements per W_qs
79+
zeros (list[torch.Tensor]): same as scales
80+
81+
Returns:
82+
qkv (torch.Tensor): (N_qkv x K) where N_qkv = Nq + Nk + Nv
83+
scales (torch.Tensor): (N_qkv x (K // group_size))
84+
zeros (torch.Tensor): (N_qkv x (K // group_size))
85+
"""
86+
qkv = torch.cat(W_qs, dim=0) # Fuse along N
87+
fused_scales = torch.cat([s for s in scales], dim=0)
88+
fused_zeros = torch.cat([z for z in zeros], dim=0)
89+
return qkv, fused_scales, fused_zeros
90+
91+
92+
def ref_proj(x, packed_w, scale, zero, group_size, kernel_type, transposed=False):
93+
return triton_mixed_mm(
94+
x,
95+
packed_w,
96+
scale.T,
97+
zero.T,
98+
transposed=transposed,
99+
group_size=group_size,
100+
fp8_fast_accum=False,
101+
kernel_type=kernel_type,
102+
)
103+
104+
105+
@pytest.mark.parametrize(
106+
"q_shape, kv_shape, group_size, axis, dtype, transposed, kernel_type",
107+
TEST_CONFIGS,
108+
ids=_arg_to_id,
109+
)
110+
def test_mixed_mm(
111+
q_shape,
112+
kv_shape,
113+
group_size,
114+
axis,
115+
dtype,
116+
transposed,
117+
kernel_type,
118+
seqlen=16,
119+
device="cuda",
120+
quant_dtype=torch.uint8,
121+
):
122+
"""
123+
Note we test with dtype float32 in the transposed case, since fused and non-fused ops are not exactly equivalent in this case.
124+
125+
More specifically when running transposed matmul:
126+
- fused: we are reducing along fused N within the kernel
127+
- non-fused: we are launching 3 individual kernels and reducing along N within each of these kernels for q, k, v then post-hoc summing these three terms to simulate the fused op
128+
129+
This gives rise to a number of numeric issues when testing equivalence, given how accumulation is treated within triton MAC loop.
130+
Using higher precision mitigates these issues for the purposes of this test.
131+
"""
132+
133+
# Override dtype per the above comment
134+
if transposed:
135+
dtype = torch.float32
136+
137+
qcfg = {
138+
**BASE_QUANT_CONFIG,
139+
**dict(group_size=group_size, axis=axis),
140+
}
141+
142+
quant_config = BaseQuantizeConfig(
143+
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
144+
)
145+
quant_config.update({"weight_quant_params": qcfg})
146+
147+
# Quantize q, k, v individually
148+
W_qs, packed_ws, scales, zeros = [], [], [], []
149+
for shape in [q_shape, kv_shape, kv_shape]:
150+
W_q, scale, zero = quantize_helper(
151+
shape, quant_config, dtype, device, quant_dtype
152+
)
153+
W_qs.append(W_q)
154+
packed_ws.append(pack_2xint4(W_q.T))
155+
scales.append(scale)
156+
zeros.append(zero)
157+
158+
# Fuse q, k, v, scales, zeros
159+
qkv_fused, scales_fused, zeros_fused = fuse_qkv(W_qs, scales, zeros)
160+
qkv_fused_packed = pack_2xint4(qkv_fused.T)
161+
162+
Ks = [shape[1] for shape in [q_shape, kv_shape]]
163+
164+
K = Ks[0]
165+
166+
# Check shapes
167+
assert all([k == K for k in Ks])
168+
assert qkv_fused_packed.shape[0] * 2 == qkv_fused.shape[1] == Ks[0]
169+
170+
if transposed:
171+
Ns = [q_shape[0], kv_shape[0], kv_shape[0]]
172+
xs = [torch.randn(seqlen, n, dtype=dtype, device=device) for n in Ns]
173+
x_fused = torch.cat(xs, dim=1)
174+
q_ref, k_ref, v_ref = [
175+
ref_proj(x, p, s, z, group_size, kernel_type, transposed=True)
176+
for x, p, s, z in zip(xs, packed_ws, scales, zeros)
177+
]
178+
tt_fused = triton_mixed_mm(
179+
x_fused,
180+
qkv_fused_packed,
181+
scales_fused.T,
182+
zeros_fused.T,
183+
transposed=True,
184+
group_size=group_size,
185+
fp8_fast_accum=False,
186+
kernel_type=kernel_type,
187+
)
188+
tt_ref = q_ref + k_ref + v_ref
189+
assert torch.allclose(tt_ref, tt_fused, atol=1e-4)
190+
else:
191+
x = torch.randn(seqlen, K, dtype=dtype, device=device)
192+
193+
q_ref, k_ref, v_ref = [
194+
ref_proj(x, p, s, z, group_size, kernel_type)
195+
for p, s, z in zip(packed_ws, scales, zeros)
196+
]
197+
198+
tt_fused = triton_mixed_mm(
199+
x,
200+
qkv_fused_packed,
201+
scales_fused.T,
202+
zeros_fused.T,
203+
transposed=False,
204+
group_size=group_size,
205+
fp8_fast_accum=False,
206+
kernel_type=kernel_type,
207+
)
208+
qN, kN, vN = q_shape[0], kv_shape[0], kv_shape[0]
209+
q_fused, k_fused, v_fused = tt_fused.split([qN, kN, vN], dim=1)
210+
211+
for ref, fused in zip([q_ref, k_ref, v_ref], [q_fused, k_fused, v_fused]):
212+
assert torch.allclose(ref, fused)

0 commit comments

Comments
 (0)