|
13 | 13 | mlir_mod_ctx, |
14 | 14 | MLIRContext, |
15 | 15 | ) |
16 | | -from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg |
| 16 | +from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg, vector |
17 | 17 | from mlir.extras.dialects.ext.gpu import ( |
18 | 18 | block_idx, |
19 | 19 | thread_idx, |
20 | 20 | block_dim, |
21 | 21 | get_compile_object_bytes, |
22 | 22 | ) |
| 23 | +from mlir.extras.dialects.ext.memref import S |
23 | 24 | from mlir.extras.dialects.ext.scf import range_ |
24 | 25 | from mlir.extras.runtime.passes import Pipeline, run_pipeline |
25 | 26 |
|
@@ -47,23 +48,62 @@ def compile_module(module, enable_ir_printing=False, print_ptx_=False): |
47 | 48 | print_ptx_ = True |
48 | 49 | mod = run_pipeline( |
49 | 50 | module, |
| 51 | + # if you're not using vectors you can just uncomment the gpu-lower-to-nvvm-pipeline below |
50 | 52 | Pipeline() |
51 | 53 | .convert_linalg_to_loops() |
| 54 | + .convert_nvgpu_to_nvvm() |
| 55 | + .gpu_kernel_outlining() |
| 56 | + .convert_vector_to_scf() |
| 57 | + .convert_scf_to_cf() |
| 58 | + .convert_nvvm_to_llvm() |
| 59 | + .convert_func_to_llvm() |
| 60 | + .expand_strided_metadata() |
52 | 61 | .add_pass( |
53 | | - "gpu-lower-to-nvvm-pipeline", |
54 | | - # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18 |
| 62 | + "nvvm-attach-target", |
55 | 63 | **{ |
56 | | - "cubin-chip": "sm_80", |
57 | | - "cubin-features": "+ptx83", |
58 | | - "cubin-format": "isa", |
59 | | - "kernel-bare-ptr-calling-convention": "1", |
60 | | - "opt-level": "2", |
61 | | - # "cubin-format": "fatbin", |
62 | | - # "cubin-format": "bin", |
| 64 | + "chip": "sm_80", |
| 65 | + "features": "+ptx83", |
| 66 | + "O": "2", |
63 | 67 | }, |
64 | | - ), |
| 68 | + ) |
| 69 | + .lower_affine() |
| 70 | + .convert_arith_to_llvm() |
| 71 | + .convert_index_to_llvm() |
| 72 | + .canonicalize() |
| 73 | + .cse() |
| 74 | + .Gpu( |
| 75 | + Pipeline() |
| 76 | + .strip_debuginfo() |
| 77 | + # TODO(max): upstream this (add to gpu pipeline) |
| 78 | + # vector.transfer |
| 79 | + .convert_vector_to_llvm() |
| 80 | + .convert_gpu_to_nvvm(use_bare_ptr_memref_call_conv=True) |
| 81 | + .canonicalize() |
| 82 | + .cse() |
| 83 | + .reconcile_unrealized_casts() |
| 84 | + ) |
| 85 | + .gpu_to_llvm(use_bare_pointers_for_kernels=True) |
| 86 | + .gpu_module_to_binary(format="isa") |
| 87 | + .canonicalize() |
| 88 | + .cse() |
| 89 | + .reconcile_unrealized_casts() |
| 90 | + # .add_pass( |
| 91 | + # "gpu-lower-to-nvvm-pipeline", |
| 92 | + # # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18 |
| 93 | + # **{ |
| 94 | + # "cubin-chip": "sm_80", |
| 95 | + # "cubin-features": "+ptx83", |
| 96 | + # "cubin-format": "isa", |
| 97 | + # "kernel-bare-ptr-calling-convention": "1", |
| 98 | + # "opt-level": "2", |
| 99 | + # # "cubin-format": "fatbin", |
| 100 | + # # "cubin-format": "bin", |
| 101 | + # }, |
| 102 | + # ) |
| 103 | + , |
65 | 104 | enable_ir_printing=enable_ir_printing, |
66 | 105 | ) |
| 106 | + |
67 | 107 | if print_ptx_: |
68 | 108 | print_ptx(mod) |
69 | 109 |
|
@@ -420,6 +460,102 @@ def sgemm_shared_mem_2d_block_tiling[ |
420 | 460 | ) |
421 | 461 |
|
422 | 462 |
|
| 463 | +@gpu.func |
| 464 | +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) |
| 465 | +def sgemm_shared_mem_2d_block_tiling_vectorize[ |
| 466 | + M, K, N, dtype, BM, BN, BK, TM, TN |
| 467 | +](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): |
| 468 | + VECTOR_WIDTH = 4 |
| 469 | + DTYPE_WIDTH = dtype.width // 8 |
| 470 | + |
| 471 | + # ld.global.v4.u32 and st.global.v4.f32 emitted only input args are aligned |
| 472 | + # alignment for cupy is 512 bytes https://github.com/cupy/cupy/blob/59e6c2b2e0c722b09c7a7af13f908942ef7806cc/cupy/cuda/memory.pyx#L805-L809 |
| 473 | + # so we're good |
| 474 | + memref.assume_alignment(A, VECTOR_WIDTH * DTYPE_WIDTH) |
| 475 | + memref.assume_alignment(B, VECTOR_WIDTH * DTYPE_WIDTH) |
| 476 | + memref.assume_alignment(C, VECTOR_WIDTH * DTYPE_WIDTH) |
| 477 | + |
| 478 | + base = gpu.dynamic_shared_memory() |
| 479 | + base = memref.memory_space_cast(T.memref(S, element_type=T.i8()), base) |
| 480 | + |
| 481 | + # transpose A |
| 482 | + A_shared = memref.view(base, (BK, BM), dtype=dtype) |
| 483 | + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) |
| 484 | + |
| 485 | + c_row = block_idx.y * BM |
| 486 | + c_col = block_idx.x * BN |
| 487 | + |
| 488 | + tid = gpu.thread_id() |
| 489 | + # BN/TN are the number of threads to span a column |
| 490 | + thread_col = tid % (BN // TN) |
| 491 | + thread_row = tid / (BN // TN) |
| 492 | + |
| 493 | + # calculating the indices that this thread will load into SMEM |
| 494 | + # we'll load 128bit / 32bit = 4 elements per thread at each step |
| 495 | + inner_col_A = tid % (BK // VECTOR_WIDTH) # warp-level GMEM coalescing |
| 496 | + inner_row_A = tid / (BK // VECTOR_WIDTH) |
| 497 | + inner_col_B = tid % (BN // VECTOR_WIDTH) # warp-level GMEM coalescing |
| 498 | + inner_row_B = tid / (BN // VECTOR_WIDTH) |
| 499 | + |
| 500 | + thread_results = memref.alloca((TM, TN), dtype) |
| 501 | + linalg.fill(0, thread_results) |
| 502 | + |
| 503 | + reg_M = memref.alloca((TM,), dtype) |
| 504 | + linalg.fill(0, reg_M) |
| 505 | + |
| 506 | + reg_N = memref.alloca((TN,), dtype) |
| 507 | + linalg.fill(0, reg_N) |
| 508 | + |
| 509 | + for bk_idx in range_(0, K, BK): |
| 510 | + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] |
| 511 | + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] |
| 512 | + |
| 513 | + A_vec = vector.load( |
| 514 | + T.vector(VECTOR_WIDTH, dtype), A_, [inner_row_A, inner_col_A * VECTOR_WIDTH] |
| 515 | + ) |
| 516 | + for j in range(VECTOR_WIDTH): |
| 517 | + # transpose A while loading it |
| 518 | + A_shared[inner_col_A * VECTOR_WIDTH + j, inner_row_A] = A_vec[j] |
| 519 | + |
| 520 | + B_vec = vector.load( |
| 521 | + T.vector(VECTOR_WIDTH, dtype), B_, [inner_row_B, inner_col_B * VECTOR_WIDTH] |
| 522 | + ) |
| 523 | + vector.store(B_vec, B_shared, [inner_row_B, inner_col_B * VECTOR_WIDTH]) |
| 524 | + |
| 525 | + gpu.barrier() |
| 526 | + |
| 527 | + for dot_idx in range_(BK): |
| 528 | + for i in range_(TM): |
| 529 | + reg_M[i] = A_shared[dot_idx, thread_row * TM + i] |
| 530 | + |
| 531 | + for i in range_(TN): |
| 532 | + reg_N[i] = B_shared[dot_idx, thread_col * TN + i] |
| 533 | + |
| 534 | + for res_idx_m in range_(TM): |
| 535 | + for res_idx_n in range_(TN): |
| 536 | + thread_results[res_idx_m, res_idx_n] += ( |
| 537 | + reg_M[res_idx_m] * reg_N[res_idx_n] |
| 538 | + ) |
| 539 | + |
| 540 | + gpu.barrier() |
| 541 | + |
| 542 | + one = arith.constant(1.0, type=dtype) |
| 543 | + C_ = C[c_row : c_row + BM, c_col : c_col + BN] |
| 544 | + |
| 545 | + for res_idx_m in range_(TM): |
| 546 | + for res_idx_n in range_(0, TN, VECTOR_WIDTH): |
| 547 | + tmp = vector.load( |
| 548 | + T.vector(VECTOR_WIDTH, dtype), |
| 549 | + C_, |
| 550 | + [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n], |
| 551 | + ) |
| 552 | + for j in range(VECTOR_WIDTH): |
| 553 | + tmp[j] = thread_results[res_idx_m, res_idx_n + j] + one |
| 554 | + vector.store( |
| 555 | + tmp, C_, [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n] |
| 556 | + ) |
| 557 | + |
| 558 | + |
423 | 559 | def prepare_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): |
424 | 560 | dtype = T.f32() |
425 | 561 | npy_dtype = np.float32 |
@@ -560,6 +696,7 @@ def run_eval( |
560 | 696 | for k in [ |
561 | 697 | sgemm_shared_mem_1d_block_tiling, |
562 | 698 | sgemm_shared_mem_2d_block_tiling, |
| 699 | + sgemm_shared_mem_2d_block_tiling_vectorize, |
563 | 700 | ]: |
564 | 701 | print(f"\n{k.__name__}") |
565 | 702 | for s in sizes: |
|
0 commit comments