Skip to content

Commit 4323bcc

Browse files
committed
works
1 parent 76e4a5c commit 4323bcc

File tree

1 file changed

+51
-33
lines changed

1 file changed

+51
-33
lines changed

examples/flash_attention.py

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -121,65 +121,81 @@ def manual_attn(q, k, v):
121121
@gpu_func(emit=True)
122122
@canonicalize(using=[scf.canonicalizer, arith.canonicalizer])
123123
def flash_attention(
124-
Q: T.memref(B * nh * N * d, T.f32()),
124+
Q: T.memref(B, nh, N, d, T.f32()),
125125
K: T.memref(B, nh, N, d, T.f32()),
126-
V: T.memref(B * nh * N * d, T.f32()),
127-
l: T.memref(B * nh * N, T.f32()),
128-
m: T.memref(B * nh * N, T.f32()),
129-
O: T.memref(B * nh * N * d, T.f32()),
126+
V: T.memref(B, nh, N, d, T.f32()),
127+
l: T.memref(B, nh, N, T.f32()),
128+
m: T.memref(B, nh, N, T.f32()),
129+
O: T.memref(B, nh, N, d, T.f32()),
130130
):
131131
tx = thread_idx.x
132132
# batch idx, head_idx
133133
bx, by = block_idx.x, block_idx.y
134134
# gpu.printf("bx %ld, by %ld\n", bx, by)
135135

136136
# Offset into Q,K,V,O,l,m - different for each batch and head
137-
qkv_offset = bx * nh * N * d + by * N * d
138-
lm_offset = bx * nh * N + by * N # offset for l and m
137+
K_ = K[bx, by, :, :]
138+
V_ = V[bx, by, :, :]
139+
Q_ = Q[bx, by, :, :]
140+
O_ = O[bx, by, :, :]
141+
l_ = l[bx, by, :]
142+
m_ = m[bx, by, :]
139143

140144
# Define SRAM for Q,K,V,S
141145
sram = gpu.dynamic_shared_memory()
142-
Qi = memref.view(sram, (Br * d,), dtype=T.f32())
143-
Kj = memref.view(sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements)
146+
Qi = memref.view(
147+
sram,
148+
(Br, d),
149+
dtype=T.f32(),
150+
)
151+
Kj = memref.view(
152+
sram,
153+
(Bc, d),
154+
dtype=T.f32(),
155+
shift=Qi.n_elements,
156+
)
144157
Vj = memref.view(
145-
sram, (Bc * d,), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements
158+
sram,
159+
(Bc, d),
160+
dtype=T.f32(),
161+
shift=Qi.n_elements + Kj.n_elements,
146162
)
147163
S = memref.view(
148164
sram,
149-
(Br * Bc,),
165+
(Br, Bc),
150166
dtype=T.f32(),
151167
shift=Qi.n_elements + Kj.n_elements + Vj.n_elements,
152168
)
153169

154-
# K_ = memref.reinterpret_cast(K, [0], [B, nh, N, d])
155-
# K_ = K_[bx : bx + 1, by : by + 1, :, :]
156-
for j in scf.range_(0, Tc):
157-
K_ = K[bx, by, :, :]
170+
for bc in scf.range_(0, N, Bc):
158171
# Load Kj, Vj to SRAM
172+
K_ = K_[:, :, bc : bc + 1, :]
173+
V_ = V_[:, :, bc : bc + 1, :]
159174
for x in scf.range_(0, d):
160-
# Kj[tx * d + x] = K[qkv_offset + Bc * d * j + tx * d + x]
161-
K_ = K_[:, :, j * Bc: (j + 1) * Bc, :]
162-
Vj[tx * d + x] = V[qkv_offset + Bc * d * j + tx * d + x]
175+
Kj[tx, x] = K_[0, 0, tx, x]
176+
Vj[tx, x] = V_[0, 0, tx, x]
163177

164-
for i in scf.range_(0, Tr):
178+
for br in scf.range_(0, N, Br):
165179
# Load Qi to SRAM, l and m to registers
180+
Q_ = Q_[:, :, br : br + 1, :]
166181
for x in scf.range_(0, d):
167-
ii = qkv_offset + Bc * d * i + tx * d + x
168-
Qi[tx * d + x] = Q[ii]
182+
Qi[tx, x] = Q_[0, 0, tx, x]
169183

170-
row_m_prev = m[lm_offset + Br * i + tx]
171-
row_l_prev = l[lm_offset + Br * i + tx]
184+
l_ = l_[:, :, br : br + 1]
185+
m_ = m_[:, :, br : br + 1]
186+
row_l_prev = l_[0, 0, tx]
187+
row_m_prev = m_[0, 0, tx]
172188

173189
# S = QK^T, row_m = rowmax(S)
174190
row_m: T.f32() = float(np.finfo(np.float32).min)
175191
for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]):
176192
sum: T.f32() = 0.0
177193
for x, sum, _ in scf.range_(0, d, iter_args=[sum]):
178-
sum += Qi[tx * d + x] * Kj[y * d + x]
194+
sum += Qi[tx, x] * Kj[y, x]
179195
sum = yield sum
180196

181197
sum *= softmax_scale
182-
S[Bc * tx + y] = sum
198+
S[tx, y] = sum
183199

184200
if sum > row_m:
185201
row_m_ = yield sum
@@ -191,8 +207,8 @@ def flash_attention(
191207
# P = exp(S - row_m), row_l = rowsum(P)
192208
row_l: T.f32() = 0.0
193209
for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]):
194-
S[Bc * tx + y] = math.exp(S[Bc * tx + y] - row_m)
195-
row_l += S[Bc * tx + y]
210+
S[tx, y] = math.exp(S[tx, y] - row_m)
211+
row_l += S[tx, y]
196212
row_l = yield row_l
197213

198214
# Compute new m and l
@@ -205,24 +221,26 @@ def flash_attention(
205221
c = row_l_prev * math.exp(row_m_prev - row_m_new)
206222

207223
# Write O, l, m to HBM
224+
O_ = O_[:, :, br : br + 1, :]
208225
for x in scf.range_(0, d):
209226
pv: T.f32() = 0.0 # Pij * Vj
210227
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
211-
pv += S[Bc * tx + y] * Vj[y * d + x]
228+
pv += S[tx, y] * Vj[y, x]
212229
pv = yield pv
213230

214-
ii = qkv_offset + Bc * d * i + tx * d + x
215-
O[ii] = div * (c * O[ii] + math.exp(row_m - row_m_new) * pv)
231+
O_[0, 0, tx, x] = div * (
232+
c * O_[0, 0, tx, x] + math.exp(row_m - row_m_new) * pv
233+
)
216234

217-
m[lm_offset + Br * i + tx] = row_m_new
218-
l[lm_offset + Br * i + tx] = row_l_new
235+
l_[0, 0, tx] = row_l_new
236+
m_[0, 0, tx] = row_m_new
219237

220238
gpu.barrier()
221239

222240

223241
ip.__exit__(None, None, None)
224242

225-
print(gpu_module)
243+
# print(gpu_module)
226244

227245
sram_size = 4 * Bc * d * np.float32().itemsize
228246

0 commit comments

Comments
 (0)