Skip to content

Commit 87f39e6

Browse files
committed
works
1 parent 4660cd5 commit 87f39e6

File tree

1 file changed

+67
-82
lines changed

1 file changed

+67
-82
lines changed

examples/flash_attention.py

Lines changed: 67 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from pathlib import Path
2+
13
import mlir.extras.types as T
24
import numpy as np
35
from hip import hip
46
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
5-
67
from mlir.extras.ast.canonicalize import canonicalize
78
from mlir.extras.context import RAIIMLIRContextModule
89
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
@@ -23,6 +24,46 @@
2324
# noinspection PyUnresolvedReferences
2425
from util import hip_check, launch_kernel, hip_synchronize
2526

27+
28+
def init_copy_host_device():
29+
q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
30+
dtype=np.float32
31+
)
32+
k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
33+
dtype=np.float32
34+
)
35+
v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
36+
dtype=np.float32
37+
)
38+
l_h = np.zeros((B * nh * N), dtype=np.float32)
39+
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
40+
O_h = np.zeros_like(q_h, dtype=np.float32)
41+
42+
host = [q_h, k_h, v_h, l_h, m_h, O_h]
43+
device = [hip_check(hip.hipMalloc(h.size * h.itemsize)) for h in host]
44+
45+
for d, h in zip(device, host):
46+
hip_check(
47+
hip.hipMemcpy(
48+
d, h, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyHostToDevice
49+
)
50+
)
51+
52+
return host, device
53+
54+
55+
def copy_device_host(host, device):
56+
for d, h in zip(device, host):
57+
hip_check(
58+
hip.hipMemcpy(
59+
h, d, h.size * h.itemsize, hip.hipMemcpyKind.hipMemcpyDeviceToHost
60+
)
61+
)
62+
hip_check(hip.hipFree(d))
63+
64+
return host
65+
66+
2667
# just so it doesn't get DCE'd by black/reformat
2768
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
2869
_ = memref
@@ -79,7 +120,7 @@ def manual_attn(q, k, v):
79120
k = k.reshape(batch_size, n_head, seq_len, head_embd)
80121
v = v.reshape(batch_size, n_head, seq_len, head_embd)
81122

82-
att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1]))
123+
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1]))
83124
att = softmax(att, axis=-1)
84125
y = att @ v
85126
return y.flatten()
@@ -100,12 +141,12 @@ def flash_attention(
100141
O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()),
101142
):
102143
tx = thread_idx.x
103-
bx = block_idx.x
104-
by = block_idx.y # batch and head index
144+
bx, by = block_idx.x, block_idx.y
145+
gy = grid_dim.y
105146

106147
# Offset into Q,K,V,O,l,m - different for each batch and head
107-
qkv_offset = bx * grid_dim.y * N * d + by * N * d # gridDim.y = nh
108-
lm_offset = bx * grid_dim.y * N + by * N # offset for l and m
148+
qkv_offset = bx * gy * N * d + by * N * d # gridDim.y = nh
149+
lm_offset = bx * gy * N + by * N # offset for l and m
109150

110151
# Define SRAM for Q,K,V,S
111152
sram = gpu.dynamic_shared_memory()
@@ -120,8 +161,6 @@ def flash_attention(
120161
Kj[tx * d + x] = K[qkv_offset + tile_size * j + tx * d + x]
121162
Vj[tx * d + x] = V[qkv_offset + tile_size * j + tx * d + x]
122163

123-
gpu.barrier() # such that the inner loop can use the correct Kj, Vj
124-
125164
for i in scf.range_(0, Tr):
126165
# Load Qi to SRAM, l and m to registers
127166
for x in scf.range_(0, d):
@@ -175,13 +214,10 @@ def flash_attention(
175214
ii = qkv_offset + tile_size * i + tx * d + x
176215
O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv)
177216

178-
gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
179-
180217
m[lm_offset + Br * i + tx] = row_m_new
181218
l[lm_offset + Br * i + tx] = row_l_new
182219

183-
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184-
# gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
220+
gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
185221

186222

187223
ip.__exit__(None, None, None)
@@ -236,68 +272,33 @@ def flash_attention(
236272
T.index(), np.prod(thread_dims)
237273
)
238274

239-
lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary())
275+
output_format = "bin"
276+
# output_format = "isa"
277+
278+
lowered_module = run_pipeline(
279+
lowered_module, Pipeline().gpu_module_to_binary(format=output_format)
280+
)
240281
hsaco = get_compile_object_bytes(lowered_module)
282+
if output_format == "isa":
283+
with open(Path(__file__).parent / "flashattention.amdgcn", "w") as f:
284+
f.write(hsaco.decode())
285+
exit()
241286

242287
hip_module = hip_check(hip.hipModuleLoadData(hsaco))
243288

244-
q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
245-
dtype=np.float32
246-
)
247-
k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
248-
dtype=np.float32
249-
)
250-
v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype(
251-
dtype=np.float32
252-
)
253-
l_h = np.zeros((B * nh * N), dtype=np.float32)
254-
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
255-
O_h = np.zeros_like(q_h, dtype=np.float32)
256-
257-
q_num_bytes = q_h.size * q_h.itemsize
258-
k_num_bytes = k_h.size * k_h.itemsize
259-
v_num_bytes = v_h.size * v_h.itemsize
260-
l_num_bytes = l_h.size * l_h.itemsize
261-
m_num_bytes = m_h.size * m_h.itemsize
262-
O_num_bytes = O_h.size * O_h.itemsize
263-
264-
q_d = hip_check(hip.hipMalloc(q_num_bytes))
265-
k_d = hip_check(hip.hipMalloc(k_num_bytes))
266-
v_d = hip_check(hip.hipMalloc(v_num_bytes))
267-
l_d = hip_check(hip.hipMalloc(l_num_bytes))
268-
m_d = hip_check(hip.hipMalloc(m_num_bytes))
269-
O_d = hip_check(hip.hipMalloc(O_num_bytes))
270-
271289
stream = 0
272290

273291
times = {
274292
flash_attention: 0,
275293
}
276-
# random.shuffle(kernels)
277-
runs = 16
294+
runs = 32
278295
for kernel in times:
279296
for i in range(runs):
280297
function = hip_check(
281298
hip.hipModuleGetFunction(hip_module, kernel.__name__.encode())
282299
)
283300
hip_check(hip.hipDeviceSynchronize())
284301

285-
for d, h, num_bytes in zip(
286-
[q_d, k_d, v_d, l_d, m_d, O_d],
287-
[q_h, k_h, v_h, l_h, m_h, O_h],
288-
[
289-
q_num_bytes,
290-
k_num_bytes,
291-
v_num_bytes,
292-
l_num_bytes,
293-
m_num_bytes,
294-
O_num_bytes,
295-
],
296-
):
297-
hip_check(
298-
hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice)
299-
)
300-
301302
(
302303
(
303304
blocks_per_grid_x,
@@ -312,6 +313,10 @@ def flash_attention(
312313
shared_memory,
313314
) = launch_params[kernel.__name__]
314315

316+
host, device = init_copy_host_device()
317+
q_h, k_h, v_h, *_ = host
318+
correct = manual_attn(q_h, k_h, v_h)
319+
315320
time_compute = launch_kernel(
316321
function.as_c_void_p(),
317322
blocks_per_grid_x,
@@ -322,36 +327,16 @@ def flash_attention(
322327
threads_per_block_z,
323328
stream,
324329
shared_memory,
325-
q_d,
326-
k_d,
327-
v_d,
328-
l_d,
329-
m_d,
330-
O_d,
330+
*device,
331331
)
332332

333-
hip_check(
334-
hip.hipMemcpy(
335-
l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
336-
)
337-
)
338-
hip_check(
339-
hip.hipMemcpy(
340-
m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
341-
)
342-
)
343-
hip_check(
344-
hip.hipMemcpy(
345-
O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost
346-
)
347-
)
348-
correct = manual_attn(q_h, k_h, v_h)
333+
*_, O_h = copy_device_host(host, device)
349334
if not np.allclose(correct, O_h):
350335
print("correct", correct)
351-
print("l_h", l_h)
352-
print("m_h", m_h)
353336
print("output", O_h)
354337
print(f"{kernel.__name__} failed")
338+
else:
339+
print(f"{kernel.__name__}: {time_compute:.03f}ms")
355340

356341
times[kernel] += time_compute
357342

0 commit comments

Comments
 (0)