1- import ctypes
21import platform
2+ import random
33import sys
44import tempfile
5+ import time
56from textwrap import dedent
67
78import mlir .extras .types as T
4041
4142# noinspection PyUnresolvedReferences
4243from mlir .extras .testing import mlir_ctx as ctx , filecheck , MLIRContext
43- from util import hip_bindings_not_installed , hip_check , launch_kernel
44+ from util import hip_bindings_not_installed , hip_check , launch_kernel , hip_synchronize
4445
4546# needed since the fix isn't defined here nor conftest.py
4647pytest .mark .usefixtures ("ctx" )
@@ -962,6 +963,7 @@ def test_amdgpu_vector(ctx: MLIRContext):
962963
963964 scale = 2
964965 M , K , N = 2 * scale , 4 * scale , 6 * scale
966+ tz_a , tz_b , tz_c = [2 , 2 , 2 ]
965967 v2f32 = T .vector (2 , T .f32 ())
966968
967969 @gpu_func
@@ -972,11 +974,11 @@ def smol_matmul(
972974 ):
973975 cst = arith .constant (np .full ([4 ], 0.0 , np .float32 ), T .vector (4 , T .f32 ()))
974976 cst_0 = arith .constant (
975- np .full ([2 , 2 ], 0.0 , np .float32 ), T .vector (2 , 2 , T .f32 ())
977+ np .full ([tz_a , tz_b ], 0.0 , np .float32 ), T .vector (tz_a , tz_b , T .f32 ())
976978 )
977- for i , C , v0 in scf .range_ (0 , M , 2 , iter_args = [C ]):
978- for j , C , v1 in scf .range_ (0 , N , 2 , iter_args = [C ]):
979- for k , C , v2 in scf .range_ (0 , K , 2 , iter_args = [C ]):
979+ for i , C , v0 in scf .range_ (0 , M , tz_a , iter_args = [C ]):
980+ for j , C , v1 in scf .range_ (0 , N , tz_b , iter_args = [C ]):
981+ for k , C , v2 in scf .range_ (0 , K , tz_c , iter_args = [C ]):
980982 cst [0 ::1 ] = A @ load (v2f32 ) @ [i , k ]
981983 cst [2 ::1 ] = A @ load (v2f32 ) @ [i + 1 , k ]
982984 cst_0 [0 ] = C @ load (v2f32 ) @ [i , j ]
@@ -1078,3 +1080,116 @@ def gpu_module():
10781080 hip_check (hip .hipFree (c_d ))
10791081
10801082 hip_check (hip .hipModuleUnload (hip_module ))
1083+
1084+
1085+ @pytest .mark .skipif (hip_bindings_not_installed (), reason = "hip not installed" )
1086+ def test_amdgpu_bank_conflicts (ctx : MLIRContext ):
1087+ from hip import hip
1088+
1089+ set_container_module (ctx .module )
1090+
1091+ M = 1024
1092+
1093+ @gpu_func
1094+ def no_bank_conflicts (A : T .memref (M , M , T .f32 ()), B : T .memref (M , M , T .f32 ())):
1095+ for i in range (M ):
1096+ a = A [i , thread_idx .x ]
1097+ B [i , thread_idx .x ] = a * a
1098+
1099+ @gpu_func
1100+ def all_bank_conflicts (A : T .memref (M , M , T .f32 ()), B : T .memref (M , M , T .f32 ())):
1101+ for i in range (M ):
1102+ a = A [i , thread_idx .x ]
1103+ B [thread_idx .x , i ] = a * a
1104+
1105+ props = hip .hipDeviceProp_t ()
1106+ hip_check (hip .hipGetDeviceProperties (props , 0 ))
1107+ arch = props .gcnArchName .decode ()
1108+
1109+ @module ("naive" , [f'#rocdl.target<chip = "{ arch } ">' ])
1110+ def gpu_module ():
1111+ no_bank_conflicts .emit ()
1112+ all_bank_conflicts .emit ()
1113+
1114+ lowered_module = run_pipeline (
1115+ gpu_module ,
1116+ Pipeline ()
1117+ .Gpu (Pipeline ().convert_gpu_to_rocdl (use_bare_ptr_memref_call_conv = True ))
1118+ .rocdl_attach_target (chip = arch )
1119+ .gpu_to_llvm ()
1120+ .lower_to_llvm ()
1121+ .gpu_module_to_binary (),
1122+ )
1123+
1124+ hsaco = get_compile_object_bytes (lowered_module )
1125+ hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
1126+
1127+ a_h = np .arange (M ).astype (dtype = np .float32 )
1128+ a_h = np .tile (a_h , (M , 1 ))
1129+ b_h = np .zeros ((M , M ), dtype = np .float32 )
1130+
1131+ a_num_bytes = a_h .size * a_h .itemsize
1132+ b_num_bytes = b_h .size * b_h .itemsize
1133+
1134+ a_d = hip_check (hip .hipMalloc (a_num_bytes ))
1135+ b_d = hip_check (hip .hipMalloc (b_num_bytes ))
1136+
1137+ gridX = max (M // 32 , 1 )
1138+ gridY = max (M // 8 , 1 )
1139+ gridZ = 1
1140+ warp_size = 32
1141+ num_warps = 8
1142+ stream = 0
1143+ shared_memory = 0
1144+
1145+ times = {
1146+ no_bank_conflicts .__name__ : 0 ,
1147+ all_bank_conflicts .__name__ : 0 ,
1148+ }
1149+ runs = 10
1150+ start , stop = hip .hipEventCreate (), hip .hipEventCreate ()
1151+ for i in range (runs ):
1152+ kernels = [no_bank_conflicts , all_bank_conflicts ]
1153+ random .shuffle (kernels )
1154+ for kernel in kernels :
1155+ hip_check (
1156+ hip .hipMemcpy (
1157+ a_d , a_h , a_num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice
1158+ )
1159+ )
1160+ hip_check (
1161+ hip .hipMemcpy (
1162+ b_d , b_h , b_num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice
1163+ )
1164+ )
1165+ function = hip_check (
1166+ hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ())
1167+ )
1168+
1169+ start = time .monotonic ()
1170+ launch_kernel (
1171+ function .as_c_void_p (),
1172+ gridX ,
1173+ gridY ,
1174+ gridZ ,
1175+ warp_size ,
1176+ num_warps ,
1177+ stream ,
1178+ shared_memory ,
1179+ a_d ,
1180+ b_d ,
1181+ )
1182+ hip_synchronize ()
1183+ if i > 0 :
1184+ times [kernel .__name__ ] += time .monotonic () - start
1185+
1186+ hip_check (
1187+ hip .hipMemcpy (
1188+ b_h , b_d , b_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
1189+ )
1190+ )
1191+
1192+ times [no_bank_conflicts .__name__ ] /= runs
1193+ times [all_bank_conflicts .__name__ ] /= runs
1194+ for k , v in times .items ():
1195+ print (f"{ k } : { v :.3e} ms" )
0 commit comments