From 5ab83738787ab2d286dee84a2794418610a6a688 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 22 Apr 2024 14:20:18 -0500 Subject: [PATCH 1/5] [wip] finish cuda opt --- examples/cuda_matmul_opt.py | 310 ++++++++++++++++++++++++++--- mlir/extras/ast/util.py | 3 + mlir/extras/dialects/ext/arith.py | 6 +- mlir/extras/dialects/ext/memref.py | 2 +- mlir/extras/dialects/ext/scf.py | 10 +- 5 files changed, 298 insertions(+), 33 deletions(-) diff --git a/examples/cuda_matmul_opt.py b/examples/cuda_matmul_opt.py index 62865e85..f74d4497 100644 --- a/examples/cuda_matmul_opt.py +++ b/examples/cuda_matmul_opt.py @@ -13,7 +13,7 @@ mlir_mod_ctx, MLIRContext, ) -from mlir.extras.dialects.ext import arith, memref, gpu, scf +from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg from mlir.extras.dialects.ext.gpu import ( block_idx, thread_idx, @@ -47,7 +47,9 @@ def compile_module(module, enable_ir_printing=False, print_ptx_=False): print_ptx_ = True mod = run_pipeline( module, - Pipeline().add_pass( + Pipeline() + .convert_linalg_to_loops() + .add_pass( "gpu-lower-to-nvvm-pipeline", # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18 **{ @@ -124,7 +126,7 @@ def sgemm_naive_row_order[ @gpu.func @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) def sgemm_coalesce[ - M, K, N, dtype, BLOCK_SIZE + M, K, N, dtype, BLOCK_SIZE: 32 ](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): tid = gpu.thread_id() @@ -183,7 +185,7 @@ def sgemm_coalesce[ @gpu.func @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) def sgemm_coalesce_transpose_B[ - M, K, N, dtype, BLOCK_SIZE + M, K, N, dtype, BLOCK_SIZE: 32 ](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): tid = gpu.thread_id() @@ -205,7 +207,7 @@ def sgemm_coalesce_transpose_B[ @gpu.func @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) def sgemm_shared_mem_block[ - M, K, N, dtype, BLOCK_SIZE + M, K, N, dtype, BLOCK_SIZE: 32 ](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): # allocate buffer for current block in fast shared mem # shared mem is shared between all threads in a block @@ -224,7 +226,6 @@ def sgemm_shared_mem_block[ c_row = block_idx.x * BLOCK_SIZE c_col = block_idx.y * BLOCK_SIZE - one = arith.constant(1.0, type=dtype) tmp = arith.constant(0, type=dtype) for bk_idx, tmp in range_(0, K, BLOCK_SIZE, iter_args=[tmp]): @@ -232,7 +233,7 @@ def sgemm_shared_mem_block[ B_ = B[bk_idx : bk_idx + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE] # Have each thread load one of the elements in A & B - # Make the threadCol (=threadIdx.x) the consecutive index + # Make the threadCol (=thread_idx.x) the consecutive index # to allow global memory access coalescing A_shared[thread_row, thread_col] = A_[thread_row, thread_col] B_shared[thread_row, thread_col] = B_[thread_row, thread_col] @@ -241,8 +242,8 @@ def sgemm_shared_mem_block[ gpu.barrier() # execute the dotproduct on the currently cached block - for k, tmp in range_(BLOCK_SIZE, iter_args=[tmp]): - tmp += A_shared[thread_row, k] * B_shared[k, thread_col] + for dot_idx, tmp in range_(BLOCK_SIZE, iter_args=[tmp]): + tmp += A_shared[thread_row, dot_idx] * B_shared[dot_idx, thread_col] tmp = yield tmp # need to sync again at the end, to avoid faster threads @@ -251,50 +252,250 @@ def sgemm_shared_mem_block[ tmp = yield tmp + one = arith.constant(1.0, type=dtype) C_ = C[c_row : c_row + BLOCK_SIZE, c_col : c_col + BLOCK_SIZE] C_[thread_row, thread_col] = tmp + one -def main(ctx: MLIRContext, M, K, N, BLOCK_SIZE=32, repeat_times=None): - if repeat_times is None: - repeat_times = 50 +def prepare_non_tiled_kernel(ctx: MLIRContext, kernel, M, K, N, BLOCK_SIZE=32): + dtype = T.f32() + npy_dtype = np.float32 + + gpu.set_container_module(ctx.module) + + @gpu.module("matmul", ["#nvvm.target"]) + def matmul_mod(): + kernel[M, K, N, dtype].emit() + + # print(ctx.module) + # print(ctx.module.operation.verify()) + # exit() + + kernel_name = kernel.__name__ + compiled_module = compile_module(ctx.module) + cuda_func = build_cuda_func(compiled_module, kernel_name) + # print_ptx(compiled_module) + + grid_dims = (math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE)) + block_dims = (BLOCK_SIZE, BLOCK_SIZE) + + if "shared" in kernel_name: + shared_mem = 2 * BLOCK_SIZE * BLOCK_SIZE * npy_dtype().nbytes + else: + shared_mem = 0 + + return ( + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + "transpose_B" in kernel_name, + ) + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_1d_block_tiling[ + M, K, N, dtype, BM, BN, BK, TM +](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): + base = gpu.dynamic_shared_memory() + A_shared = memref.view(base, (BM, BK), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + tid = gpu.thread_id() + thread_col = tid % BN + thread_row = tid / BN + + inner_col_A = tid % BK # warp-level GMEM coalescing + inner_row_A = tid / BK + inner_col_B = tid % BN # warp-level GMEM coalescing + inner_row_B = tid / BN + + thread_results = memref.alloca((TM,), dtype) + linalg.fill(0, thread_results) + + for bk_idx in range_(0, K, BK): + # Move blocktile to beginning of A's row and B's column + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + A_shared[inner_row_A, inner_col_A] = A_[inner_row_A, inner_col_A] + B_shared[inner_row_B, inner_col_B] = B_[inner_row_B, inner_col_B] + + gpu.barrier() + + for dot_idx in range_(BK): + tmp_B = B_shared[dot_idx, thread_col] + for res_idx, tmp_B in range_(TM, iter_args=[tmp_B]): + thread_results[res_idx] += ( + A_shared[thread_row * TM + res_idx, dot_idx] * tmp_B + ) + yield tmp_B + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BM, c_col : c_col + BN] + for res_idx in range_(TM): + C_[thread_row * TM + res_idx, thread_col] = thread_results[res_idx] + one + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_2d_block_tiling[ + M, K, N, dtype, BM, BN, BK, TM, TN +](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): + base = gpu.dynamic_shared_memory() + A_shared = memref.view(base, (BM, BK), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + total_results_blocktile = BM * BN + num_threads_blocktile = total_results_blocktile // (TM * TN) + + tid = gpu.thread_id() + # BN/TN are the number of threads to span a column + thread_col = tid % (BN // TN) + thread_row = tid / (BN // TN) + + inner_col_A = tid % BK # warp-level GMEM coalescing + inner_row_A = tid / BK + stride_A = num_threads_blocktile // BK + + inner_col_B = tid % BN # warp-level GMEM coalescing + inner_row_B = tid / BN + stride_B = num_threads_blocktile // BN + + thread_results = memref.alloca((TM, TN), dtype) + linalg.fill(0, thread_results) + + reg_M = memref.alloca((TM,), dtype) + linalg.fill(0, reg_M) + + reg_N = memref.alloca((TN,), dtype) + linalg.fill(0, reg_N) + + for bk_idx in range_(0, K, BK): + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + for load_offset in range_(0, BM, stride_A): + A_shared[inner_row_A + load_offset, inner_col_A] = A_[ + inner_row_A + load_offset, inner_col_A + ] + for load_offset in range_(0, BK, stride_B): + B_shared[inner_row_B + load_offset, inner_col_B] = B_[ + inner_row_B + load_offset, inner_col_B + ] + + gpu.barrier() + + for dot_idx in range_(BK): + for i in range_(TM): + reg_M[i] = A_shared[thread_row * TM + i, dot_idx] + for i in range_(TN): + reg_N[i] = B_shared[dot_idx, thread_col * TN + i] + + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + thread_results[res_idx_m, res_idx_n] += ( + reg_M[res_idx_m] * reg_N[res_idx_n] + ) + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BM, c_col : c_col + BN] + + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + C_[thread_row * TM + res_idx_m, thread_col * TN + res_idx_n] = ( + thread_results[res_idx_m, res_idx_n] + one + ) + + +def prepare_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): dtype = T.f32() npy_dtype = np.float32 + kernel_name = kernel.__name__ gpu.set_container_module(ctx.module) + BK = 8 + TM = 8 + TN = 8 + if "2d" in kernel_name and M >= 128 and N >= 128: + BM = 128 + BN = 128 + else: + BM = 64 + BN = 64 + @gpu.module("matmul", ["#nvvm.target"]) def matmul_mod(): - sgemm_shared_mem_block[M, K, N, dtype, BLOCK_SIZE].emit() + kernel[M, K, N, dtype, BM, BN, BK, TM, TN].emit() # print(ctx.module) # print(ctx.module.operation.verify()) # exit() - kernel_name = matmul_mod.opview.body.operations[0].attributes["sym_name"].value compiled_module = compile_module(ctx.module) cuda_func = build_cuda_func(compiled_module, kernel_name) # print_ptx(compiled_module) + grid_dims = (math.ceil(N / BN), math.ceil(M / BM)) + if "2d" in kernel_name: + block_dims = (BM // TM, BN // TN) + else: + block_dims = (BM // TM, BN) + + if "shared" in kernel_name: + shared_mem = ((BM * BK) + (BK * BN)) * npy_dtype().nbytes + else: + shared_mem = 0 + + return ( + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + False, + ) + + +def run_eval( + M, + K, + N, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + repeat_times=None, +): + if repeat_times is None: + repeat_times = 50 + A = np.random.randint(0, 10, (M, K)).astype(npy_dtype) B = np.random.randint(0, 10, (K, N)).astype(npy_dtype) C = np.zeros((M, N)).astype(npy_dtype) dA = cp.asarray(A) - if "transpose_B" in kernel_name: + if transpose_B: dB = cp.asarray(np.ascontiguousarray(B.T)) else: dB = cp.asarray(B) dC = cp.asarray(C) - grid_dims = (math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE)) - block_dims = (BLOCK_SIZE, BLOCK_SIZE) - - if "shared" in kernel_name: - shared_mem = 2 * BLOCK_SIZE * BLOCK_SIZE * npy_dtype().nbytes - else: - shared_mem = 0 - cuda_func( grid_dims, block_dims, @@ -320,15 +521,64 @@ def matmul_mod(): t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu) - print(f"t_gpu={t_gpu / repeat_times:.6f} ms") + print(f"t={t_gpu / repeat_times:.6f} ms") sizes = [128, 256, 512, 1024] repeats = None -for s in sizes: - with ( - mlir_mod_ctx() as ctx, - # enable_debug() - ): - main(ctx, s, s, s, repeat_times=repeats) +for k in [ + sgemm_naive, + sgemm_naive_row_order, + sgemm_coalesce, + sgemm_coalesce_transpose_B, + sgemm_shared_mem_block, +]: + print(f"\n{k.__name__}") + for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_non_tiled_kernel(ctx, k, s, s, s) + ) + run_eval( + s, + s, + s, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + ) + + +for k in [ + sgemm_shared_mem_1d_block_tiling, + sgemm_shared_mem_2d_block_tiling, +]: + print(f"\n{k.__name__}") + for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_tiled_kernel(ctx, k, s, s, s) + ) + run_eval( + s, + s, + s, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + ) diff --git a/mlir/extras/ast/util.py b/mlir/extras/ast/util.py index a0f18246..ce053add 100644 --- a/mlir/extras/ast/util.py +++ b/mlir/extras/ast/util.py @@ -143,6 +143,9 @@ def copy_func(f, new_closure: Dict = None): def append_hidden_node(node_body, new_node): last_statement = node_body[-1] + assert ( + last_statement.end_lineno is not None + ), f"last_statement {ast.unparse(last_statement)} must have end_lineno" new_node = ast.fix_missing_locations( set_lineno(new_node, last_statement.end_lineno) ) diff --git a/mlir/extras/dialects/ext/arith.py b/mlir/extras/dialects/ext/arith.py index 9e4fee72..95f1d954 100644 --- a/mlir/extras/dialects/ext/arith.py +++ b/mlir/extras/dialects/ext/arith.py @@ -1,4 +1,5 @@ import ast +import copy import operator from abc import abstractmethod from copy import deepcopy @@ -513,6 +514,8 @@ def visit_AugAssign( and isinstance(updated_node.value, ast.BinOp) and isinstance(updated_node.value.op, ast.Mult) ): + target = copy.deepcopy(updated_node.target) + target.ctx = ast.Load() updated_node = ast.Assign( targets=[updated_node.target], value=ast_call( @@ -520,10 +523,11 @@ def visit_AugAssign( [ updated_node.value.left, updated_node.value.right, - ast.Name(updated_node.target.id, ast.Load()), + target, ], ), ) + updated_node = ast.fix_missing_locations(updated_node) return updated_node diff --git a/mlir/extras/dialects/ext/memref.py b/mlir/extras/dialects/ext/memref.py index e5920e25..b4ba2d9e 100644 --- a/mlir/extras/dialects/ext/memref.py +++ b/mlir/extras/dialects/ext/memref.py @@ -113,7 +113,7 @@ def __getitem__(self, idx: tuple) -> "MemRef": if idx is None: return expand_shape(self, (0,), loc=loc) - idx = list((idx,) if isinstance(idx, (int, slice)) else idx) + idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx) for i, d in enumerate(idx): if isinstance(d, int): idx[i] = constant(d, index=True, loc=loc) diff --git a/mlir/extras/dialects/ext/scf.py b/mlir/extras/dialects/ext/scf.py index 69eacfb7..f2436b76 100644 --- a/mlir/extras/dialects/ext/scf.py +++ b/mlir/extras/dialects/ext/scf.py @@ -2,7 +2,7 @@ import logging from contextlib import contextmanager from copy import deepcopy -from typing import List +from typing import List, Union, Optional, Sequence from bytecode import ConcreteBytecode @@ -18,6 +18,7 @@ get_op_result_or_op_results, ) from ....dialects.linalg.opdsl.lang.emitter import _is_index_type + # gotta come first from ....dialects.scf import * from ....dialects.scf import _Dialect, yield_ as yield__ @@ -432,6 +433,7 @@ def visit_If(self, updated_node: ast.If) -> ast.If: updated_node.orelse, deepcopy(new_yield) ) + updated_node = ast.fix_missing_locations(updated_node) return updated_node def visit_For(self, updated_node: ast.For) -> ast.For: @@ -439,6 +441,7 @@ def visit_For(self, updated_node: ast.For) -> ast.For: new_yield = ast.Expr(ast.Yield(value=None)) if not is_yield(updated_node.body[-1]): updated_node.body = append_hidden_node(updated_node.body, new_yield) + updated_node = ast.fix_missing_locations(updated_node) return updated_node @@ -480,6 +483,7 @@ def visit_If(self, updated_node: ast.If) -> ast.If: if needs_forward(updated_node.orelse): updated_node.orelse = forward_yield_from_nested_if(updated_node.orelse) + updated_node = ast.fix_missing_locations(updated_node) return updated_node @@ -515,6 +519,10 @@ def visit_While(self, updated_node: ast.While) -> List[ast.AST]: ) new_test = ast.copy_location(new_test, updated_node) updated_node.test = new_test + + updated_node = ast.fix_missing_locations(updated_node) + assign = ast.fix_missing_locations(assign) + return [assign, updated_node] From add8434fd3c72f7e169abfaa9d6b486444d6e1ad Mon Sep 17 00:00:00 2001 From: max Date: Mon, 22 Apr 2024 23:23:36 -0500 Subject: [PATCH 2/5] vectors in gpu --- examples/cuda_matmul_opt.py | 159 +++++++++++++++++++++++++++-- examples/mlir_python_extras.ipynb | 18 ++-- mlir/extras/dialects/ext/gpu.py | 29 ++++-- mlir/extras/dialects/ext/memref.py | 61 ++++++++--- mlir/extras/dialects/ext/scf.py | 13 ++- mlir/extras/dialects/ext/tensor.py | 1 + mlir/extras/dialects/ext/vector.py | 68 +++++++++--- 7 files changed, 289 insertions(+), 60 deletions(-) diff --git a/examples/cuda_matmul_opt.py b/examples/cuda_matmul_opt.py index f74d4497..c283c688 100644 --- a/examples/cuda_matmul_opt.py +++ b/examples/cuda_matmul_opt.py @@ -13,13 +13,14 @@ mlir_mod_ctx, MLIRContext, ) -from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg +from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg, vector from mlir.extras.dialects.ext.gpu import ( block_idx, thread_idx, block_dim, get_compile_object_bytes, ) +from mlir.extras.dialects.ext.memref import S from mlir.extras.dialects.ext.scf import range_ from mlir.extras.runtime.passes import Pipeline, run_pipeline @@ -47,23 +48,62 @@ def compile_module(module, enable_ir_printing=False, print_ptx_=False): print_ptx_ = True mod = run_pipeline( module, + # if you're not using vectors you can just uncomment the gpu-lower-to-nvvm-pipeline below Pipeline() .convert_linalg_to_loops() + .convert_nvgpu_to_nvvm() + .gpu_kernel_outlining() + .convert_vector_to_scf() + .convert_scf_to_cf() + .convert_nvvm_to_llvm() + .convert_func_to_llvm() + .expand_strided_metadata() .add_pass( - "gpu-lower-to-nvvm-pipeline", - # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18 + "nvvm-attach-target", **{ - "cubin-chip": "sm_80", - "cubin-features": "+ptx83", - "cubin-format": "isa", - "kernel-bare-ptr-calling-convention": "1", - "opt-level": "2", - # "cubin-format": "fatbin", - # "cubin-format": "bin", + "chip": "sm_80", + "features": "+ptx83", + "O": "2", }, - ), + ) + .lower_affine() + .convert_arith_to_llvm() + .convert_index_to_llvm() + .canonicalize() + .cse() + .Gpu( + Pipeline() + .strip_debuginfo() + # TODO(max): upstream this (add to gpu pipeline) + # vector.transfer + .convert_vector_to_llvm() + .convert_gpu_to_nvvm(use_bare_ptr_memref_call_conv=True) + .canonicalize() + .cse() + .reconcile_unrealized_casts() + ) + .gpu_to_llvm(use_bare_pointers_for_kernels=True) + .gpu_module_to_binary(format="isa") + .canonicalize() + .cse() + .reconcile_unrealized_casts() + # .add_pass( + # "gpu-lower-to-nvvm-pipeline", + # # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18 + # **{ + # "cubin-chip": "sm_80", + # "cubin-features": "+ptx83", + # "cubin-format": "isa", + # "kernel-bare-ptr-calling-convention": "1", + # "opt-level": "2", + # # "cubin-format": "fatbin", + # # "cubin-format": "bin", + # }, + # ) + , enable_ir_printing=enable_ir_printing, ) + if print_ptx_: print_ptx(mod) @@ -420,6 +460,102 @@ def sgemm_shared_mem_2d_block_tiling[ ) +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_shared_mem_2d_block_tiling_vectorize[ + M, K, N, dtype, BM, BN, BK, TM, TN +](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): + VECTOR_WIDTH = 4 + DTYPE_WIDTH = dtype.width // 8 + + # ld.global.v4.u32 and st.global.v4.f32 emitted only input args are aligned + # alignment for cupy is 512 bytes https://github.com/cupy/cupy/blob/59e6c2b2e0c722b09c7a7af13f908942ef7806cc/cupy/cuda/memory.pyx#L805-L809 + # so we're good + memref.assume_alignment(A, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(B, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(C, VECTOR_WIDTH * DTYPE_WIDTH) + + base = gpu.dynamic_shared_memory() + base = memref.memory_space_cast(T.memref(S, element_type=T.i8()), base) + + # transpose A + A_shared = memref.view(base, (BK, BM), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + tid = gpu.thread_id() + # BN/TN are the number of threads to span a column + thread_col = tid % (BN // TN) + thread_row = tid / (BN // TN) + + # calculating the indices that this thread will load into SMEM + # we'll load 128bit / 32bit = 4 elements per thread at each step + inner_col_A = tid % (BK // VECTOR_WIDTH) # warp-level GMEM coalescing + inner_row_A = tid / (BK // VECTOR_WIDTH) + inner_col_B = tid % (BN // VECTOR_WIDTH) # warp-level GMEM coalescing + inner_row_B = tid / (BN // VECTOR_WIDTH) + + thread_results = memref.alloca((TM, TN), dtype) + linalg.fill(0, thread_results) + + reg_M = memref.alloca((TM,), dtype) + linalg.fill(0, reg_M) + + reg_N = memref.alloca((TN,), dtype) + linalg.fill(0, reg_N) + + for bk_idx in range_(0, K, BK): + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + A_vec = vector.load( + T.vector(VECTOR_WIDTH, dtype), A_, [inner_row_A, inner_col_A * VECTOR_WIDTH] + ) + for j in range(VECTOR_WIDTH): + # transpose A while loading it + A_shared[inner_col_A * VECTOR_WIDTH + j, inner_row_A] = A_vec[j] + + B_vec = vector.load( + T.vector(VECTOR_WIDTH, dtype), B_, [inner_row_B, inner_col_B * VECTOR_WIDTH] + ) + vector.store(B_vec, B_shared, [inner_row_B, inner_col_B * VECTOR_WIDTH]) + + gpu.barrier() + + for dot_idx in range_(BK): + for i in range_(TM): + reg_M[i] = A_shared[dot_idx, thread_row * TM + i] + + for i in range_(TN): + reg_N[i] = B_shared[dot_idx, thread_col * TN + i] + + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + thread_results[res_idx_m, res_idx_n] += ( + reg_M[res_idx_m] * reg_N[res_idx_n] + ) + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + C_ = C[c_row : c_row + BM, c_col : c_col + BN] + + for res_idx_m in range_(TM): + for res_idx_n in range_(0, TN, VECTOR_WIDTH): + tmp = vector.load( + T.vector(VECTOR_WIDTH, dtype), + C_, + [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n], + ) + for j in range(VECTOR_WIDTH): + tmp[j] = thread_results[res_idx_m, res_idx_n + j] + one + vector.store( + tmp, C_, [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n] + ) + + def prepare_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): dtype = T.f32() npy_dtype = np.float32 @@ -560,6 +696,7 @@ def run_eval( for k in [ sgemm_shared_mem_1d_block_tiling, sgemm_shared_mem_2d_block_tiling, + sgemm_shared_mem_2d_block_tiling_vectorize, ]: print(f"\n{k.__name__}") for s in sizes: diff --git a/examples/mlir_python_extras.ipynb b/examples/mlir_python_extras.ipynb index 939834c7..c372938f 100644 --- a/examples/mlir_python_extras.ipynb +++ b/examples/mlir_python_extras.ipynb @@ -56,7 +56,7 @@ "from mlir.extras.dialects.ext.arith import constant\n", "from mlir.extras.dialects.ext.memref import S\n", "from mlir.extras.dialects.ext.func import func\n", - "from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range\n", + "from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_\n", "from mlir.extras.runtime.passes import Pipeline, run_pipeline\n", "from mlir.extras.runtime.refbackend import LLVMJITBackend\n", "from mlir.ir import StridedLayoutAttr\n", @@ -102,8 +102,8 @@ " if one > two:\n", " C[0, 0] = constant(3, T.i64())\n", " else:\n", - " for i in range(0, K):\n", - " for j in range(0, K):\n", + " for i in range_(0, K):\n", + " for j in range_(0, K):\n", " C[i, j] = A[i, j] * B[i, j]" ] }, @@ -457,8 +457,8 @@ "def tile(\n", " A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32\n", "):\n", - " for i in range(0, D):\n", - " for j in range(0, D):\n", + " for i in range_(0, D):\n", + " for j in range_(0, D):\n", " C[i, j] = A[i, j] + B[i, j]\n", "\n", "@func(emit=True)\n", @@ -466,8 +466,8 @@ "def tiled_memfoo(\n", " A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n", "):\n", - " for i in range(0, F):\n", - " for j in range(0, F):\n", + " for i in range_(0, F):\n", + " for j in range_(0, F):\n", " l = lambda l: l * D\n", " r = lambda r: (r + 1) * D\n", " a, b, c = (\n", @@ -797,8 +797,8 @@ "def linalg_memfoo(\n", " A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n", "):\n", - " for i in range(0, F):\n", - " for j in range(0, F):\n", + " for i in range_(0, F):\n", + " for j in range_(0, F):\n", " l = lambda l: l * D\n", " r = lambda r: (r + 1) * D\n", " a, b, c = (\n", diff --git a/mlir/extras/dialects/ext/gpu.py b/mlir/extras/dialects/ext/gpu.py index 05f2ae89..8d394170 100644 --- a/mlir/extras/dialects/ext/gpu.py +++ b/mlir/extras/dialects/ext/gpu.py @@ -2,6 +2,8 @@ from functools import partial from typing import Any, List, Optional, Tuple, Union +from mlir.dialects._gpu_enum_gen import AddressSpace + from .arith import constant from .func import FuncBase from ... import types as T @@ -117,32 +119,39 @@ def get_device_mapping_array_attr( return ArrayAttr.get(mapping, context=context) -def device_mapping_attr(mnemonic, mapping_id_enum: MappingId): +def gpu_attr(mnemonic, mapping_id_enum: MappingId): return Attribute.parse(f"#gpu.{mnemonic}<{mapping_id_enum}>") def thread_attr(thread): - return device_mapping_attr("thread", thread) + return gpu_attr("thread", thread) def block_attr(block): - return device_mapping_attr("block", block) + return gpu_attr("block", block) def warp_attr(warp): - return device_mapping_attr("warp", warp) + return gpu_attr("warp", warp) def warpgroup_attr(warpgroup): - return device_mapping_attr("warpgroup", warpgroup) + return gpu_attr("warpgroup", warpgroup) def address_space_attr(address_space: AddressSpace): - return device_mapping_attr("address_space", address_space) + return gpu_attr("address_space", address_space) + + +_int = int + +def smem_space(int=False): + a = AddressSpace.Workgroup + if int: + return _int(a) -def smem_space(): - return address_space_attr(AddressSpace.Workgroup) + return address_space_attr(a) @_cext.register_operation(_Dialect, replace=True) @@ -577,12 +586,12 @@ def printf(format, *args): _dynamic_shared_memory = dynamic_shared_memory -def dynamic_shared_memory(*, loc=None, ip=None): +def dynamic_shared_memory(*, int=False, loc=None, ip=None): return _dynamic_shared_memory( T.memref( ShapedType.get_dynamic_size(), element_type=T.i8(), - memory_space=smem_space(), + memory_space=smem_space(int), ), loc=loc, ip=ip, diff --git a/mlir/extras/dialects/ext/memref.py b/mlir/extras/dialects/ext/memref.py index b4ba2d9e..b091dd68 100644 --- a/mlir/extras/dialects/ext/memref.py +++ b/mlir/extras/dialects/ext/memref.py @@ -33,6 +33,7 @@ def _alloc( sizes: Sequence[Union[int, Value]], element_type: Type, memory_space=None, + alignment=None, loc=None, ip=None, ): @@ -52,21 +53,56 @@ def _alloc( symbol_operands = [] return get_op_result_or_op_results( - op_ctor(result_type, dynamic_sizes, symbol_operands, loc=loc, ip=ip) + op_ctor( + result_type, + dynamic_sizes, + symbol_operands, + alignment=alignment, + loc=loc, + ip=ip, + ) ) -def alloc(sizes: Union[int, Value], element_type: Type = None, memory_space=None): - loc = get_user_code_loc() +def alloc( + sizes: Union[int, Value], + element_type: Type = None, + memory_space=None, + alignment=None, + loc=None, + ip=None, +): + if loc is None: + loc = get_user_code_loc() return _alloc( - AllocOp, sizes, element_type, memory_space=memory_space, loc=loc, ip=None + AllocOp, + sizes, + element_type, + memory_space=memory_space, + alignment=alignment, + loc=loc, + ip=ip, ) -def alloca(sizes: Union[int, Value], element_type: Type = None, memory_space=None): - loc = get_user_code_loc() +def alloca( + sizes: Union[int, Value], + element_type: Type = None, + memory_space=None, + alignment=None, + loc=None, + ip=None, +): + if loc is None: + loc = get_user_code_loc() return _alloc( - AllocaOp, sizes, element_type, memory_space=memory_space, loc=loc, ip=None + AllocaOp, + sizes, + element_type, + memory_space=memory_space, + alignment=alignment, + loc=loc, + ip=ip, ) @@ -115,6 +151,7 @@ def __getitem__(self, idx: tuple) -> "MemRef": idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx) for i, d in enumerate(idx): + # TODO(max): rethink this since subview and etc probably take constant attributes? if isinstance(d, int): idx[i] = constant(d, index=True, loc=loc) @@ -123,7 +160,7 @@ def __getitem__(self, idx: tuple) -> "MemRef": else: return _subview(self, tuple(idx), loc=loc) - def __setitem__(self, idx, source): + def __setitem__(self, idx, val): loc = get_user_code_loc() if not self.has_rank(): @@ -135,12 +172,10 @@ def __setitem__(self, idx, source): idx[i] = constant(d, index=True, loc=loc) if all(isinstance(d, Scalar) for d in idx) and len(idx) == len(self.shape): - assert isinstance( - source, Scalar - ), "coordinate insert requires scalar element" - store(source, self, idx, loc=loc) + assert isinstance(val, Scalar), "coordinate insert requires scalar element" + store(val, self, idx, loc=loc) else: - _copy_to_subview(self, source, tuple(idx), loc=loc) + _copy_to_subview(self, val, tuple(idx), loc=loc) def expand_shape( diff --git a/mlir/extras/dialects/ext/scf.py b/mlir/extras/dialects/ext/scf.py index f2436b76..962de593 100644 --- a/mlir/extras/dialects/ext/scf.py +++ b/mlir/extras/dialects/ext/scf.py @@ -437,11 +437,14 @@ def visit_If(self, updated_node: ast.If) -> ast.If: return updated_node def visit_For(self, updated_node: ast.For) -> ast.For: - updated_node = self.generic_visit(updated_node) - new_yield = ast.Expr(ast.Yield(value=None)) - if not is_yield(updated_node.body[-1]): - updated_node.body = append_hidden_node(updated_node.body, new_yield) - updated_node = ast.fix_missing_locations(updated_node) + # TODO(max): this isn't robust at all... + line = ast.dump(updated_node.iter.func) + if "range_" in line or "for_" in line: + updated_node = self.generic_visit(updated_node) + new_yield = ast.Expr(ast.Yield(value=None)) + if not is_yield(updated_node.body[-1]): + updated_node.body = append_hidden_node(updated_node.body, new_yield) + updated_node = ast.fix_missing_locations(updated_node) return updated_node diff --git a/mlir/extras/dialects/ext/tensor.py b/mlir/extras/dialects/ext/tensor.py index ecfd55c7..7b05c1fd 100644 --- a/mlir/extras/dialects/ext/tensor.py +++ b/mlir/extras/dialects/ext/tensor.py @@ -102,6 +102,7 @@ def insert_slice( ) +# TODO(max): unify vector/memref/tensor @register_value_caster(RankedTensorType.static_typeid) class Tensor(ShapedValue, ArithValue): def __getitem__(self, idx: tuple) -> "Tensor": diff --git a/mlir/extras/dialects/ext/vector.py b/mlir/extras/dialects/ext/vector.py index 3291a889..8a498117 100644 --- a/mlir/extras/dialects/ext/vector.py +++ b/mlir/extras/dialects/ext/vector.py @@ -1,20 +1,46 @@ +import inspect from typing import List from ._shaped_value import ShapedValue -from .arith import ArithValue, FastMathFlags, constant -from ...util import get_user_code_loc +from .arith import ArithValue, FastMathFlags, constant, Scalar +from ...util import get_user_code_loc, _update_caller_vars from ...._mlir_libs._mlir import register_value_caster from ....dialects._ods_common import _dispatch_mixed_values # noinspection PyUnresolvedReferences from ....dialects.vector import * from ....extras import types as T -from ....ir import AffineMap, VectorType +from ....ir import AffineMap, VectorType, Value @register_value_caster(VectorType.static_typeid) class Vector(ShapedValue, ArithValue): - pass + def __getitem__(self, idx: tuple) -> "Vector": + loc = get_user_code_loc() + + if not self.has_rank(): + raise ValueError("only ranked memref slicing/indexing supported") + + if idx == Ellipsis or idx == slice(None): + return self + if isinstance(idx, tuple) and all(i == slice(None) for i in idx): + return self + if idx is None: + raise RuntimeError("None idx not supported") + + idx = list((idx,) if isinstance(idx, (int, Scalar, slice)) else idx) + return extract(self, tuple(idx), loc=loc) + + def __setitem__(self, idx, val): + loc = get_user_code_loc() + + if not self.has_rank(): + raise ValueError("only ranked memref slicing/indexing supported") + + idx = list((idx,) if isinstance(idx, (Scalar, int, Value)) else idx) + res = insert(self, val, idx, loc=loc) + previous_frame = inspect.currentframe().f_back + _update_caller_vars(previous_frame, [self], [res]) _transfer_write = transfer_write @@ -22,28 +48,27 @@ class Vector(ShapedValue, ArithValue): def transfer_write( vector: Vector, - source, + dest, indices, *, permutation_map=None, mask: List[int] = None, in_bounds: List[bool] = None, loc=None, - ip=None + ip=None, ): if loc is None: loc = get_user_code_loc() if permutation_map is None: - permutation_map = AffineMap.get_minor_identity( - source.type.rank, vector.type.rank - ) + permutation_map = AffineMap.get_minor_identity(dest.type.rank, vector.type.rank) for j, i in enumerate(indices): if isinstance(i, int): indices[j] = constant(i, index=True) return _transfer_write( result=None, vector=vector, - source=source, + # no clue why they chose this name... + source=dest, indices=indices, permutation_map=permutation_map, mask=mask, @@ -66,7 +91,7 @@ def transfer_read( mask=None, in_bounds=None, loc=None, - ip=None + ip=None, ): if loc is None: loc = get_user_code_loc() @@ -111,6 +136,25 @@ def extract(vector, position, *, loc=None, ip=None): ) +_insert = insert + + +def insert(vector, val, position, *, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() + dynamic_position, _packed_position, static_position = _dispatch_mixed_values( + position + ) + return _insert( + val, + dest=vector, + dynamic_position=dynamic_position, + static_position=static_position, + loc=loc, + ip=ip, + ) + + _reduction = reduction @@ -121,7 +165,7 @@ def reduction( acc=None, fastmath: FastMathFlags = None, loc=None, - ip=None + ip=None, ): if loc is None: loc = get_user_code_loc() From d25b03e63a074064f73ade8bcc2f91f8c31ff162 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 23 Apr 2024 12:46:16 -0500 Subject: [PATCH 3/5] warp tiling not working --- examples/cuda_matmul_opt.py | 242 ++++++++++++++++++++++++++++++- mlir/extras/dialects/ext/gpu.py | 16 ++ mlir/extras/dialects/ext/llvm.py | 1 + 3 files changed, 251 insertions(+), 8 deletions(-) diff --git a/examples/cuda_matmul_opt.py b/examples/cuda_matmul_opt.py index c283c688..0becbd5c 100644 --- a/examples/cuda_matmul_opt.py +++ b/examples/cuda_matmul_opt.py @@ -556,6 +556,168 @@ def sgemm_shared_mem_2d_block_tiling_vectorize[ ) +WARP_SIZE = 32 + + +@gpu.func +@canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) +def sgemm_warp_tiling[ + M, K, N, dtype, BM, BN, BK, WM, WN, WNITER, TM, TN, NUM_THREADS +](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)): + VECTOR_WIDTH = 4 + DTYPE_WIDTH = dtype.width // 8 + + tid = gpu.thread_id() + + # ld.global.v4.u32 and st.global.v4.f32 emitted only input args are aligned + # alignment for cupy is 512 bytes https://github.com/cupy/cupy/blob/59e6c2b2e0c722b09c7a7af13f908942ef7806cc/cupy/cuda/memory.pyx#L805-L809 + # so we're good + memref.assume_alignment(A, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(B, VECTOR_WIDTH * DTYPE_WIDTH) + memref.assume_alignment(C, VECTOR_WIDTH * DTYPE_WIDTH) + + base = gpu.dynamic_shared_memory() + base = memref.memory_space_cast(T.memref(S, element_type=T.i8()), base) + + # transpose A + A_shared = memref.view(base, (BK, BM), dtype=dtype) + B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK) + + c_row = block_idx.y * BM + c_col = block_idx.x * BN + + # Placement of the warp in the threadblock tile + warp_idx = tid / WARP_SIZE + warp_row = warp_idx / (BN // WN) + warp_col = warp_idx % (BN // WN) + + # size of the warp subtile + WMITER = (WM * WN) // (WARP_SIZE * TM * TN * WNITER) + WSUBM = WM // WMITER + WSUBN = WN // WNITER + + # Placement of the thread in the warp subtile + thread_idx_in_warp = tid % WARP_SIZE + thread_col_in_warp = thread_idx_in_warp % (WSUBN // TN) + thread_row_in_warp = thread_idx_in_warp / (WSUBN // TN) + + # calculating the indices that this thread will load into SMEM + # we'll load 128bit / 32bit = 4 elements per thread at each step + inner_row_A = tid / (BK // VECTOR_WIDTH) + inner_col_A = tid % (BK // VECTOR_WIDTH) + row_stride_A = (NUM_THREADS * VECTOR_WIDTH) // BK + inner_row_B = tid / (BN // VECTOR_WIDTH) + inner_col_B = tid % (BN // VECTOR_WIDTH) + row_stride_B = NUM_THREADS // (BN // VECTOR_WIDTH) + + # allocate thread-local cache for results in registerfile + thread_results = memref.alloca((WMITER * TM, WNITER * TN), dtype) + linalg.fill(0, thread_results) + + reg_M = memref.alloca((WMITER, TM), dtype) + linalg.fill(0, reg_M) + + reg_N = memref.alloca((WNITER, TN), dtype) + linalg.fill(0, reg_N) + + for bk_idx in range_(0, K, BK): + A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK] + B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN] + + for offset in range(0, BM - row_stride_A + 1, row_stride_A): + A_vec = vector.load( + T.vector(VECTOR_WIDTH, dtype), + A_, + [inner_row_A + offset, inner_col_A * VECTOR_WIDTH], + ) + for j in range(VECTOR_WIDTH): + # transpose A while loading it + A_shared[inner_col_A * VECTOR_WIDTH + j, inner_row_A + offset] = A_vec[ + j + ] + + for offset in range(0, BK - row_stride_B + 1, row_stride_B): + B_vec = vector.load( + T.vector(VECTOR_WIDTH, dtype), + B_, + [inner_row_B + offset, inner_col_B * VECTOR_WIDTH], + ) + vector.store( + B_vec, B_shared, [inner_row_B + offset, inner_col_B * VECTOR_WIDTH] + ) + + gpu.barrier() + + for dot_idx in range_(BK): + for w_sub_row_idx in range_(WMITER): + for i in range_(TM): + reg_M[w_sub_row_idx, i] = A_shared[ + dot_idx, + warp_row * WM + + w_sub_row_idx * WSUBM + + thread_row_in_warp * TM + + i, + ] + + for w_sub_col_idx in range_(WNITER): + for i in range_(TN): + reg_N[w_sub_col_idx, i] = A_shared[ + dot_idx, + warp_col * WN + + w_sub_col_idx * WSUBN + + thread_col_in_warp * TN + + i, + ] + + for w_sub_row_idx in range_(WMITER): + for w_sub_col_idx in range_(WNITER): + for res_idx_m in range_(TM): + for res_idx_n in range_(TN): + thread_results[ + w_sub_row_idx * TM + res_idx_m, + w_sub_col_idx * TN + res_idx_n, + ] += ( + reg_M[w_sub_row_idx, res_idx_m] + * reg_N[w_sub_col_idx, res_idx_n] + ) + + gpu.barrier() + + one = arith.constant(1.0, type=dtype) + + for w_sub_row_idx in range_(WMITER): + for w_sub_col_idx in range_(WNITER): + r = c_row + warp_row * WM + w_sub_row_idx * WSUBM + c = c_col + warp_col * WN + w_sub_col_idx * WSUBN + C_ = C[r : r + WSUBM, c : c + WSUBN] + for res_idx_m in range_(TM): + for res_idx_n in range_(0, TN, VECTOR_WIDTH): + tmp = vector.load( + T.vector(VECTOR_WIDTH, dtype), + C_, + [ + thread_row_in_warp * TM + res_idx_m, + thread_col_in_warp * TN + res_idx_n, + ], + ) + for j in range(VECTOR_WIDTH): + tmp[j] = ( + thread_results[ + w_sub_row_idx * TM + res_idx_m, + w_sub_col_idx * TN + res_idx_n + j, + ] + + one + ) + vector.store( + tmp, + C_, + [ + thread_row_in_warp * TM + res_idx_m, + thread_col_in_warp * TN + res_idx_n, + ], + ) + + def prepare_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): dtype = T.f32() npy_dtype = np.float32 @@ -606,6 +768,49 @@ def matmul_mod(): ) +def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): + dtype = T.f32() + npy_dtype = np.float32 + kernel_name = kernel.__name__ + + gpu.set_container_module(ctx.module) + + NUM_THREADS = 128 + BN = 128 + BM = 128 + BK = 16 + WN = 64 + WM = 64 + WNITER = 4 + TN = 4 + TM = 8 + + @gpu.module("matmul", ["#nvvm.target"]) + def matmul_mod(): + kernel[M, K, N, dtype, BM, BN, BK, WM, WN, WNITER, TM, TN, NUM_THREADS].emit() + + # print(ctx.module) + # print(ctx.module.operation.verify()) + # exit() + + compiled_module = compile_module(ctx.module) + cuda_func = build_cuda_func(compiled_module, kernel_name) + # print_ptx(compiled_module) + + grid_dims = (math.ceil(N / BN), math.ceil(M / BM)) + block_dims = (NUM_THREADS,) + shared_mem = ((BM * BK) + (BK * BN)) * npy_dtype().nbytes + + return ( + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + False, + ) + + def run_eval( M, K, @@ -664,11 +869,11 @@ def run_eval( repeats = None for k in [ - sgemm_naive, - sgemm_naive_row_order, - sgemm_coalesce, - sgemm_coalesce_transpose_B, - sgemm_shared_mem_block, + # sgemm_naive, + # sgemm_naive_row_order, + # sgemm_coalesce, + # sgemm_coalesce_transpose_B, + # sgemm_shared_mem_block, ]: print(f"\n{k.__name__}") for s in sizes: @@ -694,9 +899,9 @@ def run_eval( for k in [ - sgemm_shared_mem_1d_block_tiling, - sgemm_shared_mem_2d_block_tiling, - sgemm_shared_mem_2d_block_tiling_vectorize, + # sgemm_shared_mem_1d_block_tiling, + # sgemm_shared_mem_2d_block_tiling, + # sgemm_shared_mem_2d_block_tiling_vectorize, ]: print(f"\n{k.__name__}") for s in sizes: @@ -719,3 +924,24 @@ def run_eval( npy_dtype, transpose_B, ) + +for s in sizes: + with ( + mlir_mod_ctx() as ctx, + # enable_debug() + ): + print(f"{s=}", end=" ") + cuda_func, grid_dims, block_dims, shared_mem, npy_dtype, transpose_B = ( + prepare_warp_tiled_kernel(ctx, sgemm_warp_tiling, s, s, s) + ) + run_eval( + s, + s, + s, + cuda_func, + grid_dims, + block_dims, + shared_mem, + npy_dtype, + transpose_B, + ) diff --git a/mlir/extras/dialects/ext/gpu.py b/mlir/extras/dialects/ext/gpu.py index 8d394170..6d601bf7 100644 --- a/mlir/extras/dialects/ext/gpu.py +++ b/mlir/extras/dialects/ext/gpu.py @@ -596,3 +596,19 @@ def dynamic_shared_memory(*, int=False, loc=None, ip=None): loc=loc, ip=ip, ) + + +_memset = memset + + +def memset(dst, value, async_dependencies=None, *, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() + if async_dependencies is None: + async_dependencies = [] + async_token = None + if len(async_dependencies): + async_token = gpu_async_token() + if isinstance(value, (int, float, bool)): + value = constant(value, type=dst.type.element_type) + return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip) diff --git a/mlir/extras/dialects/ext/llvm.py b/mlir/extras/dialects/ext/llvm.py index 62a5ad4d..cf85d94e 100644 --- a/mlir/extras/dialects/ext/llvm.py +++ b/mlir/extras/dialects/ext/llvm.py @@ -5,3 +5,4 @@ def llvm_ptr_t(): return Type.parse("!llvm.ptr") + From 175959f55ba375817a4e9688134d83aa5c3c6ca2 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 23 Apr 2024 12:51:01 -0500 Subject: [PATCH 4/5] warp tiling NA but slow --- examples/cuda_matmul_opt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/cuda_matmul_opt.py b/examples/cuda_matmul_opt.py index 0becbd5c..9baaab2b 100644 --- a/examples/cuda_matmul_opt.py +++ b/examples/cuda_matmul_opt.py @@ -661,7 +661,7 @@ def sgemm_warp_tiling[ for w_sub_col_idx in range_(WNITER): for i in range_(TN): - reg_N[w_sub_col_idx, i] = A_shared[ + reg_N[w_sub_col_idx, i] = B_shared[ dot_idx, warp_col * WN + w_sub_col_idx * WSUBN From 8d63ca4042c2ca28637d5c48fa1eaf312699efa8 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 23 Apr 2024 12:53:57 -0500 Subject: [PATCH 5/5] warp tiling fast --- examples/cuda_matmul_opt.py | 26 ++++++++++++++------------ mlir/extras/ast/canonicalize.py | 2 +- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/examples/cuda_matmul_opt.py b/examples/cuda_matmul_opt.py index 9baaab2b..fd100486 100644 --- a/examples/cuda_matmul_opt.py +++ b/examples/cuda_matmul_opt.py @@ -775,15 +775,16 @@ def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N): gpu.set_container_module(ctx.module) + # Settings for A100 (looks like it works for 3070 too?) NUM_THREADS = 128 BN = 128 - BM = 128 + BM = 64 BK = 16 WN = 64 - WM = 64 - WNITER = 4 + WM = 32 + WNITER = 1 TN = 4 - TM = 8 + TM = 4 @gpu.module("matmul", ["#nvvm.target"]) def matmul_mod(): @@ -869,11 +870,11 @@ def run_eval( repeats = None for k in [ - # sgemm_naive, - # sgemm_naive_row_order, - # sgemm_coalesce, - # sgemm_coalesce_transpose_B, - # sgemm_shared_mem_block, + sgemm_naive, + sgemm_naive_row_order, + sgemm_coalesce, + sgemm_coalesce_transpose_B, + sgemm_shared_mem_block, ]: print(f"\n{k.__name__}") for s in sizes: @@ -899,9 +900,9 @@ def run_eval( for k in [ - # sgemm_shared_mem_1d_block_tiling, - # sgemm_shared_mem_2d_block_tiling, - # sgemm_shared_mem_2d_block_tiling_vectorize, + sgemm_shared_mem_1d_block_tiling, + sgemm_shared_mem_2d_block_tiling, + sgemm_shared_mem_2d_block_tiling_vectorize, ]: print(f"\n{k.__name__}") for s in sizes: @@ -925,6 +926,7 @@ def run_eval( transpose_B, ) +print(f"\n{sgemm_warp_tiling.__name__}") for s in sizes: with ( mlir_mod_ctx() as ctx, diff --git a/mlir/extras/ast/canonicalize.py b/mlir/extras/ast/canonicalize.py index 3cdcf4dd..638786b6 100644 --- a/mlir/extras/ast/canonicalize.py +++ b/mlir/extras/ast/canonicalize.py @@ -116,7 +116,7 @@ def transform_ast( max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1 > n_lines ) or (f.__code__.co_firstlineno != min([l for _, l in line_starts])): - warnings.warn( + logger.debug( "something went wrong with the line numbers for the rewritten/canonicalized function" ) f.__code__ = new_f_code_o