Skip to content

Commit 9cb3d6b

Browse files
committed
rank-reduced
1 parent 4323bcc commit 9cb3d6b

File tree

2 files changed

+146
-86
lines changed

2 files changed

+146
-86
lines changed

examples/flash_attention.py

Lines changed: 41 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import mlir.extras.types as T
44
import numpy as np
55
from hip import hip
6-
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr, Type
6+
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
77
from mlir.extras.ast.canonicalize import canonicalize
88
from mlir.extras.context import RAIIMLIRContextModule
9-
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm, affine
9+
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
10+
from mlir.dialects import math
1011

1112
# noinspection PyUnresolvedReferences
1213
from mlir.extras.dialects.ext.gpu import (
@@ -25,12 +26,12 @@
2526
from util import hip_check, launch_kernel, hip_synchronize
2627

2728

28-
def init_copy_host_device():
29-
q_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
30-
k_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
31-
v_h = np.random.randint(0, 10, (B * nh * N * d)).astype(dtype=np.float32)
32-
l_h = np.zeros((B * nh * N), dtype=np.float32)
33-
m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32)
29+
def init_copy_host_device(B, nh, N, d):
30+
q_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
31+
k_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
32+
v_h = np.random.randint(0, 10, (B, nh, N, d)).astype(dtype=np.float32)
33+
l_h = np.zeros((B, nh, N), dtype=np.float32)
34+
m_h = np.full((B, nh, N), float(np.finfo(np.float32).min), dtype=np.float32)
3435
O_h = np.zeros_like(q_h, dtype=np.float32)
3536

3637
host = [q_h, k_h, v_h, l_h, m_h, O_h]
@@ -87,11 +88,7 @@ def gpu_module():
8788
N = 128
8889
d = 128
8990

90-
import math
91-
92-
Tc = math.ceil(N / Bc)
93-
Tr = math.ceil(N / Br)
94-
softmax_scale = 1.0 / math.sqrt(d)
91+
softmax_scale = 1.0 / float(np.sqrt(d))
9592

9693

9794
def softmax(x, axis=None):
@@ -101,20 +98,13 @@ def softmax(x, axis=None):
10198

10299

103100
def manual_attn(q, k, v):
104-
# the kernel below overwrites the global math.........
105-
import math
106-
107-
q = q.reshape(B, nh, N, d)
108-
k = k.reshape(B, nh, N, d)
109-
v = v.reshape(B, nh, N, d)
110-
111-
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / math.sqrt(k.shape[-1]))
101+
att = q @ k.transpose(0, 1, 3, 2) * (1.0 / float(np.sqrt(k.shape[-1])))
112102
att = softmax(att, axis=-1)
113103
y = att @ v
114-
return y.flatten()
104+
return y
115105

116106

117-
from mlir.dialects import math
107+
rank_reduce = memref.MemRef.rank_reduce
118108

119109

120110
# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
@@ -134,32 +124,18 @@ def flash_attention(
134124
# gpu.printf("bx %ld, by %ld\n", bx, by)
135125

136126
# Offset into Q,K,V,O,l,m - different for each batch and head
137-
K_ = K[bx, by, :, :]
138-
V_ = V[bx, by, :, :]
139-
Q_ = Q[bx, by, :, :]
140-
O_ = O[bx, by, :, :]
141-
l_ = l[bx, by, :]
142-
m_ = m[bx, by, :]
127+
K = K[bx, by, :, :, rank_reduce]
128+
V = V[bx, by, :, :, rank_reduce]
129+
Q = Q[bx, by, :, :, rank_reduce]
130+
O = O[bx, by, :, :, rank_reduce]
131+
l = l[bx, by, :, rank_reduce]
132+
m = m[bx, by, :, rank_reduce]
143133

144134
# Define SRAM for Q,K,V,S
145135
sram = gpu.dynamic_shared_memory()
146-
Qi = memref.view(
147-
sram,
148-
(Br, d),
149-
dtype=T.f32(),
150-
)
151-
Kj = memref.view(
152-
sram,
153-
(Bc, d),
154-
dtype=T.f32(),
155-
shift=Qi.n_elements,
156-
)
157-
Vj = memref.view(
158-
sram,
159-
(Bc, d),
160-
dtype=T.f32(),
161-
shift=Qi.n_elements + Kj.n_elements,
162-
)
136+
Qi = memref.view(sram, (Br, d), dtype=T.f32())
137+
Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements)
138+
Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements)
163139
S = memref.view(
164140
sram,
165141
(Br, Bc),
@@ -169,22 +145,22 @@ def flash_attention(
169145

170146
for bc in scf.range_(0, N, Bc):
171147
# Load Kj, Vj to SRAM
172-
K_ = K_[:, :, bc : bc + 1, :]
173-
V_ = V_[:, :, bc : bc + 1, :]
148+
K_ = K[bc : bc + 1, :]
149+
V_ = V[bc : bc + 1, :]
174150
for x in scf.range_(0, d):
175-
Kj[tx, x] = K_[0, 0, tx, x]
176-
Vj[tx, x] = V_[0, 0, tx, x]
151+
Kj[tx, x] = K_[tx, x]
152+
Vj[tx, x] = V_[tx, x]
177153

178154
for br in scf.range_(0, N, Br):
179155
# Load Qi to SRAM, l and m to registers
180-
Q_ = Q_[:, :, br : br + 1, :]
156+
Q_ = Q[br : br + 1, :]
181157
for x in scf.range_(0, d):
182-
Qi[tx, x] = Q_[0, 0, tx, x]
158+
Qi[tx, x] = Q_[tx, x]
183159

184-
l_ = l_[:, :, br : br + 1]
185-
m_ = m_[:, :, br : br + 1]
186-
row_l_prev = l_[0, 0, tx]
187-
row_m_prev = m_[0, 0, tx]
160+
l_ = l[br : br + 1]
161+
m_ = m[br : br + 1]
162+
row_l_prev = l_[tx]
163+
row_m_prev = m_[tx]
188164

189165
# S = QK^T, row_m = rowmax(S)
190166
row_m: T.f32() = float(np.finfo(np.float32).min)
@@ -218,22 +194,21 @@ def flash_attention(
218194
+ math.exp(row_m - row_m_new) * row_l
219195
)
220196
div = 1.0 / row_l_new
221-
c = row_l_prev * math.exp(row_m_prev - row_m_new)
197+
f1 = row_l_prev * math.exp(row_m_prev - row_m_new)
198+
f2 = math.exp(row_m - row_m_new)
222199

223200
# Write O, l, m to HBM
224-
O_ = O_[:, :, br : br + 1, :]
201+
O_ = O[br : br + 1, :]
225202
for x in scf.range_(0, d):
226203
pv: T.f32() = 0.0 # Pij * Vj
227204
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
228205
pv += S[tx, y] * Vj[y, x]
229206
pv = yield pv
230207

231-
O_[0, 0, tx, x] = div * (
232-
c * O_[0, 0, tx, x] + math.exp(row_m - row_m_new) * pv
233-
)
208+
O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv)
234209

235-
l_[0, 0, tx] = row_l_new
236-
m_[0, 0, tx] = row_m_new
210+
l_[tx] = row_l_new
211+
m_[tx] = row_m_new
237212

238213
gpu.barrier()
239214

@@ -305,7 +280,7 @@ def flash_attention(
305280
)
306281
hsaco = get_compile_object_bytes(lowered_module)
307282
if output_format in {"isa", "llvm", "offloading"}:
308-
with open(Path(__file__).parent / "flashattention.amdgcn", "wb") as f:
283+
with open(Path(__file__).parent / f"flashattention.{output_format}", "wb") as f:
309284
f.write(hsaco)
310285
exit()
311286

@@ -338,7 +313,7 @@ def flash_attention(
338313
shared_memory,
339314
) = launch_params[kernel.__name__]
340315

341-
host, device = init_copy_host_device()
316+
host, device = init_copy_host_device(B, nh, N, d)
342317
q_h, k_h, v_h, *_ = host
343318
correct = manual_attn(q_h, k_h, v_h)
344319

@@ -360,8 +335,7 @@ def flash_attention(
360335
with np.printoptions(threshold=np.inf, linewidth=np.inf):
361336
print(
362337
"correct - output:\n",
363-
correct.round().reshape(B, nh, N, d)
364-
- O_h.round().reshape(B, nh, N, d),
338+
correct.round() - O_h.round(),
365339
)
366340
print(f"{kernel.__name__} failed\n")
367341
else:

mlir/extras/dialects/ext/memref.py

Lines changed: 105 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22
import operator
3-
from itertools import accumulate
3+
from itertools import accumulate, zip_longest
44
from typing import Sequence, Union, Optional
55

66
import numpy as np
@@ -24,7 +24,11 @@
2424
MixedValues,
2525
_dispatch_mixed_values,
2626
)
27-
from ....dialects.memref import _is_static_int_like, _infer_memref_subview_result_type
27+
from ....dialects.memref import (
28+
_is_static_int_like,
29+
_infer_memref_subview_result_type,
30+
_generated_subview,
31+
)
2832
from ....dialects.memref import *
2933
from ....ir import (
3034
DenseElementsAttr,
@@ -175,6 +179,8 @@ def __str__(self):
175179
def __repr__(self):
176180
return str(self)
177181

182+
rank_reduce = object()
183+
178184
def __getitem__(self, idx: tuple) -> "MemRef":
179185
loc = get_user_code_loc()
180186

@@ -189,6 +195,10 @@ def __getitem__(self, idx: tuple) -> "MemRef":
189195
return expand_shape(self, (0,), loc=loc)
190196

191197
idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx)
198+
rank_reduce = MemRef.rank_reduce in idx
199+
if rank_reduce:
200+
idx.remove(MemRef.rank_reduce)
201+
192202
for i, d in enumerate(idx):
193203
# TODO(max): rethink this since subview and etc probably take constant attributes?
194204
if isinstance(d, int):
@@ -197,7 +207,7 @@ def __getitem__(self, idx: tuple) -> "MemRef":
197207
if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape):
198208
return load(self, idx, loc=loc)
199209
else:
200-
return _subview(self, tuple(idx), loc=loc)
210+
return _subview(self, tuple(idx), rank_reduce=rank_reduce, loc=loc)
201211

202212
def __setitem__(self, idx, val):
203213
loc = get_user_code_loc()
@@ -306,10 +316,89 @@ def _maybe_compute_size(start, stop, step):
306316
return stop - start
307317

308318

319+
def subview(
320+
source: Value,
321+
offsets: MixedValues,
322+
sizes: MixedValues,
323+
strides: MixedValues,
324+
*,
325+
rank_reduce=False,
326+
result_type: Optional[MemRefType] = None,
327+
loc=None,
328+
ip=None,
329+
):
330+
if offsets is None:
331+
offsets = []
332+
if sizes is None:
333+
sizes = []
334+
if strides is None:
335+
strides = []
336+
source_strides, source_offset = source.type.get_strides_and_offset()
337+
if result_type is None and all(
338+
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
339+
):
340+
# If any are arith.constant results then this will canonicalize to python int
341+
# (which can then be used to fully specify the subview).
342+
(
343+
offsets,
344+
sizes,
345+
strides,
346+
result_type,
347+
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
348+
elif result_type is None:
349+
raise ValueError(
350+
"mixed static/dynamic offset/sizes/strides requires explicit result type."
351+
)
352+
353+
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
354+
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
355+
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
356+
357+
if rank_reduce:
358+
result_shape = list(result_type.shape)
359+
layout_strides = None
360+
if result_type.layout:
361+
layout_strides = result_type.layout.strides
362+
for i, (s, ss) in reversed(
363+
list(enumerate(list(zip_longest(sizes, static_sizes))))
364+
):
365+
if (
366+
s is not None and _is_static_int_like(s) and s.literal_value == 1
367+
) or ss == 1:
368+
del result_shape[i]
369+
if layout_strides is not None:
370+
del layout_strides[i]
371+
reduced_layout = None
372+
if layout_strides is not None:
373+
reduced_layout = StridedLayoutAttr.get(
374+
result_type.layout.offset, layout_strides
375+
)
376+
result_type = MemRefType.get(
377+
result_shape,
378+
result_type.element_type,
379+
reduced_layout,
380+
result_type.memory_space,
381+
)
382+
383+
return _generated_subview(
384+
result_type,
385+
source,
386+
offsets,
387+
sizes,
388+
strides,
389+
static_offsets,
390+
static_sizes,
391+
static_strides,
392+
loc=loc,
393+
ip=ip,
394+
)
395+
396+
309397
def _subview(
310398
mem: MemRef,
311399
idx,
312400
*,
401+
rank_reduce=False,
313402
loc=None,
314403
ip=None,
315404
) -> MemRef:
@@ -320,14 +409,9 @@ def _subview(
320409
out = mem
321410

322411
if indexer.is_constant():
323-
out = subview(
324-
out,
325-
offsets=indexer.static_offsets(),
326-
sizes=indexer.static_sizes(),
327-
strides=indexer.static_strides(),
328-
loc=loc,
329-
ip=ip,
330-
)
412+
offsets = indexer.static_offsets()
413+
sizes = indexer.static_sizes()
414+
strides = indexer.static_strides()
331415
else:
332416
# special tile case
333417
offsets = [None] * len(indexer.in_shape)
@@ -354,14 +438,16 @@ def _subview(
354438
assert all(
355439
map(lambda x: x is not None, offsets + sizes + strides)
356440
), f"not each slice is statically known: {indexer.indices}"
357-
out = subview(
358-
out,
359-
offsets=offsets,
360-
sizes=sizes,
361-
strides=strides,
362-
loc=loc,
363-
ip=ip,
364-
)
441+
442+
out = subview(
443+
out,
444+
offsets=offsets,
445+
sizes=sizes,
446+
strides=strides,
447+
rank_reduce=rank_reduce,
448+
loc=loc,
449+
ip=ip,
450+
)
365451

366452
# This adds newaxis/None dimensions.
367453
return expand_shape(out, indexer.newaxis_dims, loc=loc, ip=ip)

0 commit comments

Comments
 (0)