Skip to content

Commit 1607990

Browse files
committed
warp tiling fast
1 parent 175959f commit 1607990

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

examples/cuda_matmul_opt.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -775,15 +775,16 @@ def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N):
775775

776776
gpu.set_container_module(ctx.module)
777777

778+
# Settings for A100 (looks like it works for 3070 too?)
778779
NUM_THREADS = 128
779780
BN = 128
780-
BM = 128
781+
BM = 64
781782
BK = 16
782783
WN = 64
783-
WM = 64
784-
WNITER = 4
784+
WM = 32
785+
WNITER = 1
785786
TN = 4
786-
TM = 8
787+
TM = 4
787788

788789
@gpu.module("matmul", ["#nvvm.target"])
789790
def matmul_mod():
@@ -869,11 +870,11 @@ def run_eval(
869870
repeats = None
870871

871872
for k in [
872-
# sgemm_naive,
873-
# sgemm_naive_row_order,
874-
# sgemm_coalesce,
875-
# sgemm_coalesce_transpose_B,
876-
# sgemm_shared_mem_block,
873+
sgemm_naive,
874+
sgemm_naive_row_order,
875+
sgemm_coalesce,
876+
sgemm_coalesce_transpose_B,
877+
sgemm_shared_mem_block,
877878
]:
878879
print(f"\n{k.__name__}")
879880
for s in sizes:
@@ -899,9 +900,9 @@ def run_eval(
899900

900901

901902
for k in [
902-
# sgemm_shared_mem_1d_block_tiling,
903-
# sgemm_shared_mem_2d_block_tiling,
904-
# sgemm_shared_mem_2d_block_tiling_vectorize,
903+
sgemm_shared_mem_1d_block_tiling,
904+
sgemm_shared_mem_2d_block_tiling,
905+
sgemm_shared_mem_2d_block_tiling_vectorize,
905906
]:
906907
print(f"\n{k.__name__}")
907908
for s in sizes:
@@ -925,6 +926,7 @@ def run_eval(
925926
transpose_B,
926927
)
927928

929+
print(f"\n{sgemm_warp_tiling.__name__}")
928930
for s in sizes:
929931
with (
930932
mlir_mod_ctx() as ctx,

0 commit comments

Comments
 (0)