|
| 1 | +import mlir.extras.types as T |
| 2 | +import numpy as np |
| 3 | +from hip import hip |
| 4 | +from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr |
| 5 | + |
| 6 | +from mlir.extras.ast.canonicalize import canonicalize |
| 7 | +from mlir.extras.context import RAIIMLIRContextModule |
| 8 | +from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm |
| 9 | + |
| 10 | +# noinspection PyUnresolvedReferences |
| 11 | +from mlir.extras.dialects.ext.gpu import ( |
| 12 | + all_reduce, |
| 13 | + wait, |
| 14 | + thread_attr as thread, |
| 15 | + block_idx, |
| 16 | + thread_idx, |
| 17 | + grid_dim, |
| 18 | + block_dim, |
| 19 | + func as gpu_func, |
| 20 | + set_container_module, |
| 21 | + launch, |
| 22 | + all_reduce_, |
| 23 | + module, |
| 24 | + get_compile_object_bytes, |
| 25 | + lds_space, |
| 26 | +) |
| 27 | +from mlir.extras.runtime.passes import run_pipeline, Pipeline |
| 28 | +from mlir.extras.util import find_ops |
| 29 | + |
| 30 | +# noinspection PyUnresolvedReferences |
| 31 | +from util import hip_check, launch_kernel, hip_synchronize |
| 32 | + |
| 33 | +# just so it doesn't get DCE'd by black/reformat |
| 34 | +# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable |
| 35 | +_ = memref |
| 36 | + |
| 37 | +ctx = RAIIMLIRContextModule() |
| 38 | +set_container_module(ctx.module) |
| 39 | + |
| 40 | +props = hip.hipDeviceProp_t() |
| 41 | +hip_check(hip.hipGetDeviceProperties(props, 0)) |
| 42 | +arch = props.gcnArchName.decode() |
| 43 | + |
| 44 | + |
| 45 | +# just a default attr - actual target is set blow |
| 46 | +@module("kernels", [f'#rocdl.target<abi = "500">']) |
| 47 | +def gpu_module(): |
| 48 | + pass |
| 49 | + |
| 50 | + |
| 51 | +ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0]) |
| 52 | +ip.__enter__() |
| 53 | + |
| 54 | +batch_size = 16 |
| 55 | +n_head = 12 |
| 56 | +seq_len = 64 |
| 57 | +head_embd = 64 |
| 58 | + |
| 59 | +Bc = 32 |
| 60 | +Br = 32 |
| 61 | + |
| 62 | +B = batch_size |
| 63 | +nh = n_head |
| 64 | +N = seq_len |
| 65 | +d = head_embd |
| 66 | + |
| 67 | +import math |
| 68 | + |
| 69 | +Tc = math.ceil(N / Bc) |
| 70 | +Tr = math.ceil(N / Br) |
| 71 | +softmax_scale = 1.0 / math.sqrt(d) |
| 72 | +tile_size = Bc * d # size of Qi, Kj, Vj |
| 73 | + |
| 74 | + |
| 75 | +def softmax(x, axis=None): |
| 76 | + x_max = np.amax(x, axis=axis, keepdims=True) |
| 77 | + exp_x_shifted = np.exp(x - x_max) |
| 78 | + return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) |
| 79 | + |
| 80 | + |
| 81 | +def manual_attn(q, k, v): |
| 82 | + # the kernel below overwrites the global math......... |
| 83 | + import math |
| 84 | + |
| 85 | + q = q.reshape(batch_size, n_head, seq_len, head_embd) |
| 86 | + k = k.reshape(batch_size, n_head, seq_len, head_embd) |
| 87 | + v = v.reshape(batch_size, n_head, seq_len, head_embd) |
| 88 | + |
| 89 | + att = q @ k.transpose(0, 1, -2, -1) * (1.0 / math.sqrt(k.shape[-1])) |
| 90 | + att = softmax(att, axis=-1) |
| 91 | + y = att @ v |
| 92 | + return y.flatten() |
| 93 | + |
| 94 | + |
| 95 | +from mlir.dialects import math |
| 96 | + |
| 97 | + |
| 98 | +# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu |
| 99 | +@gpu_func(emit=True) |
| 100 | +@canonicalize(using=[scf.canonicalizer, arith.canonicalizer]) |
| 101 | +def flash_attention( |
| 102 | + Q: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), |
| 103 | + K: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), |
| 104 | + V: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), |
| 105 | + l: T.memref(B * nh * N, T.f32()), |
| 106 | + m: T.memref(B * nh * N, T.f32()), |
| 107 | + O: T.memref(batch_size * n_head * seq_len * head_embd, T.f32()), |
| 108 | +): |
| 109 | + tx = thread_idx.x |
| 110 | + bx = block_idx.x |
| 111 | + by = block_idx.y # batch and head index |
| 112 | + |
| 113 | + # Offset into Q,K,V,O,l,m - different for each batch and head |
| 114 | + qkv_offset = (bx * grid_dim.y * N * d) + (by * N * d) # gridDim.y = nh |
| 115 | + lm_offset = (bx * grid_dim.y * N) + (by * N) # offset for l and m |
| 116 | + |
| 117 | + # Define SRAM for Q,K,V,S |
| 118 | + sram = gpu.dynamic_shared_memory() |
| 119 | + Qi = memref.view(sram, (tile_size,), dtype=T.f32()) |
| 120 | + Kj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 1) |
| 121 | + Vj = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 2) |
| 122 | + S = memref.view(sram, (tile_size,), dtype=T.f32(), shift=tile_size * 3) |
| 123 | + |
| 124 | + for j in scf.range_(0, Tc): |
| 125 | + # Load Kj, Vj to SRAM |
| 126 | + for x in scf.range_(0, d): |
| 127 | + Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x] |
| 128 | + Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x] |
| 129 | + |
| 130 | + gpu.barrier() # such that the inner loop can use the correct Kj, Vj |
| 131 | + |
| 132 | + for i in scf.range_(0, Tr): |
| 133 | + # Load Qi to SRAM, l and m to registers |
| 134 | + for x in scf.range_(0, d): |
| 135 | + Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x] |
| 136 | + |
| 137 | + row_m_prev = m[lm_offset + (Br * i) + tx] |
| 138 | + row_l_prev = l[lm_offset + (Br * i) + tx] |
| 139 | + |
| 140 | + # S = QK^T, row_m = rowmax(S) |
| 141 | + row_m: T.f32() = float(np.finfo(np.float32).min) |
| 142 | + for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]): |
| 143 | + sum: T.f32() = 0.0 |
| 144 | + for x, sum, _ in scf.range_(0, d, iter_args=[sum]): |
| 145 | + sum += Qi[(tx * d) + x] * Kj[(y * d) + x] |
| 146 | + sum = yield sum |
| 147 | + |
| 148 | + sum *= softmax_scale |
| 149 | + S[(Bc * tx) + y] = sum |
| 150 | + |
| 151 | + if sum > row_m: |
| 152 | + row_m_ = yield sum |
| 153 | + else: |
| 154 | + row_m_ = yield row_m |
| 155 | + |
| 156 | + row_m = yield row_m_ |
| 157 | + |
| 158 | + # P = exp(S - row_m), row_l = rowsum(P) |
| 159 | + row_l: T.f32() = 0.0 |
| 160 | + for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]): |
| 161 | + S[(Bc * tx) + y] = math.exp(S[(Bc * tx) + y] - row_m) |
| 162 | + row_l += S[(Bc * tx) + y] |
| 163 | + row_l = yield row_l |
| 164 | + |
| 165 | + # Compute new m and l |
| 166 | + row_m_new = arith.maximumf(row_m_prev, row_m) |
| 167 | + row_l_new = (math.exp(row_m_prev - row_m_new) * row_l_prev) + ( |
| 168 | + math.exp(row_m - row_m_new) * row_l |
| 169 | + ) |
| 170 | + div = 1.0 / row_l_new |
| 171 | + c = row_l_prev * math.exp(row_m_prev - row_m_new) |
| 172 | + |
| 173 | + # Write O, l, m to HBM |
| 174 | + for x in scf.range_(0, d): |
| 175 | + pv: T.f32() = 0.0 # Pij * Vj |
| 176 | + for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]): |
| 177 | + pv += S[(Bc * tx) + y] * Vj[(y * d) + x] |
| 178 | + pv = yield pv |
| 179 | + |
| 180 | + ii = qkv_offset + (tile_size * i) + (tx * d) + x |
| 181 | + O[ii] = div * ((c * O[ii]) + (math.exp(row_m - row_m_new) * pv)) |
| 182 | + |
| 183 | + gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop |
| 184 | + |
| 185 | + m[lm_offset + (Br * i) + tx] = row_m_new |
| 186 | + l[lm_offset + (Br * i) + tx] = row_l_new |
| 187 | + |
| 188 | + # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop |
| 189 | + # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop |
| 190 | + |
| 191 | + |
| 192 | +ip.__exit__(None, None, None) |
| 193 | + |
| 194 | +sram_size = 4 * tile_size * np.float32().itemsize |
| 195 | + |
| 196 | +launch_params = { |
| 197 | + flash_attention.__name__: ( |
| 198 | + (B, nh, 1), |
| 199 | + (Bc, 1, 1), |
| 200 | + sram_size, |
| 201 | + ) |
| 202 | +} |
| 203 | + |
| 204 | +simplified_module = run_pipeline( |
| 205 | + ctx.module, |
| 206 | + Pipeline() |
| 207 | + .canonicalize() |
| 208 | + .cse() |
| 209 | + .loop_invariant_code_motion() |
| 210 | + .loop_invariant_subset_hoisting() |
| 211 | + .rocdl_attach_target(chip=arch, O=3, abi="500"), |
| 212 | +) |
| 213 | + |
| 214 | +lowered_module = run_pipeline( |
| 215 | + simplified_module, |
| 216 | + Pipeline() |
| 217 | + .Gpu( |
| 218 | + Pipeline().convert_gpu_to_rocdl( |
| 219 | + use_bare_ptr_memref_call_conv=True, |
| 220 | + runtime="HIP", |
| 221 | + ) |
| 222 | + ) |
| 223 | + .gpu_to_llvm() |
| 224 | + .lower_to_llvm(), |
| 225 | + # .Nested("llvm.func", Pipeline().sroa()), |
| 226 | +) |
| 227 | + |
| 228 | +# print(lowered_module) |
| 229 | +gep = find_ops(lowered_module.operation, lambda o: isinstance(o.opview, llvm.GEPOp)) |
| 230 | +for g in gep: |
| 231 | + g.attributes["inbounds"] = UnitAttr.get() |
| 232 | + |
| 233 | +kernel_funcs = find_ops( |
| 234 | + lowered_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp) |
| 235 | +) |
| 236 | +for k in kernel_funcs: |
| 237 | + if k.sym_name.value != flash_attention.__name__: |
| 238 | + continue |
| 239 | + _, thread_dims, _ = launch_params[k.sym_name.value] |
| 240 | + k.attributes["rocdl.max_flat_work_group_size"] = IntegerAttr.get( |
| 241 | + T.index(), np.prod(thread_dims) |
| 242 | + ) |
| 243 | + |
| 244 | +lowered_module = run_pipeline(lowered_module, Pipeline().gpu_module_to_binary()) |
| 245 | +hsaco = get_compile_object_bytes(lowered_module) |
| 246 | + |
| 247 | +hip_module = hip_check(hip.hipModuleLoadData(hsaco)) |
| 248 | + |
| 249 | +q_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( |
| 250 | + dtype=np.float32 |
| 251 | +) |
| 252 | +k_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( |
| 253 | + dtype=np.float32 |
| 254 | +) |
| 255 | +v_h = np.random.randint(0, 10, (batch_size * n_head * seq_len * head_embd)).astype( |
| 256 | + dtype=np.float32 |
| 257 | +) |
| 258 | +l_h = np.zeros((B * nh * N), dtype=np.float32) |
| 259 | +m_h = np.full((B * nh * N), float(np.finfo(np.float32).min), dtype=np.float32) |
| 260 | +O_h = np.zeros_like(q_h, dtype=np.float32) |
| 261 | + |
| 262 | +q_num_bytes = q_h.size * q_h.itemsize |
| 263 | +k_num_bytes = k_h.size * k_h.itemsize |
| 264 | +v_num_bytes = v_h.size * v_h.itemsize |
| 265 | +l_num_bytes = l_h.size * l_h.itemsize |
| 266 | +m_num_bytes = m_h.size * m_h.itemsize |
| 267 | +O_num_bytes = O_h.size * O_h.itemsize |
| 268 | + |
| 269 | +q_d = hip_check(hip.hipMalloc(q_num_bytes)) |
| 270 | +k_d = hip_check(hip.hipMalloc(k_num_bytes)) |
| 271 | +v_d = hip_check(hip.hipMalloc(v_num_bytes)) |
| 272 | +l_d = hip_check(hip.hipMalloc(l_num_bytes)) |
| 273 | +m_d = hip_check(hip.hipMalloc(m_num_bytes)) |
| 274 | +O_d = hip_check(hip.hipMalloc(O_num_bytes)) |
| 275 | + |
| 276 | +stream = 0 |
| 277 | + |
| 278 | +times = { |
| 279 | + flash_attention: 0, |
| 280 | +} |
| 281 | +# random.shuffle(kernels) |
| 282 | +runs = 16 |
| 283 | +for kernel in times: |
| 284 | + for i in range(runs): |
| 285 | + function = hip_check( |
| 286 | + hip.hipModuleGetFunction(hip_module, kernel.__name__.encode()) |
| 287 | + ) |
| 288 | + hip_check(hip.hipDeviceSynchronize()) |
| 289 | + |
| 290 | + for d, h, num_bytes in zip( |
| 291 | + [q_d, k_d, v_d, l_d, m_d, O_d], |
| 292 | + [q_h, k_h, v_h, l_h, m_h, O_h], |
| 293 | + [ |
| 294 | + q_num_bytes, |
| 295 | + k_num_bytes, |
| 296 | + v_num_bytes, |
| 297 | + l_num_bytes, |
| 298 | + m_num_bytes, |
| 299 | + O_num_bytes, |
| 300 | + ], |
| 301 | + ): |
| 302 | + hip_check( |
| 303 | + hip.hipMemcpy(d, h, num_bytes, hip.hipMemcpyKind.hipMemcpyHostToDevice) |
| 304 | + ) |
| 305 | + |
| 306 | + ( |
| 307 | + ( |
| 308 | + blocks_per_grid_x, |
| 309 | + blocks_per_grid_y, |
| 310 | + blocks_per_grid_z, |
| 311 | + ), |
| 312 | + ( |
| 313 | + threads_per_block_x, |
| 314 | + threads_per_block_y, |
| 315 | + threads_per_block_z, |
| 316 | + ), |
| 317 | + shared_memory, |
| 318 | + ) = launch_params[kernel.__name__] |
| 319 | + |
| 320 | + time_compute = launch_kernel( |
| 321 | + function.as_c_void_p(), |
| 322 | + blocks_per_grid_x, |
| 323 | + blocks_per_grid_y, |
| 324 | + blocks_per_grid_z, |
| 325 | + threads_per_block_x, |
| 326 | + threads_per_block_y, |
| 327 | + threads_per_block_z, |
| 328 | + stream, |
| 329 | + shared_memory, |
| 330 | + q_d, |
| 331 | + k_d, |
| 332 | + v_d, |
| 333 | + l_d, |
| 334 | + m_d, |
| 335 | + O_d, |
| 336 | + ) |
| 337 | + |
| 338 | + hip_check( |
| 339 | + hip.hipMemcpy( |
| 340 | + l_h, l_d, l_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost |
| 341 | + ) |
| 342 | + ) |
| 343 | + hip_check( |
| 344 | + hip.hipMemcpy( |
| 345 | + m_h, m_d, m_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost |
| 346 | + ) |
| 347 | + ) |
| 348 | + hip_check( |
| 349 | + hip.hipMemcpy( |
| 350 | + O_h, O_d, O_num_bytes, hip.hipMemcpyKind.hipMemcpyDeviceToHost |
| 351 | + ) |
| 352 | + ) |
| 353 | + correct = manual_attn(q_h, k_h, v_h) |
| 354 | + if not np.allclose(correct, O_h): |
| 355 | + print("correct", correct) |
| 356 | + print("l_h", l_h) |
| 357 | + print("m_h", m_h) |
| 358 | + print("output", O_h) |
| 359 | + print(f"{kernel.__name__} failed") |
| 360 | + |
| 361 | + times[kernel] += time_compute |
| 362 | + |
| 363 | +for k in times: |
| 364 | + times[k] /= runs |
| 365 | + |
| 366 | +for k, v in times.items(): |
| 367 | + print(f"{k.__name__}: {v:.03f}ms") |
0 commit comments