1+ from pathlib import Path
2+
13import mlir .extras .types as T
24import numpy as np
35from hip import hip
46from mlir .ir import InsertionPoint , IntegerAttr , UnitAttr
5-
67from mlir .extras .ast .canonicalize import canonicalize
78from mlir .extras .context import RAIIMLIRContextModule
89from mlir .extras .dialects .ext import memref , scf , arith , gpu , llvm
2324# noinspection PyUnresolvedReferences
2425from util import hip_check , launch_kernel , hip_synchronize
2526
27+
28+ def init_copy_host_device ():
29+ q_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
30+ dtype = np .float32
31+ )
32+ k_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
33+ dtype = np .float32
34+ )
35+ v_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
36+ dtype = np .float32
37+ )
38+ l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
39+ m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
40+ O_h = np .zeros_like (q_h , dtype = np .float32 )
41+
42+ host = [q_h , k_h , v_h , l_h , m_h , O_h ]
43+ device = [hip_check (hip .hipMalloc (h .size * h .itemsize )) for h in host ]
44+
45+ for d , h in zip (device , host ):
46+ hip_check (
47+ hip .hipMemcpy (
48+ d , h , h .size * h .itemsize , hip .hipMemcpyKind .hipMemcpyHostToDevice
49+ )
50+ )
51+
52+ return host , device
53+
54+
55+ def copy_device_host (host , device ):
56+ for d , h in zip (device , host ):
57+ hip_check (
58+ hip .hipMemcpy (
59+ h , d , h .size * h .itemsize , hip .hipMemcpyKind .hipMemcpyDeviceToHost
60+ )
61+ )
62+ hip_check (hip .hipFree (d ))
63+
64+ return host
65+
66+
2667# just so it doesn't get DCE'd by black/reformat
2768# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
2869_ = memref
@@ -79,7 +120,7 @@ def manual_attn(q, k, v):
79120 k = k .reshape (batch_size , n_head , seq_len , head_embd )
80121 v = v .reshape (batch_size , n_head , seq_len , head_embd )
81122
82- att = q @ k .transpose (0 , 1 , - 2 , - 1 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
123+ att = q @ k .transpose (0 , 1 , 3 , 2 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
83124 att = softmax (att , axis = - 1 )
84125 y = att @ v
85126 return y .flatten ()
@@ -100,12 +141,12 @@ def flash_attention(
100141 O : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
101142):
102143 tx = thread_idx .x
103- bx = block_idx .x
104- by = block_idx . y # batch and head index
144+ bx , by = block_idx .x , block_idx . y
145+ gy = grid_dim . y
105146
106147 # Offset into Q,K,V,O,l,m - different for each batch and head
107- qkv_offset = bx * grid_dim . y * N * d + by * N * d # gridDim.y = nh
108- lm_offset = bx * grid_dim . y * N + by * N # offset for l and m
148+ qkv_offset = bx * gy * N * d + by * N * d # gridDim.y = nh
149+ lm_offset = bx * gy * N + by * N # offset for l and m
109150
110151 # Define SRAM for Q,K,V,S
111152 sram = gpu .dynamic_shared_memory ()
@@ -120,8 +161,6 @@ def flash_attention(
120161 Kj [tx * d + x ] = K [qkv_offset + tile_size * j + tx * d + x ]
121162 Vj [tx * d + x ] = V [qkv_offset + tile_size * j + tx * d + x ]
122163
123- gpu .barrier () # such that the inner loop can use the correct Kj, Vj
124-
125164 for i in scf .range_ (0 , Tr ):
126165 # Load Qi to SRAM, l and m to registers
127166 for x in scf .range_ (0 , d ):
@@ -175,13 +214,10 @@ def flash_attention(
175214 ii = qkv_offset + tile_size * i + tx * d + x
176215 O [ii ] = div * (c * O [ii ] + math .exp (row_m - row_m_new ) * pv )
177216
178- gpu .barrier () # otherwise, thread can use the wrong Kj, Vj in inner loop
179-
180217 m [lm_offset + Br * i + tx ] = row_m_new
181218 l [lm_offset + Br * i + tx ] = row_l_new
182219
183- # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184- # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
220+ gpu .barrier () # otherwise, thread can use the wrong Kj, Vj in inner loop
185221
186222
187223ip .__exit__ (None , None , None )
@@ -236,68 +272,33 @@ def flash_attention(
236272 T .index (), np .prod (thread_dims )
237273 )
238274
239- lowered_module = run_pipeline (lowered_module , Pipeline ().gpu_module_to_binary ())
275+ output_format = "bin"
276+ # output_format = "isa"
277+
278+ lowered_module = run_pipeline (
279+ lowered_module , Pipeline ().gpu_module_to_binary (format = output_format )
280+ )
240281hsaco = get_compile_object_bytes (lowered_module )
282+ if output_format == "isa" :
283+ with open (Path (__file__ ).parent / "flashattention.amdgcn" , "w" ) as f :
284+ f .write (hsaco .decode ())
285+ exit ()
241286
242287hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
243288
244- q_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
245- dtype = np .float32
246- )
247- k_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
248- dtype = np .float32
249- )
250- v_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
251- dtype = np .float32
252- )
253- l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
254- m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
255- O_h = np .zeros_like (q_h , dtype = np .float32 )
256-
257- q_num_bytes = q_h .size * q_h .itemsize
258- k_num_bytes = k_h .size * k_h .itemsize
259- v_num_bytes = v_h .size * v_h .itemsize
260- l_num_bytes = l_h .size * l_h .itemsize
261- m_num_bytes = m_h .size * m_h .itemsize
262- O_num_bytes = O_h .size * O_h .itemsize
263-
264- q_d = hip_check (hip .hipMalloc (q_num_bytes ))
265- k_d = hip_check (hip .hipMalloc (k_num_bytes ))
266- v_d = hip_check (hip .hipMalloc (v_num_bytes ))
267- l_d = hip_check (hip .hipMalloc (l_num_bytes ))
268- m_d = hip_check (hip .hipMalloc (m_num_bytes ))
269- O_d = hip_check (hip .hipMalloc (O_num_bytes ))
270-
271289stream = 0
272290
273291times = {
274292 flash_attention : 0 ,
275293}
276- # random.shuffle(kernels)
277- runs = 16
294+ runs = 32
278295for kernel in times :
279296 for i in range (runs ):
280297 function = hip_check (
281298 hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ())
282299 )
283300 hip_check (hip .hipDeviceSynchronize ())
284301
285- for d , h , num_bytes in zip (
286- [q_d , k_d , v_d , l_d , m_d , O_d ],
287- [q_h , k_h , v_h , l_h , m_h , O_h ],
288- [
289- q_num_bytes ,
290- k_num_bytes ,
291- v_num_bytes ,
292- l_num_bytes ,
293- m_num_bytes ,
294- O_num_bytes ,
295- ],
296- ):
297- hip_check (
298- hip .hipMemcpy (d , h , num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice )
299- )
300-
301302 (
302303 (
303304 blocks_per_grid_x ,
@@ -312,6 +313,10 @@ def flash_attention(
312313 shared_memory ,
313314 ) = launch_params [kernel .__name__ ]
314315
316+ host , device = init_copy_host_device ()
317+ q_h , k_h , v_h , * _ = host
318+ correct = manual_attn (q_h , k_h , v_h )
319+
315320 time_compute = launch_kernel (
316321 function .as_c_void_p (),
317322 blocks_per_grid_x ,
@@ -322,36 +327,16 @@ def flash_attention(
322327 threads_per_block_z ,
323328 stream ,
324329 shared_memory ,
325- q_d ,
326- k_d ,
327- v_d ,
328- l_d ,
329- m_d ,
330- O_d ,
330+ * device ,
331331 )
332332
333- hip_check (
334- hip .hipMemcpy (
335- l_h , l_d , l_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
336- )
337- )
338- hip_check (
339- hip .hipMemcpy (
340- m_h , m_d , m_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
341- )
342- )
343- hip_check (
344- hip .hipMemcpy (
345- O_h , O_d , O_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
346- )
347- )
348- correct = manual_attn (q_h , k_h , v_h )
333+ * _ , O_h = copy_device_host (host , device )
349334 if not np .allclose (correct , O_h ):
350335 print ("correct" , correct )
351- print ("l_h" , l_h )
352- print ("m_h" , m_h )
353336 print ("output" , O_h )
354337 print (f"{ kernel .__name__ } failed" )
338+ else :
339+ print (f"{ kernel .__name__ } : { time_compute :.03f} ms" )
355340
356341 times [kernel ] += time_compute
357342
0 commit comments