@@ -775,15 +775,16 @@ def prepare_warp_tiled_kernel(ctx: MLIRContext, kernel, M, K, N):
775775
776776 gpu .set_container_module (ctx .module )
777777
778+ # Settings for A100 (looks like it works for 3070 too?)
778779 NUM_THREADS = 128
779780 BN = 128
780- BM = 128
781+ BM = 64
781782 BK = 16
782783 WN = 64
783- WM = 64
784- WNITER = 4
784+ WM = 32
785+ WNITER = 1
785786 TN = 4
786- TM = 8
787+ TM = 4
787788
788789 @gpu .module ("matmul" , ["#nvvm.target" ])
789790 def matmul_mod ():
@@ -869,11 +870,11 @@ def run_eval(
869870repeats = None
870871
871872for k in [
872- # sgemm_naive,
873- # sgemm_naive_row_order,
874- # sgemm_coalesce,
875- # sgemm_coalesce_transpose_B,
876- # sgemm_shared_mem_block,
873+ sgemm_naive ,
874+ sgemm_naive_row_order ,
875+ sgemm_coalesce ,
876+ sgemm_coalesce_transpose_B ,
877+ sgemm_shared_mem_block ,
877878]:
878879 print (f"\n { k .__name__ } " )
879880 for s in sizes :
@@ -899,9 +900,9 @@ def run_eval(
899900
900901
901902for k in [
902- # sgemm_shared_mem_1d_block_tiling,
903- # sgemm_shared_mem_2d_block_tiling,
904- # sgemm_shared_mem_2d_block_tiling_vectorize,
903+ sgemm_shared_mem_1d_block_tiling ,
904+ sgemm_shared_mem_2d_block_tiling ,
905+ sgemm_shared_mem_2d_block_tiling_vectorize ,
905906]:
906907 print (f"\n { k .__name__ } " )
907908 for s in sizes :
@@ -925,6 +926,7 @@ def run_eval(
925926 transpose_B ,
926927 )
927928
929+ print (f"\n { sgemm_warp_tiling .__name__ } " )
928930for s in sizes :
929931 with (
930932 mlir_mod_ctx () as ctx ,
0 commit comments