Skip to content

Commit 3cec8cd

Browse files
committed
[WIP] fp8_attention
stack-info: PR: #279, branch: yf225/stack/14
1 parent 47878bf commit 3cec8cd

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed

benchmarks/run.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from typing import Any
2626
from typing import Callable
2727

28+
# Import tritonbench's run module which applies the async_task patch
29+
# This ensures the patch is applied before we need it
30+
import tritonbench.run # noqa: F401 # This applies the async_task patch
31+
2832
# Maps tritonbench op names to Helion kernel examples
2933
KERNEL_MAPPINGS: dict[str, tuple[str, str, str]] = {
3034
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
@@ -60,6 +64,11 @@
6064
"examples.attention",
6165
"attention",
6266
),
67+
"fp8_attention": (
68+
"tritonbench.operators.fp8_attention.operator",
69+
"examples.fp8_attention",
70+
"fp8_attention_tritonbench",
71+
),
6372
}
6473

6574

examples/fp8_attention.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
from __future__ import annotations
2+
3+
import math
4+
5+
import torch
6+
7+
import helion
8+
import helion.language as hl
9+
10+
# TritonBench configuration
11+
TRITONBENCH_ARGS = {
12+
"batch": 4,
13+
"n_heads": 48,
14+
"d_head": 64,
15+
}
16+
17+
18+
@helion.kernel(
19+
static_shapes=True,
20+
config=helion.Config(
21+
block_sizes=[64, 32], # [BLOCK_M, BLOCK_N]
22+
num_warps=4,
23+
num_stages=3,
24+
),
25+
)
26+
def fp8_attention_kernel(
27+
q: torch.Tensor, # [batch*heads, seq, dim]
28+
k: torch.Tensor, # [batch*heads, seq, dim]
29+
v: torch.Tensor, # [batch*heads, dim, seq] - pre-transposed
30+
) -> torch.Tensor:
31+
"""FP8 attention kernel processing batch*heads in parallel."""
32+
batch_heads = q.size(0)
33+
seq_len = q.size(1)
34+
head_dim = q.size(2)
35+
36+
# Output tensor
37+
out = torch.empty(
38+
[batch_heads, seq_len, head_dim], dtype=torch.float32, device=q.device
39+
)
40+
41+
# Scale factor for attention
42+
sm_scale = 1.0 / math.sqrt(float(head_dim))
43+
# Triton multiplies sm_scale by 1.44269504 (1/log(2)) for exp2
44+
sm_scale = sm_scale * 1.44269504
45+
46+
# Process each batch*head in parallel
47+
for bh in hl.grid(batch_heads):
48+
# Process each query position
49+
for tile_m in hl.tile(seq_len):
50+
# Initialize for online softmax
51+
m_i = hl.full([tile_m], float("-inf"), dtype=torch.float32)
52+
l_i = hl.full([tile_m], 0.0, dtype=torch.float32) # Should be 0, not 1!
53+
acc = hl.zeros([tile_m, head_dim], dtype=torch.float32)
54+
55+
# Load query tile - keep in FP8
56+
q_tile = q[bh, tile_m, :] # [tile_m, dim]
57+
58+
# Compute attention scores for all keys
59+
for tile_n in hl.tile(seq_len):
60+
# Load key tile and transpose for Q @ K^T
61+
k_tile = k[bh, tile_n, :] # [tile_n, dim] - keep in FP8
62+
k_tile_t = k_tile.transpose(0, 1) # [dim, tile_n]
63+
64+
# Compute Q @ K^T with FP8 inputs, result in FP32
65+
qk = torch.matmul(q_tile, k_tile_t).to(
66+
torch.float32
67+
) # [tile_m, tile_n]
68+
69+
# Scale QK scores first
70+
qk_scaled = qk * sm_scale # [tile_m, tile_n]
71+
72+
# Compute max of scaled scores
73+
qk_max = torch.amax(qk_scaled, dim=-1) # [tile_m]
74+
75+
# Update global max
76+
m_new = torch.maximum(m_i, qk_max)
77+
78+
# Shift by max for numerical stability
79+
qk_shifted = qk_scaled - m_new[:, None]
80+
81+
# Use exp2 to match Triton's implementation
82+
# Note: Triton already multiplies sm_scale by 1.44269504
83+
p = torch.exp2(qk_shifted) # [tile_m, tile_n]
84+
85+
# Sum of exponentials for this block
86+
l_ij = torch.sum(p, dim=-1) # [tile_m]
87+
88+
# Update accumulators with correction factor
89+
# Correction factor for previous blocks
90+
alpha = torch.exp2(m_i - m_new)
91+
l_i = l_i * alpha + l_ij
92+
acc = acc * alpha[:, None]
93+
94+
# Load values - V is [dim, seq]
95+
v_tile = v[bh, :, tile_n] # [dim, tile_n] - keep in FP8
96+
97+
# Convert p to FP8 for FP8 GEMM
98+
p_fp8 = p.to(v.dtype) # Convert to same FP8 type as V
99+
100+
# Accumulate attention @ V with FP8 GEMM
101+
v_t = v_tile.transpose(0, 1) # [tile_n, dim]
102+
pv = torch.matmul(p_fp8, v_t).to(torch.float32) # [tile_m, dim]
103+
acc = acc + pv
104+
105+
# Update max tracker
106+
m_i = m_new
107+
108+
# Final normalization
109+
acc = acc / l_i[:, None]
110+
out[bh, tile_m, :] = acc
111+
112+
return out
113+
114+
115+
def prepare_fp8_attention_inputs(
116+
q: torch.Tensor, # [batch, heads, seq, dim]
117+
k: torch.Tensor, # [batch, heads, seq, dim]
118+
v: torch.Tensor, # [batch, heads, seq, dim]
119+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[int, int, int, int]]:
120+
"""
121+
Common preprocessing for FP8 attention implementations.
122+
123+
Returns:
124+
q_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2
125+
k_reshaped_fp8: [batch*heads, seq, dim] - in FP8 e5m2
126+
v_transposed_fp8: [batch*heads, dim, seq] - in FP8 e5m2
127+
shape: (batch, heads, seq_len, head_dim)
128+
"""
129+
batch, heads, seq_len, head_dim = q.shape
130+
131+
# Reshape to [batch*heads, seq, dim]
132+
q_reshaped = q.reshape(batch * heads, seq_len, head_dim)
133+
k_reshaped = k.reshape(batch * heads, seq_len, head_dim)
134+
135+
# Transpose V to [batch, heads, dim, seq] then reshape
136+
v_transposed = v.permute(0, 1, 3, 2).reshape(batch * heads, head_dim, seq_len)
137+
138+
# Convert to FP8 e5m2
139+
q_reshaped_fp8 = q_reshaped.to(torch.float8_e5m2)
140+
k_reshaped_fp8 = k_reshaped.to(torch.float8_e5m2)
141+
v_transposed_fp8 = v_transposed.to(torch.float8_e5m2)
142+
143+
return q_reshaped_fp8, k_reshaped_fp8, v_transposed_fp8, (batch, heads, seq_len, head_dim)
144+
145+
def fp8_attention_tritonbench(
146+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
147+
) -> torch.Tensor:
148+
"""Wrapper for TritonBench compatibility."""
149+
# Common preprocessing with FP8 conversion
150+
q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v)
151+
batch, heads, seq_len, head_dim = shape
152+
153+
# Call the fused kernel that processes all batch*heads at once
154+
out_fused = fp8_attention_kernel(q_fp8, k_fp8, v_fp8)
155+
156+
# Reshape back and convert to FP16
157+
out = out_fused.reshape(batch, heads, seq_len, head_dim)
158+
return out.to(torch.float16)
159+
160+
161+
def fp8_attention_pytorch(
162+
q: torch.Tensor, # [batch, heads, seq, dim]
163+
k: torch.Tensor, # [batch, heads, seq, dim]
164+
v: torch.Tensor, # [batch, heads, seq, dim]
165+
) -> torch.Tensor:
166+
"""
167+
Manual baseline implementation using FP8 e5m2.
168+
169+
This serves as a PyTorch reference implementation for FP8 attention.
170+
Uses e5m2 format with dequantization since torch._scaled_mm doesn't support e5m2.
171+
"""
172+
# Get preprocessed inputs with FP8 conversion
173+
q_fp8, k_fp8, v_fp8, shape = prepare_fp8_attention_inputs(q, k, v)
174+
batch, heads, seq_len, head_dim = shape
175+
176+
sm_scale = 1.0 / math.sqrt(float(head_dim))
177+
178+
outputs = []
179+
180+
for i in range(batch * heads):
181+
q_i = q_fp8[i] # [seq, dim] - already FP8
182+
k_i = k_fp8[i] # [seq, dim] - already FP8
183+
v_i = v_fp8[i] # [dim, seq] - pre-transposed, already FP8
184+
185+
# For Q @ K^T, we need K^T to be column-major
186+
kt_fp8 = k_i.t() # column-major [dim, seq]
187+
188+
# Q @ K^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
189+
q_deq = q_i.to(torch.float32)
190+
kt_deq = kt_fp8.to(torch.float32)
191+
qk = torch.matmul(q_deq, kt_deq)
192+
193+
# Compute max before scaling (following Triton Flash v2 algorithm)
194+
qk_max = torch.amax(qk, dim=-1, keepdim=True)
195+
196+
# Scale and shift in one operation, then use exp2
197+
qk_scaled_shifted = qk * sm_scale - qk_max * sm_scale
198+
p = torch.exp2(qk_scaled_shifted * 1.44269504)
199+
200+
# Normalize
201+
p_norm = p / p.sum(dim=-1, keepdim=True)
202+
203+
# Step 2: Attention @ V using FP8
204+
# P is [seq, seq], V is [dim, seq]
205+
# We want P @ V^T = [seq, seq] @ [seq, dim] = [seq, dim]
206+
p_fp8 = p_norm.to(torch.float8_e5m2) # row-major [seq, seq]
207+
208+
# v_i is [dim, seq], already FP8
209+
vt_fp8 = v_i.t() # column-major [seq, dim]
210+
211+
# P @ V^T - dequantize and use regular matmul since e5m2 not supported by _scaled_mm
212+
p_deq = p_fp8.to(torch.float32)
213+
vt_deq = vt_fp8.to(torch.float32)
214+
out_i = torch.matmul(p_deq, vt_deq)
215+
216+
outputs.append(out_i)
217+
218+
# Stack and reshape back
219+
out_stacked = torch.stack(outputs, dim=0) # [batch*heads, seq, dim]
220+
out = out_stacked.reshape(batch, heads, seq_len, head_dim)
221+
222+
return out.to(torch.float16)
223+
224+
225+
def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
226+
torch.manual_seed(42)
227+
q = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
228+
k = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
229+
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
230+
231+
from helion._testing import run_example
232+
233+
run_example(
234+
fp8_attention_tritonbench, fp8_attention_pytorch, (q, k, v), atol=0.1, rtol=0.1
235+
)
236+
237+
238+
def main() -> None:
239+
check(1, 2, 128, 64)
240+
check(2, 4, 256, 64)
241+
check(4, 8, 512, 128)
242+
243+
244+
if __name__ == "__main__":
245+
main()

0 commit comments

Comments
 (0)