Skip to content

Commit a7c3755

Browse files
committed
rank-reduced
1 parent 4323bcc commit a7c3755

File tree

3 files changed

+148
-88
lines changed

3 files changed

+148
-88
lines changed

examples/flash_attention.py

Lines changed: 41 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import mlir.extras.types as T
44
import numpy as np
55
from hip import hip
6-
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr, Type
6+
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
77
from mlir.extras.ast.canonicalize import canonicalize
88
from mlir.extras.context import RAIIMLIRContextModule
9-
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm, affine
9+
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
10+
from mlir.dialects import math
1011

1112
# noinspection PyUnresolvedReferences
1213
from mlir.extras.dialects.ext.gpu import (
@@ -25,12 +26,12 @@
2526
from util import hip_check, launch_kernel, hip_synchronize
2627

2728

28-
def init_copy_host_device():
29-
q_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
30-
k_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
31-
v_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
32-
l_h = np.zeros((B * nh * N), dtype=np.float32)
33-
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
29+
def init_copy_host_device(B, nh, N, d):
30+
q_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
31+
k_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
32+
v_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
33+
l_h = np.zeros((B, nh, N), dtype=np.float32)
34+
m_h = np.full((B, nh, N), float(np.finfo(np.float32).min), dtype=np.float32)
3435
O_h = np.zeros_like(q_h, dtype=np.float32)
3536

3637
host = [q_h, k_h, v_h, l_h, m_h, O_h]
@@ -87,11 +88,7 @@ def gpu_module():
8788
N = 128
8889
d = 128
8990

90-
import math
91-
92-
Tc = math.ceil(N / Bc)
93-
Tr = math.ceil(N / Br)
94-
softmax_scale = 1.0 / math.sqrt(d)
91+
softmax_scale = 1.0 / float(np.sqrt(d))
9592

9693

9794
def softmax(x, axis=None):
@@ -101,20 +98,13 @@ def softmax(x, axis=None):
10198

10299

103100
def manual_attn(q, k, v):
104-
# the kernel below overwrites the global math.........
105-
import math
106-
107-
q = q.reshape(B, nh, N, d)
108-
k = k.reshape(B, nh, N, d)
109-
v = v.reshape(B, nh, N, d)
110-
111-
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1]))
101+
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / float(np.sqrt(k.shape[-1])))
112102
att = softmax(att, axis=-1)
113103
y = att @ v
114-
return y.flatten()
104+
return y
115105

116106

117-
from mlir.dialects import math
107+
rank_reduce = memref.MemRef.rank_reduce
118108

119109

120110
# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
@@ -134,32 +124,18 @@ def flash_attention(
134124
# gpu.printf("bx %ld, by %ld\n", bx, by)
135125

136126
# Offset into Q,K,V,O,l,m - different for each batch and head
137-
K_ = K[bx, by, :, :]
138-
V_ = V[bx, by, :, :]
139-
Q_ = Q[bx, by, :, :]
140-
O_ = O[bx, by, :, :]
141-
l_ = l[bx, by, :]
142-
m_ = m[bx, by, :]
127+
K = K[bx, by, :, :, rank_reduce]
128+
V = V[bx, by, :, :, rank_reduce]
129+
Q = Q[bx, by, :, :, rank_reduce]
130+
O = O[bx, by, :, :, rank_reduce]
131+
l = l[bx, by, :, rank_reduce]
132+
m = m[bx, by, :, rank_reduce]
143133

144134
# Define SRAM for Q,K,V,S
145135
sram = gpu.dynamic_shared_memory()
146-
Qi = memref.view(
147-
sram,
148-
(Br, d),
149-
dtype=T.f32(),
150-
)
151-
Kj = memref.view(
152-
sram,
153-
(Bc, d),
154-
dtype=T.f32(),
155-
shift=Qi.n_elements,
156-
)
157-
Vj = memref.view(
158-
sram,
159-
(Bc, d),
160-
dtype=T.f32(),
161-
shift=Qi.n_elements + Kj.n_elements,
162-
)
136+
Qi = memref.view(sram, (Br, d), dtype=T.f32())
137+
Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements)
138+
Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements)
163139
S = memref.view(
164140
sram,
165141
(Br, Bc),
@@ -169,22 +145,22 @@ def flash_attention(
169145

170146
for bc in scf.range_(0, N, Bc):
171147
# Load Kj, Vj to SRAM
172-
K_ = K_[:, :, bc : bc + 1, :]
173-
V_ = V_[:, :, bc : bc + 1, :]
148+
K_ = K[bc : bc + 1, :]
149+
V_ = V[bc : bc + 1, :]
174150
for x in scf.range_(0, d):
175-
Kj[tx, x] = K_[0, 0, tx, x]
176-
Vj[tx, x] = V_[0, 0, tx, x]
151+
Kj[tx, x] = K_[tx, x]
152+
Vj[tx, x] = V_[tx, x]
177153

178154
for br in scf.range_(0, N, Br):
179155
# Load Qi to SRAM, l and m to registers
180-
Q_ = Q_[:, :, br : br + 1, :]
156+
Q_ = Q[br : br + 1, :]
181157
for x in scf.range_(0, d):
182-
Qi[tx, x] = Q_[0, 0, tx, x]
158+
Qi[tx, x] = Q_[tx, x]
183159

184-
l_ = l_[:, :, br : br + 1]
185-
m_ = m_[:, :, br : br + 1]
186-
row_l_prev = l_[0, 0, tx]
187-
row_m_prev = m_[0, 0, tx]
160+
l_ = l[br : br + 1]
161+
m_ = m[br : br + 1]
162+
row_l_prev = l_[tx]
163+
row_m_prev = m_[tx]
188164

189165
# S = QK^T, row_m = rowmax(S)
190166
row_m: T.f32() = float(np.finfo(np.float32).min)
@@ -218,22 +194,21 @@ def flash_attention(
218194
+ math.exp(row_m - row_m_new) * row_l
219195
)
220196
div = 1.0 / row_l_new
221-
c = row_l_prev * math.exp(row_m_prev - row_m_new)
197+
f1 = row_l_prev * math.exp(row_m_prev - row_m_new)
198+
f2 = math.exp(row_m - row_m_new)
222199

223200
# Write O, l, m to HBM
224-
O_ = O_[:, :, br : br + 1, :]
201+
O_ = O[br : br + 1, :]
225202
for x in scf.range_(0, d):
226203
pv: T.f32() = 0.0 # Pij * Vj
227204
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
228205
pv += S[tx, y] * Vj[y, x]
229206
pv = yield pv
230207

231-
O_[0, 0, tx, x] = div * (
232-
c * O_[0, 0, tx, x] + math.exp(row_m - row_m_new) * pv
233-
)
208+
O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv)
234209

235-
l_[0, 0, tx] = row_l_new
236-
m_[0, 0, tx] = row_m_new
210+
l_[tx] = row_l_new
211+
m_[tx] = row_m_new
237212

238213
gpu.barrier()
239214

@@ -305,7 +280,7 @@ def flash_attention(
305280
)
306281
hsaco = get_compile_object_bytes(lowered_module)
307282
if output_format in {"isa", "llvm", "offloading"}:
308-
with open(Path(__file__).parent / "flashattention.amdgcn", "wb") as f:
283+
with open(Path(__file__).parent / f"flashattention.{output_format}", "wb") as f:
309284
f.write(hsaco)
310285
exit()
311286

@@ -338,7 +313,7 @@ def flash_attention(
338313
shared_memory,
339314
) = launch_params[kernel.__name__]
340315

341-
host, device = init_copy_host_device()
316+
host, device = init_copy_host_device(B, nh, N, d)
342317
q_h, k_h, v_h, *_ = host
343318
correct = manual_attn(q_h, k_h, v_h)
344319

@@ -360,8 +335,7 @@ def flash_attention(
360335
with np.printoptions(threshold=np.inf, linewidth=np.inf):
361336
print(
362337
"correct - output:\n",
363-
correct.round().reshape(B, nh, N, d)
364-
- O_h.round().reshape(B, nh, N, d),
338+
correct.round() - O_h.round(),
365339
)
366340
print(f"{kernel.__name__} failed\n")
367341
else:

mlir/extras/dialects/ext/arith.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ...ast.canonicalize import StrictTransformer, Canonicalizer, BytecodePatcher
1414
from ...util import get_user_code_loc, infer_mlir_type, mlir_type_to_np_dtype
1515
from ...._mlir_libs._mlir import register_value_caster
16-
from ....dialects.arith import *
1716
from ....dialects import complex as complex_dialect
1817
from ....dialects._arith_enum_gen import (
1918
_arith_cmpfpredicateattr,
@@ -457,8 +456,9 @@ def __repr__(self):
457456
__mod__ = partialmethod(_binary_op, op="mod")
458457
__and__ = partialmethod(_binary_op, op="and")
459458
__or__ = partialmethod(_binary_op, op="or")
460-
__radd__ = partialmethod(_binary_op, op="add")
459+
# TODO(max): powi/powf using math
461460

461+
__radd__ = partialmethod(_rbinary_op, op="add")
462462
__rsub__ = partialmethod(_rbinary_op, op="sub")
463463
__rmul__ = partialmethod(_rbinary_op, op="mul")
464464
__rtruediv__ = partialmethod(_rbinary_op, op="truediv")

0 commit comments

Comments
 (0)