Skip to content

Commit 8d63ca4

Browse files
committed
warp tiling fast
1 parent 175959f commit 8d63ca4

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

examples/cuda_matmul_opt.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -775,15 +775,16 @@ def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N):
775775

776776
gpu.set_container_module(ctx.module)
777777

778+
# Settings for A100 (looks like it works for 3070 too?)
778779
NUM_THREADS = 128
779780
BN = 128
780-
BM = 128
781+
BM = 64
781782
BK = 16
782783
WN = 64
783-
WM = 64
784-
WNITER = 4
784+
WM = 32
785+
WNITER = 1
785786
TN = 4
786-
TM = 8
787+
TM = 4
787788

788789
@gpu.module("matmul", ["#nvvm.target"])
789790
def matmul_mod():
@@ -869,11 +870,11 @@ def run_eval(
869870
repeats = None
870871

871872
for k in [
872-
# sgemm_naive,
873-
# sgemm_naive_row_order,
874-
# sgemm_coalesce,
875-
# sgemm_coalesce_transpose_B,
876-
# sgemm_shared_mem_block,
873+
sgemm_naive,
874+
sgemm_naive_row_order,
875+
sgemm_coalesce,
876+
sgemm_coalesce_transpose_B,
877+
sgemm_shared_mem_block,
877878
]:
878879
print(f"\n{k.__name__}")
879880
for s in sizes:
@@ -899,9 +900,9 @@ def run_eval(
899900

900901

901902
for k in [
902-
# sgemm_shared_mem_1d_block_tiling,
903-
# sgemm_shared_mem_2d_block_tiling,
904-
# sgemm_shared_mem_2d_block_tiling_vectorize,
903+
sgemm_shared_mem_1d_block_tiling,
904+
sgemm_shared_mem_2d_block_tiling,
905+
sgemm_shared_mem_2d_block_tiling_vectorize,
905906
]:
906907
print(f"\n{k.__name__}")
907908
for s in sizes:
@@ -925,6 +926,7 @@ def run_eval(
925926
transpose_B,
926927
)
927928

929+
print(f"\n{sgemm_warp_tiling.__name__}")
928930
for s in sizes:
929931
with (
930932
mlir_mod_ctx() as ctx,

mlir/extras/ast/canonicalize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def transform_ast(
116116
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
117117
> n_lines
118118
) or (f.__code__.co_firstlineno != min([l for _, l in line_starts])):
119-
warnings.warn(
119+
logger.debug(
120120
"something went wrong with the line numbers for the rewritten/canonicalized function"
121121
)
122122
f.__code__ = new_f_code_o

0 commit comments

Comments
 (0)