1+ from pathlib import Path
2+
13import mlir .extras .types as T
24import numpy as np
35from hip import hip
46from mlir .ir import InsertionPoint , IntegerAttr , UnitAttr
5-
67from mlir .extras .ast .canonicalize import canonicalize
78from mlir .extras .context import RAIIMLIRContextModule
89from mlir .extras .dialects .ext import memref , scf , arith , gpu , llvm
2324# noinspection PyUnresolvedReferences
2425from util import hip_check , launch_kernel , hip_synchronize
2526
27+
28+ def init_copy_host_device ():
29+ q_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
30+ k_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
31+ v_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
32+ l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
33+ m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
34+ O_h = np .zeros_like (q_h , dtype = np .float32 )
35+
36+ host = [q_h , k_h , v_h , l_h , m_h , O_h ]
37+ device = [hip_check (hip .hipMalloc (h .size * h .itemsize )) for h in host ]
38+
39+ for dev , h in zip (device , host ):
40+ hip_check (
41+ hip .hipMemcpy (
42+ dev , h , h .size * h .itemsize , hip .hipMemcpyKind .hipMemcpyHostToDevice
43+ )
44+ )
45+
46+ return host , device
47+
48+
49+ def copy_device_host (host , device ):
50+ for d , h in zip (device , host ):
51+ hip_check (
52+ hip .hipMemcpy (
53+ h , d , h .size * h .itemsize , hip .hipMemcpyKind .hipMemcpyDeviceToHost
54+ )
55+ )
56+ hip_check (hip .hipFree (d ))
57+
58+ return host
59+
60+
2661# just so it doesn't get DCE'd by black/reformat
2762# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
2863_ = memref
@@ -44,25 +79,19 @@ def gpu_module():
4479ip = InsertionPoint .at_block_begin (gpu_module .regions [0 ].blocks [0 ])
4580ip .__enter__ ()
4681
47- batch_size = 16
48- n_head = 12
49- seq_len = 64
50- head_embd = 64
51-
5282Bc = 32
5383Br = 32
5484
55- B = batch_size
56- nh = n_head
57- N = seq_len
58- d = head_embd
85+ B = 16
86+ nh = 12
87+ N = 64
88+ d = 64
5989
6090import math
6191
6292Tc = math .ceil (N / Bc )
6393Tr = math .ceil (N / Br )
6494softmax_scale = 1.0 / math .sqrt (d )
65- tile_size = Bc * d # size of Qi, Kj, Vj
6695
6796
6897def softmax (x , axis = None ):
@@ -75,11 +104,11 @@ def manual_attn(q, k, v):
75104 # the kernel below overwrites the global math.........
76105 import math
77106
78- q = q .reshape (batch_size , n_head , seq_len , head_embd )
79- k = k .reshape (batch_size , n_head , seq_len , head_embd )
80- v = v .reshape (batch_size , n_head , seq_len , head_embd )
107+ q = q .reshape (B , nh , N , d )
108+ k = k .reshape (B , nh , N , d )
109+ v = v .reshape (B , nh , N , d )
81110
82- att = q @ k .transpose (0 , 1 , - 2 , - 1 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
111+ att = q @ k .transpose (0 , 1 , 3 , 2 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
83112 att = softmax (att , axis = - 1 )
84113 y = att @ v
85114 return y .flatten ()
@@ -92,42 +121,49 @@ def manual_attn(q, k, v):
92121@gpu_func (emit = True )
93122@canonicalize (using = [scf .canonicalizer , arith .canonicalizer ])
94123def flash_attention (
95- Q : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
96- K : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
97- V : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
124+ Q : T .memref (B * nh * N * d , T .f32 ()),
125+ K : T .memref (B * nh * N * d , T .f32 ()),
126+ V : T .memref (B * nh * N * d , T .f32 ()),
98127 l : T .memref (B * nh * N , T .f32 ()),
99128 m : T .memref (B * nh * N , T .f32 ()),
100- O : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
129+ O : T .memref (B * nh * N * d , T .f32 ()),
101130):
102131 tx = thread_idx .x
103- bx = block_idx .x
104- by = block_idx . y # batch and head index
132+ bx , by = block_idx .x , block_idx . y
133+ gy = grid_dim . y
105134
106135 # Offset into Q,K,V,O,l,m - different for each batch and head
107- qkv_offset = bx * grid_dim . y * N * d + by * N * d # gridDim.y = nh
108- lm_offset = bx * grid_dim . y * N + by * N # offset for l and m
136+ qkv_offset = bx * gy * N * d + by * N * d # gridDim.y = nh
137+ lm_offset = bx * gy * N + by * N # offset for l and m
109138
110139 # Define SRAM for Q,K,V,S
111140 sram = gpu .dynamic_shared_memory ()
112- Qi = memref .view (sram , (tile_size ,), dtype = T .f32 ())
113- Kj = memref .view (sram , (tile_size ,), dtype = T .f32 (), shift = tile_size * 1 )
114- Vj = memref .view (sram , (tile_size ,), dtype = T .f32 (), shift = tile_size * 2 )
115- S = memref .view (sram , (tile_size ,), dtype = T .f32 (), shift = tile_size * 3 )
141+ Qi = memref .view (sram , (Br * d ,), dtype = T .f32 ())
142+ Kj = memref .view (sram , (Bc * d ,), dtype = T .f32 (), shift = Qi .n_elements )
143+ Vj = memref .view (
144+ sram , (Bc * d ,), dtype = T .f32 (), shift = Qi .n_elements + Kj .n_elements
145+ )
146+ S = memref .view (
147+ sram ,
148+ (Br * Bc ,),
149+ dtype = T .f32 (),
150+ shift = Qi .n_elements + Kj .n_elements + Vj .n_elements ,
151+ )
116152
117153 for j in scf .range_ (0 , Tc ):
118154 # Load Kj, Vj to SRAM
119155 for x in scf .range_ (0 , d ):
120- Kj [tx * d + x ] = K [qkv_offset + tile_size * j + tx * d + x ]
121- Vj [tx * d + x ] = V [qkv_offset + tile_size * j + tx * d + x ]
122-
123- gpu .barrier () # such that the inner loop can use the correct Kj, Vj
156+ Kj [tx * d + x ] = K [qkv_offset + Bc * d * j + tx * d + x ]
157+ Vj [tx * d + x ] = V [qkv_offset + Bc * d * j + tx * d + x ]
124158
125159 for i in scf .range_ (0 , Tr ):
126160 # Load Qi to SRAM, l and m to registers
127161 for x in scf .range_ (0 , d ):
128- ii = qkv_offset + tile_size * i + tx * d + x
162+ ii = qkv_offset + Bc * d * i + tx * d + x
129163 Qi [tx * d + x ] = Q [ii ]
130164
165+ gpu .barrier ()
166+
131167 row_m_prev = m [lm_offset + Br * i + tx ]
132168 row_l_prev = l [lm_offset + Br * i + tx ]
133169
@@ -172,21 +208,16 @@ def flash_attention(
172208 pv += S [Bc * tx + y ] * Vj [y * d + x ]
173209 pv = yield pv
174210
175- ii = qkv_offset + tile_size * i + tx * d + x
211+ ii = qkv_offset + Bc * d * i + tx * d + x
176212 O [ii ] = div * (c * O [ii ] + math .exp (row_m - row_m_new ) * pv )
177213
178- gpu .barrier () # otherwise, thread can use the wrong Kj, Vj in inner loop
179-
180214 m [lm_offset + Br * i + tx ] = row_m_new
181215 l [lm_offset + Br * i + tx ] = row_l_new
182216
183- # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184- # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
185-
186217
187218ip .__exit__ (None , None , None )
188219
189- sram_size = 4 * tile_size * np .float32 ().itemsize
220+ sram_size = 4 * Bc * d * np .float32 ().itemsize
190221
191222launch_params = {
192223 flash_attention .__name__ : (
@@ -206,6 +237,9 @@ def flash_attention(
206237 .rocdl_attach_target (chip = arch , O = 3 , abi = "500" ),
207238)
208239
240+ # print(simplified_module)
241+ # exit()
242+
209243lowered_module = run_pipeline (
210244 simplified_module ,
211245 Pipeline ()
@@ -216,7 +250,8 @@ def flash_attention(
216250 )
217251 )
218252 .gpu_to_llvm ()
219- .lower_to_llvm (),
253+ .lower_to_llvm ()
254+ .ensure_debug_info_scope_on_llvm_func (emission_kind = "Full" ),
220255 # .Nested("llvm.func", Pipeline().sroa()),
221256)
222257
@@ -236,68 +271,33 @@ def flash_attention(
236271 T .index (), np .prod (thread_dims )
237272 )
238273
239- lowered_module = run_pipeline (lowered_module , Pipeline ().gpu_module_to_binary ())
274+ output_format = "bin"
275+ # output_format = "isa"
276+
277+ lowered_module = run_pipeline (
278+ lowered_module , Pipeline ().gpu_module_to_binary (format = output_format )
279+ )
240280hsaco = get_compile_object_bytes (lowered_module )
281+ if output_format == "isa" :
282+ with open (Path (__file__ ).parent / "flashattention.amdgcn" , "w" ) as f :
283+ f .write (hsaco .decode ())
284+ exit ()
241285
242286hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
243287
244- q_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
245- dtype = np .float32
246- )
247- k_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
248- dtype = np .float32
249- )
250- v_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
251- dtype = np .float32
252- )
253- l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
254- m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
255- O_h = np .zeros_like (q_h , dtype = np .float32 )
256-
257- q_num_bytes = q_h .size * q_h .itemsize
258- k_num_bytes = k_h .size * k_h .itemsize
259- v_num_bytes = v_h .size * v_h .itemsize
260- l_num_bytes = l_h .size * l_h .itemsize
261- m_num_bytes = m_h .size * m_h .itemsize
262- O_num_bytes = O_h .size * O_h .itemsize
263-
264- q_d = hip_check (hip .hipMalloc (q_num_bytes ))
265- k_d = hip_check (hip .hipMalloc (k_num_bytes ))
266- v_d = hip_check (hip .hipMalloc (v_num_bytes ))
267- l_d = hip_check (hip .hipMalloc (l_num_bytes ))
268- m_d = hip_check (hip .hipMalloc (m_num_bytes ))
269- O_d = hip_check (hip .hipMalloc (O_num_bytes ))
270-
271288stream = 0
272289
273290times = {
274291 flash_attention : 0 ,
275292}
276- # random.shuffle(kernels)
277- runs = 16
293+ runs = 32
278294for kernel in times :
279295 for i in range (runs ):
280296 function = hip_check (
281297 hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ())
282298 )
283299 hip_check (hip .hipDeviceSynchronize ())
284300
285- for d , h , num_bytes in zip (
286- [q_d , k_d , v_d , l_d , m_d , O_d ],
287- [q_h , k_h , v_h , l_h , m_h , O_h ],
288- [
289- q_num_bytes ,
290- k_num_bytes ,
291- v_num_bytes ,
292- l_num_bytes ,
293- m_num_bytes ,
294- O_num_bytes ,
295- ],
296- ):
297- hip_check (
298- hip .hipMemcpy (d , h , num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice )
299- )
300-
301301 (
302302 (
303303 blocks_per_grid_x ,
@@ -312,6 +312,10 @@ def flash_attention(
312312 shared_memory ,
313313 ) = launch_params [kernel .__name__ ]
314314
315+ host , device = init_copy_host_device ()
316+ q_h , k_h , v_h , * _ = host
317+ correct = manual_attn (q_h , k_h , v_h )
318+
315319 time_compute = launch_kernel (
316320 function .as_c_void_p (),
317321 blocks_per_grid_x ,
@@ -322,36 +326,20 @@ def flash_attention(
322326 threads_per_block_z ,
323327 stream ,
324328 shared_memory ,
325- q_d ,
326- k_d ,
327- v_d ,
328- l_d ,
329- m_d ,
330- O_d ,
329+ * device ,
331330 )
332331
333- hip_check (
334- hip .hipMemcpy (
335- l_h , l_d , l_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
336- )
337- )
338- hip_check (
339- hip .hipMemcpy (
340- m_h , m_d , m_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
341- )
342- )
343- hip_check (
344- hip .hipMemcpy (
345- O_h , O_d , O_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
346- )
347- )
348- correct = manual_attn (q_h , k_h , v_h )
332+ * _ , O_h = copy_device_host (host , device )
349333 if not np .allclose (correct , O_h ):
350- print ("correct" , correct )
351- print ("l_h" , l_h )
352- print ("m_h" , m_h )
353- print ("output" , O_h )
354- print (f"{ kernel .__name__ } failed" )
334+ with np .printoptions (threshold = np .inf , linewidth = np .inf ):
335+ print (
336+ "correct - output:\n " ,
337+ correct .round ().reshape (B , nh , N , d )
338+ - O_h .round ().reshape (B , nh , N , d ),
339+ )
340+ print (f"{ kernel .__name__ } failed\n " )
341+ else :
342+ print (f"{ kernel .__name__ } : { time_compute :.03f} ms" )
355343
356344 times [kernel ] += time_compute
357345
0 commit comments