33import mlir .extras .types as T
44import numpy as np
55from hip import hip
6- from mlir .ir import InsertionPoint , IntegerAttr , UnitAttr , Type
6+ from mlir .ir import InsertionPoint , IntegerAttr , UnitAttr
77from mlir .extras .ast .canonicalize import canonicalize
88from mlir .extras .context import RAIIMLIRContextModule
9- from mlir .extras .dialects .ext import memref , scf , arith , gpu , llvm , affine
9+ from mlir .extras .dialects .ext import memref , scf , arith , gpu , llvm
10+ from mlir .dialects import math
1011
1112# noinspection PyUnresolvedReferences
1213from mlir .extras .dialects .ext .gpu import (
2526from util import hip_check , launch_kernel , hip_synchronize
2627
2728
28- def init_copy_host_device ():
29- q_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
30- k_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
31- v_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
32- l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
33- m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
29+ def init_copy_host_device (B , nh , N , d ):
30+ q_h = np .random .randint (0 , 10 , (B , nh , N , d )).astype (dtype = np .float32 )
31+ k_h = np .random .randint (0 , 10 , (B , nh , N , d )).astype (dtype = np .float32 )
32+ v_h = np .random .randint (0 , 10 , (B , nh , N , d )).astype (dtype = np .float32 )
33+ l_h = np .zeros ((B , nh , N ), dtype = np .float32 )
34+ m_h = np .full ((B , nh , N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
3435 O_h = np .zeros_like (q_h , dtype = np .float32 )
3536
3637 host = [q_h , k_h , v_h , l_h , m_h , O_h ]
@@ -87,11 +88,7 @@ def gpu_module():
8788N = 128
8889d = 128
8990
90- import math
91-
92- Tc = math .ceil (N / Bc )
93- Tr = math .ceil (N / Br )
94- softmax_scale = 1.0 / math .sqrt (d )
91+ softmax_scale = 1.0 / float (np .sqrt (d ))
9592
9693
9794def softmax (x , axis = None ):
@@ -101,20 +98,13 @@ def softmax(x, axis=None):
10198
10299
103100def manual_attn (q , k , v ):
104- # the kernel below overwrites the global math.........
105- import math
106-
107- q = q .reshape (B , nh , N , d )
108- k = k .reshape (B , nh , N , d )
109- v = v .reshape (B , nh , N , d )
110-
111- att = q @ k .transpose (0 , 1 , 3 , 2 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
101+ att = q @ k .transpose (0 , 1 , 3 , 2 ) * (1.0 / float (np .sqrt (k .shape [- 1 ])))
112102 att = softmax (att , axis = - 1 )
113103 y = att @ v
114- return y . flatten ()
104+ return y
115105
116106
117- from mlir . dialects import math
107+ rank_reduce = memref . MemRef . rank_reduce
118108
119109
120110# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
@@ -134,32 +124,18 @@ def flash_attention(
134124 # gpu.printf("bx %ld, by %ld\n", bx, by)
135125
136126 # Offset into Q,K,V,O,l,m - different for each batch and head
137- K_ = K [bx , by , :, :]
138- V_ = V [bx , by , :, :]
139- Q_ = Q [bx , by , :, :]
140- O_ = O [bx , by , :, :]
141- l_ = l [bx , by , :]
142- m_ = m [bx , by , :]
127+ K = K [bx , by , :, :, rank_reduce ]
128+ V = V [bx , by , :, :, rank_reduce ]
129+ Q = Q [bx , by , :, :, rank_reduce ]
130+ O = O [bx , by , :, :, rank_reduce ]
131+ l = l [bx , by , :, rank_reduce ]
132+ m = m [bx , by , :, rank_reduce ]
143133
144134 # Define SRAM for Q,K,V,S
145135 sram = gpu .dynamic_shared_memory ()
146- Qi = memref .view (
147- sram ,
148- (Br , d ),
149- dtype = T .f32 (),
150- )
151- Kj = memref .view (
152- sram ,
153- (Bc , d ),
154- dtype = T .f32 (),
155- shift = Qi .n_elements ,
156- )
157- Vj = memref .view (
158- sram ,
159- (Bc , d ),
160- dtype = T .f32 (),
161- shift = Qi .n_elements + Kj .n_elements ,
162- )
136+ Qi = memref .view (sram , (Br , d ), dtype = T .f32 ())
137+ Kj = memref .view (sram , (Bc , d ), dtype = T .f32 (), shift = Qi .n_elements )
138+ Vj = memref .view (sram , (Bc , d ), dtype = T .f32 (), shift = Qi .n_elements + Kj .n_elements )
163139 S = memref .view (
164140 sram ,
165141 (Br , Bc ),
@@ -169,22 +145,22 @@ def flash_attention(
169145
170146 for bc in scf .range_ (0 , N , Bc ):
171147 # Load Kj, Vj to SRAM
172- K_ = K_ [:, :, bc : bc + 1 , :]
173- V_ = V_ [:, :, bc : bc + 1 , :]
148+ K_ = K [ bc : bc + 1 , :]
149+ V_ = V [ bc : bc + 1 , :]
174150 for x in scf .range_ (0 , d ):
175- Kj [tx , x ] = K_ [0 , 0 , tx , x ]
176- Vj [tx , x ] = V_ [0 , 0 , tx , x ]
151+ Kj [tx , x ] = K_ [tx , x ]
152+ Vj [tx , x ] = V_ [tx , x ]
177153
178154 for br in scf .range_ (0 , N , Br ):
179155 # Load Qi to SRAM, l and m to registers
180- Q_ = Q_ [:, :, br : br + 1 , :]
156+ Q_ = Q [ br : br + 1 , :]
181157 for x in scf .range_ (0 , d ):
182- Qi [tx , x ] = Q_ [0 , 0 , tx , x ]
158+ Qi [tx , x ] = Q_ [tx , x ]
183159
184- l_ = l_ [:, :, br : br + 1 ]
185- m_ = m_ [:, :, br : br + 1 ]
186- row_l_prev = l_ [0 , 0 , tx ]
187- row_m_prev = m_ [0 , 0 , tx ]
160+ l_ = l [ br : br + 1 ]
161+ m_ = m [ br : br + 1 ]
162+ row_l_prev = l_ [tx ]
163+ row_m_prev = m_ [tx ]
188164
189165 # S = QK^T, row_m = rowmax(S)
190166 row_m : T .f32 () = float (np .finfo (np .float32 ).min )
@@ -218,22 +194,21 @@ def flash_attention(
218194 + math .exp (row_m - row_m_new ) * row_l
219195 )
220196 div = 1.0 / row_l_new
221- c = row_l_prev * math .exp (row_m_prev - row_m_new )
197+ f1 = row_l_prev * math .exp (row_m_prev - row_m_new )
198+ f2 = math .exp (row_m - row_m_new )
222199
223200 # Write O, l, m to HBM
224- O_ = O_ [:, :, br : br + 1 , :]
201+ O_ = O [ br : br + 1 , :]
225202 for x in scf .range_ (0 , d ):
226203 pv : T .f32 () = 0.0 # Pij * Vj
227204 for y , pv , _ in scf .range_ (0 , Bc , iter_args = [pv ]):
228205 pv += S [tx , y ] * Vj [y , x ]
229206 pv = yield pv
230207
231- O_ [0 , 0 , tx , x ] = div * (
232- c * O_ [0 , 0 , tx , x ] + math .exp (row_m - row_m_new ) * pv
233- )
208+ O_ [tx , x ] = div * (f1 * O_ [tx , x ] + f2 * pv )
234209
235- l_ [0 , 0 , tx ] = row_l_new
236- m_ [0 , 0 , tx ] = row_m_new
210+ l_ [tx ] = row_l_new
211+ m_ [tx ] = row_m_new
237212
238213 gpu .barrier ()
239214
@@ -305,7 +280,7 @@ def flash_attention(
305280)
306281hsaco = get_compile_object_bytes (lowered_module )
307282if output_format in {"isa" , "llvm" , "offloading" }:
308- with open (Path (__file__ ).parent / "flashattention.amdgcn " , "wb" ) as f :
283+ with open (Path (__file__ ).parent / f "flashattention.{ output_format } " , "wb" ) as f :
309284 f .write (hsaco )
310285 exit ()
311286
@@ -338,7 +313,7 @@ def flash_attention(
338313 shared_memory ,
339314 ) = launch_params [kernel .__name__ ]
340315
341- host , device = init_copy_host_device ()
316+ host , device = init_copy_host_device (B , nh , N , d )
342317 q_h , k_h , v_h , * _ = host
343318 correct = manual_attn (q_h , k_h , v_h )
344319
@@ -360,8 +335,7 @@ def flash_attention(
360335 with np .printoptions (threshold = np .inf , linewidth = np .inf ):
361336 print (
362337 "correct - output:\n " ,
363- correct .round ().reshape (B , nh , N , d )
364- - O_h .round ().reshape (B , nh , N , d ),
338+ correct .round () - O_h .round (),
365339 )
366340 print (f"{ kernel .__name__ } failed\n " )
367341 else :
0 commit comments