Skip to content

Commit 874c79c

Browse files
committed
flash attention
1 parent f6bff8f commit 874c79c

File tree

5 files changed

+554
-90
lines changed

5 files changed

+554
-90
lines changed

examples/flash_attention.py

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
import mlir.extras.types as T
2+
import numpy as np
3+
from hip import hip
4+
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
5+
6+
from mlir.extras.ast.canonicalize import canonicalize
7+
from mlir.extras.context import RAIIMLIRContextModule
8+
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
9+
10+
# noinspection PyUnresolvedReferences
11+
from mlir.extras.dialects.ext.gpu import (
12+
all_reduce,
13+
wait,
14+
thread_attr as thread,
15+
block_idx,
16+
thread_idx,
17+
grid_dim,
18+
block_dim,
19+
func as gpu_func,
20+
set_container_module,
21+
launch,
22+
all_reduce_,
23+
module,
24+
get_compile_object_bytes,
25+
lds_space,
26+
)
27+
from mlir.extras.runtime.passes import run_pipeline, Pipeline
28+
from mlir.extras.util import find_ops
29+
30+
# noinspection PyUnresolvedReferences
31+
from util import hip_check, launch_kernel, hip_synchronize
32+
33+
# just so it doesn't get DCE'd by black/reformat
34+
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
35+
_ = memref
36+
37+
ctx = RAIIMLIRContextModule()
38+
set_container_module(ctx.module)
39+
40+
props = hip.hipDeviceProp_t()
41+
hip_check(hip.hipGetDeviceProperties(props, 0))
42+
arch = props.gcnArchName.decode()
43+
44+
45+
# just a default attr - actual target is set blow
46+
@module("kernels", [f'#rocdl.target<abi = "500">'])
47+
def gpu_module():
48+
pass
49+
50+
51+
ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0])
52+
ip.__enter__()
53+
54+
batch_size = 16
55+
n_head = 12
56+
seq_len = 64
57+
head_embd = 64
58+
59+
Bc = 32
60+
Br = 32
61+
62+
B = batch_size
63+
nh = n_head
64+
N = seq_len
65+
d = head_embd
66+
67+
import math
68+
69+
Tc = math.ceil(N / Bc)
70+
Tr = math.ceil(N / Br)
71+
softmax_scale = 1.0 / math.sqrt(d)
72+
tile_size = Bc * d # size of Qi, Kj, Vj
73+
74+
75+
def softmax(x, axis=None):
76+
x_max = np.amax(x, axis=axis, keepdims=True)
77+
exp_x_shifted = np.exp(x - x_max)
78+
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True)
79+
80+
81+
def manual_attn(q, k, v):
82+
# the kernel below overwrites the global math.........
83+
import math
84+
85+
q = q.reshape(batch_size, n_head, seq_len, head_embd)
86+
k = k.reshape(batch_size, n_head, seq_len, head_embd)
87+
v = v.reshape(batch_size, n_head, seq_len, head_embd)
88+
89+
att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1]))
90+
att = softmax(att, axis=-1)
91+
y = att @ v
92+
return y.flatten()
93+
94+
95+
from mlir.dialects import math
96+
97+
98+
# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
99+
@gpu_func(emit=True)
100+
@canonicalize(using=[scf.canonicalizer, arith.canonicalizer])
101+
def flash_attention(
102+
Q: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
103+
K: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
104+
V: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
105+
l: T.memref(B * nh * N, T.f32()),
106+
m: T.memref(B * nh * N, T.f32()),
107+
O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
108+
):
109+
tx = thread_idx.x
110+
bx = block_idx.x
111+
by = block_idx.y # batch and head index
112+
113+
# Offset into Q,K,V,O,l,m - different for each batch and head
114+
qkv_offset = (bx * grid_dim.y * N * d) + (by * N * d) # gridDim.y = nh
115+
lm_offset = (bx * grid_dim.y * N) + (by * N) # offset for l and m
116+
117+
# Define SRAM for Q,K,V,S
118+
sram = gpu.dynamic_shared_memory()
119+
Qi = memref.view(sram, (tile_size,), dtype=T.f32())
120+
Kj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 1)
121+
Vj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 2)
122+
S = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 3)
123+
124+
for j in scf.range_(0, Tc):
125+
# Load Kj, Vj to SRAM
126+
for x in scf.range_(0, d):
127+
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x]
128+
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x]
129+
130+
gpu.barrier() # such that the inner loop can use the correct Kj, Vj
131+
132+
for i in scf.range_(0, Tr):
133+
# Load Qi to SRAM, l and m to registers
134+
for x in scf.range_(0, d):
135+
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x]
136+
137+
row_m_prev = m[lm_offset + (Br * i) + tx]
138+
row_l_prev = l[lm_offset + (Br * i) + tx]
139+
140+
# S = QK^T, row_m = rowmax(S)
141+
row_m: T.f32() = float(np.finfo(np.float32).min)
142+
for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]):
143+
sum: T.f32() = 0.0
144+
for x, sum, _ in scf.range_(0, d, iter_args=[sum]):
145+
sum += Qi[(tx * d) + x] * Kj[(y * d) + x]
146+
sum = yield sum
147+
148+
sum *= softmax_scale
149+
S[(Bc * tx) + y] = sum
150+
151+
if sum > row_m:
152+
row_m_ = yield sum
153+
else:
154+
row_m_ = yield row_m
155+
156+
row_m = yield row_m_
157+
158+
# P = exp(S - row_m), row_l = rowsum(P)
159+
row_l: T.f32() = 0.0
160+
for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]):
161+
S[(Bc * tx) + y] = math.exp(S[(Bc * tx) + y] - row_m)
162+
row_l += S[(Bc * tx) + y]
163+
row_l = yield row_l
164+
165+
# Compute new m and l
166+
row_m_new = arith.maximumf(row_m_prev, row_m)
167+
row_l_new = (math.exp(row_m_prev - row_m_new) * row_l_prev) + (
168+
math.exp(row_m - row_m_new) * row_l
169+
)
170+
div = 1.0 / row_l_new
171+
c = row_l_prev * math.exp(row_m_prev - row_m_new)
172+
173+
# Write O, l, m to HBM
174+
for x in scf.range_(0, d):
175+
pv: T.f32() = 0.0 # Pij * Vj
176+
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
177+
pv += S[(Bc * tx) + y] * Vj[(y * d) + x]
178+
pv = yield pv
179+
180+
ii = qkv_offset + (tile_size * i) + (tx * d) + x
181+
O[ii] = div * ((c * O[ii]) + (math.exp(row_m - row_m_new) * pv))
182+
183+
gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184+
185+
m[lm_offset + (Br * i) + tx] = row_m_new
186+
l[lm_offset + (Br * i) + tx] = row_l_new
187+
188+
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
189+
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
190+
191+
192+
ip.__exit__(None, None, None)
193+
194+
sram_size = 4 * tile_size * np.float32().itemsize
195+
196+
launch_params = {
197+
flash_attention.__name__: (
198+
(B, nh, 1),
199+
(Bc, 1, 1),
200+
sram_size,
201+
)
202+
}
203+
204+
simplified_module = run_pipeline(
205+
ctx.module,
206+
Pipeline()
207+
.canonicalize()
208+
.cse()
209+
.loop_invariant_code_motion()
210+
.loop_invariant_subset_hoisting()
211+
.rocdl_attach_target(chip=arch, O=3, abi="500"),
212+
)
213+
214+
lowered_module = run_pipeline(
215+
simplified_module,
216+
Pipeline()
217+
.Gpu(
218+
Pipeline().convert_gpu_to_rocdl(
219+
use_bare_ptr_memref_call_conv=True,
220+
runtime="HIP",
221+
)
222+
)
223+
.gpu_to_llvm()
224+
.lower_to_llvm(),
225+
# .Nested("llvm.func", Pipeline().sroa()),
226+
)
227+
228+
# print(lowered_module)
229+
gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp))
230+
for g in gep:
231+
g.attributes["inbounds"] = UnitAttr.get()
232+
233+
kernel_funcs = find_ops(
234+
lowered_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp)
235+
)
236+
for k in kernel_funcs:
237+
if k.sym_name.value != flash_attention.__name__:
238+
continue
239+
_, thread_dims, _ = launch_params[k.sym_name.value]
240+
k.attributes["rocdl.max_flat_work_group_size"] = IntegerAttr.get(
241+
T.index(), np.prod(thread_dims)
242+
)
243+
244+
lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary())
245+
hsaco = get_compile_object_bytes(lowered_module)
246+
247+
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
248+
249+
q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
250+
dtype=np.float32
251+
)
252+
k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
253+
dtype=np.float32
254+
)
255+
v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
256+
dtype=np.float32
257+
)
258+
l_h = np.zeros((B * nh * N), dtype=np.float32)
259+
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
260+
O_h = np.zeros_like(q_h, dtype=np.float32)
261+
262+
q_num_bytes = q_h.size * q_h.itemsize
263+
k_num_bytes = k_h.size * k_h.itemsize
264+
v_num_bytes = v_h.size * v_h.itemsize
265+
l_num_bytes = l_h.size * l_h.itemsize
266+
m_num_bytes = m_h.size * m_h.itemsize
267+
O_num_bytes = O_h.size * O_h.itemsize
268+
269+
q_d = hip_check(hip.hipMalloc(q_num_bytes))
270+
k_d = hip_check(hip.hipMalloc(k_num_bytes))
271+
v_d = hip_check(hip.hipMalloc(v_num_bytes))
272+
l_d = hip_check(hip.hipMalloc(l_num_bytes))
273+
m_d = hip_check(hip.hipMalloc(m_num_bytes))
274+
O_d = hip_check(hip.hipMalloc(O_num_bytes))
275+
276+
stream = 0
277+
278+
times = {
279+
flash_attention: 0,
280+
}
281+
# random.shuffle(kernels)
282+
runs = 16
283+
for kernel in times:
284+
for i in range(runs):
285+
function = hip_check(
286+
hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())
287+
)
288+
hip_check(hip.hipDeviceSynchronize())
289+
290+
for d, h, num_bytes in zip(
291+
[q_d, k_d, v_d, l_d, m_d, O_d],
292+
[q_h, k_h, v_h, l_h, m_h, O_h],
293+
[
294+
q_num_bytes,
295+
k_num_bytes,
296+
v_num_bytes,
297+
l_num_bytes,
298+
m_num_bytes,
299+
O_num_bytes,
300+
],
301+
):
302+
hip_check(
303+
hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
304+
)
305+
306+
(
307+
(
308+
blocks_per_grid_x,
309+
blocks_per_grid_y,
310+
blocks_per_grid_z,
311+
),
312+
(
313+
threads_per_block_x,
314+
threads_per_block_y,
315+
threads_per_block_z,
316+
),
317+
shared_memory,
318+
) = launch_params[kernel.__name__]
319+
320+
time_compute = launch_kernel(
321+
function.as_c_void_p(),
322+
blocks_per_grid_x,
323+
blocks_per_grid_y,
324+
blocks_per_grid_z,
325+
threads_per_block_x,
326+
threads_per_block_y,
327+
threads_per_block_z,
328+
stream,
329+
shared_memory,
330+
q_d,
331+
k_d,
332+
v_d,
333+
l_d,
334+
m_d,
335+
O_d,
336+
)
337+
338+
hip_check(
339+
hip.hipMemcpy(
340+
l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
341+
)
342+
)
343+
hip_check(
344+
hip.hipMemcpy(
345+
m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
346+
)
347+
)
348+
hip_check(
349+
hip.hipMemcpy(
350+
O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
351+
)
352+
)
353+
correct = manual_attn(q_h, k_h, v_h)
354+
if not np.allclose(correct, O_h):
355+
print("correct", correct)
356+
print("l_h", l_h)
357+
print("m_h", m_h)
358+
print("output", O_h)
359+
print(f"{kernel.__name__} failed")
360+
361+
times[kernel] += time_compute
362+
363+
for k in times:
364+
times[k] /= runs
365+
366+
for k, v in times.items():
367+
print(f"{k.__name__}: {v:.03f}ms")

0 commit comments

Comments
 (0)