1515)
1616from mlir .extras .dialects .ext import arith , memref , gpu , scf
1717from mlir .extras .dialects .ext .gpu import (
18- block_id ,
19- thread_id ,
18+ block_idx ,
19+ thread_idx ,
2020 block_dim ,
2121 get_compile_object_bytes ,
2222)
3030_ = memref
3131
3232
33- def build_cuda_func (compiled_module , kernel_name = "mat_product_kernel " ):
33+ def build_cuda_func (compiled_module , kernel_name = "naive " ):
3434 ptx = get_compile_object_bytes (compiled_module )
3535 mod = Module ()
3636 mod .load (ptx )
3737 return mod .get_function (kernel_name )
3838
3939
40+ def print_ptx (compiled_module ):
41+ ptx = get_compile_object_bytes (compiled_module )
42+ print (ptx .decode ())
43+
44+
45+ def compile_module (module , enable_ir_printing = False , print_ptx_ = False ):
46+ if enable_ir_printing :
47+ print_ptx_ = True
48+ mod = run_pipeline (
49+ module ,
50+ Pipeline ().add_pass (
51+ "gpu-lower-to-nvvm-pipeline" ,
52+ # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
53+ ** {
54+ "cubin-chip" : "sm_80" ,
55+ "cubin-features" : "+ptx83" ,
56+ "cubin-format" : "isa" ,
57+ "kernel-bare-ptr-calling-convention" : "1" ,
58+ "opt-level" : "2" ,
59+ # "cubin-format": "fatbin",
60+ # "cubin-format": "bin",
61+ },
62+ ),
63+ enable_ir_printing = enable_ir_printing ,
64+ )
65+ if print_ptx_ :
66+ print_ptx (mod )
67+
68+ return mod
69+
70+
4071@contextlib .contextmanager
4172def time_cuda ():
4273 start_gpu = cp .cuda .Event ()
@@ -50,80 +81,254 @@ def time_cuda():
5081
5182@gpu .func
5283@canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
53- def mat_product_kernel [
84+ def sgemm_naive [
85+ M , K , N , dtype
86+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
87+ one = arith .constant (1.0 , type = dtype )
88+ tmp = arith .constant (0 , type = dtype )
89+
90+ # this is from the example and it's basically a mistake
91+ # it increments the row for each adjacent thread id
92+ # uncomment the print to see
93+ r = block_dim .x * block_idx .x + thread_idx .x
94+ c = block_dim .y * block_idx .y + thread_idx .y
95+ # tid = gpu.thread_id()
96+ # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
97+
98+ for k , tmp in range_ (K , iter_args = [tmp ]):
99+ tmp += A [r , k ] * B [k , c ]
100+ tmp = yield tmp
101+ C [r , c ] = tmp + one
102+
103+
104+ @gpu .func
105+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
106+ def sgemm_naive_row_order [
54107 M , K , N , dtype
55108](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
56- x = block_dim .x * block_id .x + thread_id .x
57- y = block_dim .y * block_id .y + thread_id .y
109+ one = arith .constant (1.0 , type = dtype )
110+ tmp = arith .constant (0 , type = dtype )
111+
112+ # increment along the cols (ie preserve row-order access)
113+ c = block_dim .x * block_idx .x + thread_idx .x
114+ r = block_dim .y * block_idx .y + thread_idx .y
115+ # tid = gpu.thread_id()
116+ # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
117+
118+ for k , tmp in range_ (K , iter_args = [tmp ]):
119+ tmp += A [r , k ] * B [k , c ]
120+ tmp = yield tmp
121+ C [r , c ] = tmp + one
122+
123+
124+ @gpu .func
125+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
126+ def sgemm_coalesce [
127+ M , K , N , dtype , BLOCK_SIZE
128+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
129+
130+ tid = gpu .thread_id ()
131+ # this is actually floordiv
132+ r = block_idx .x * BLOCK_SIZE + (tid / BLOCK_SIZE )
133+ c = block_idx .y * BLOCK_SIZE + (tid % BLOCK_SIZE )
134+ # gpu.printf("tid: %ld: (%ld, %ld)\n", tid, r, c)
135+
136+ one = arith .constant (1.0 , type = dtype )
137+ tmp = arith .constant (0 , type = dtype )
138+
139+ for k , tmp in range_ (K , iter_args = [tmp ]):
140+ # k varies per core while c varies with tid
141+ # apparently that's fine? i guess all the loads can happen
142+ # because there's enough scratch per SM to prefetch all the data each thread needs?
143+ tmp += A [r , k ] * B [k , c ]
144+ tmp = yield tmp
145+ C [r , c ] = tmp + one
146+
147+
148+ # So if you try to load something like:
149+ #
150+ # B.T:
151+ #
152+ # 0 0 0 0 0 0 0 0
153+ # 1 1 1 1 1 1 1 1
154+ # 2 2 2 2 2 2 2 2
155+ #
156+ # vs
157+ #
158+ # B:
159+ # 0 1 2 3 4 5 6 7 8
160+ # 0 1 2 3 4 5 6 7 8
161+ # 0 1 2 3 4 5 6 7 8
162+ #
163+ # In B, you are feeding all threads with a single load (say warp can load 8 elements at a time) and then you increment k
164+ #
165+ # in B.T, a single load is feeding only a single thread, so others are probably waiting for their load to happen
166+ # these are the issues by threads:
167+ #
168+ # 0: (0, 0), (1, 0), (2, 0)
169+ # 1: (0, 1), (1, 1), (2, 1)
170+ # 2: (0, 2), (1, 2), (2, 2)
171+ #
172+ # warp recieves these issues:
173+ #
174+ # (0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)
175+ #
176+ # warp issues coalesced reads:
177+ #
178+ # (0, 0:2), (1, 0:2), (2,0:2)
179+ # so even though the threads have bad memory access pattern
180+ # the warp has good memory access pattern
181+ # and since the actual load happens at warp level
182+ # its good
183+ @gpu .func
184+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
185+ def sgemm_coalesce_transpose_B [
186+ M , K , N , dtype , BLOCK_SIZE
187+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
188+
189+ tid = gpu .thread_id ()
190+ r = block_idx .x * BLOCK_SIZE + (tid / BLOCK_SIZE )
191+ c = block_idx .y * BLOCK_SIZE + (tid % BLOCK_SIZE )
58192
59193 one = arith .constant (1.0 , type = dtype )
60194 tmp = arith .constant (0 , type = dtype )
195+
61196 for k , tmp in range_ (K , iter_args = [tmp ]):
62- tmp += A [x , k ] * B [k , y ]
197+ # this is slower because c is incremented with each tid
198+ # so you break memory coalescing
199+ # but k now being on the row order dim doesn't help?
200+ tmp += A [r , k ] * B [c , k ]
201+ tmp = yield tmp
202+ C [r , c ] = tmp + one
203+
204+
205+ @gpu .func
206+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
207+ def sgemm_shared_mem_block [
208+ M , K , N , dtype , BLOCK_SIZE
209+ ](A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )):
210+ # allocate buffer for current block in fast shared mem
211+ # shared mem is shared between all threads in a block
212+ base = gpu .dynamic_shared_memory ()
213+ A_shared = memref .view (base , (BLOCK_SIZE , BLOCK_SIZE ), dtype = dtype )
214+ B_shared = memref .view (
215+ base , (BLOCK_SIZE , BLOCK_SIZE ), dtype = dtype , shift = BLOCK_SIZE * BLOCK_SIZE
216+ )
217+
218+ # the inner row & col that we're accessing in this thread
219+ tid = gpu .thread_id ()
220+ thread_row = tid / BLOCK_SIZE
221+ thread_col = tid % BLOCK_SIZE
222+
223+ # the output block that we want to compute in this threadblock
224+ c_row = block_idx .x * BLOCK_SIZE
225+ c_col = block_idx .y * BLOCK_SIZE
226+
227+ one = arith .constant (1.0 , type = dtype )
228+ tmp = arith .constant (0 , type = dtype )
229+
230+ for bk_idx , tmp in range_ (0 , K , BLOCK_SIZE , iter_args = [tmp ]):
231+ A_ = A [c_row : c_row + BLOCK_SIZE , bk_idx : bk_idx + BLOCK_SIZE ]
232+ B_ = B [bk_idx : bk_idx + BLOCK_SIZE , c_col : c_col + BLOCK_SIZE ]
233+
234+ # Have each thread load one of the elements in A & B
235+ # Make the threadCol (=threadIdx.x) the consecutive index
236+ # to allow global memory access coalescing
237+ A_shared [thread_row , thread_col ] = A_ [thread_row , thread_col ]
238+ B_shared [thread_row , thread_col ] = B_ [thread_row , thread_col ]
239+
240+ # block threads in this block until cache is fully populated
241+ gpu .barrier ()
242+
243+ # execute the dotproduct on the currently cached block
244+ for k , tmp in range_ (BLOCK_SIZE , iter_args = [tmp ]):
245+ tmp += A_shared [thread_row , k ] * B_shared [k , thread_col ]
246+ tmp = yield tmp
247+
248+ # need to sync again at the end, to avoid faster threads
249+ # fetching the next block into the cache before slower threads are done
250+ gpu .barrier ()
251+
63252 tmp = yield tmp
64- C [x , y ] = tmp + one
253+
254+ C_ = C [c_row : c_row + BLOCK_SIZE , c_col : c_col + BLOCK_SIZE ]
255+ C_ [thread_row , thread_col ] = tmp + one
65256
66257
67- def main (ctx : MLIRContext , M , K , N , BLOCK_SIZE = 32 , repeat_times = 50 ):
258+ def main (ctx : MLIRContext , M , K , N , BLOCK_SIZE = 32 , repeat_times = None ):
259+ if repeat_times is None :
260+ repeat_times = 50
68261 dtype = T .f32 ()
69262 npy_dtype = np .float32
70263
71264 gpu .set_container_module (ctx .module )
72265
73- @gpu .module ("naive " , ["#nvvm.target" ])
74- def _ ():
75- mat_product_kernel [M , K , N , dtype ].emit ()
266+ @gpu .module ("matmul " , ["#nvvm.target" ])
267+ def matmul_mod ():
268+ sgemm_shared_mem_block [M , K , N , dtype , BLOCK_SIZE ].emit ()
76269
77270 # print(ctx.module)
78- ctx .module .operation .verify ()
271+ # print(ctx.module.operation.verify())
272+ # exit()
79273
80- compiled_module = run_pipeline (
81- ctx .module ,
82- Pipeline ().add_pass (
83- "gpu-lower-to-nvvm-pipeline" ,
84- # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
85- ** {
86- "cubin-chip" : "sm_80" ,
87- "cubin-features" : "+ptx83" ,
88- "cubin-format" : "isa" ,
89- "kernel-bare-ptr-calling-convention" : "1" ,
90- # "cubin-format": "fatbin",
91- # "cubin-format": "bin",
92- },
93- ),
94- )
95- cuda_func = build_cuda_func (compiled_module )
96- # print(compiled_module)
274+ kernel_name = matmul_mod .opview .body .operations [0 ].attributes ["sym_name" ].value
275+ compiled_module = compile_module (ctx .module )
276+ cuda_func = build_cuda_func (compiled_module , kernel_name )
97277 # print_ptx(compiled_module)
98278
99279 A = np .random .randint (0 , 10 , (M , K )).astype (npy_dtype )
100280 B = np .random .randint (0 , 10 , (K , N )).astype (npy_dtype )
101281 C = np .zeros ((M , N )).astype (npy_dtype )
102282
103283 dA = cp .asarray (A )
104- dB = cp .asarray (B )
284+ if "transpose_B" in kernel_name :
285+ dB = cp .asarray (np .ascontiguousarray (B .T ))
286+ else :
287+ dB = cp .asarray (B )
105288 dC = cp .asarray (C )
106289
290+ grid_dims = (math .ceil (M / BLOCK_SIZE ), math .ceil (N / BLOCK_SIZE ))
291+ block_dims = (BLOCK_SIZE , BLOCK_SIZE )
292+
293+ if "shared" in kernel_name :
294+ shared_mem = 2 * BLOCK_SIZE * BLOCK_SIZE * npy_dtype ().nbytes
295+ else :
296+ shared_mem = None
297+
298+ cuda_func (
299+ grid_dims ,
300+ block_dims ,
301+ (dA .data .ptr , dB .data .ptr , dC .data .ptr ),
302+ shared_mem = shared_mem ,
303+ )
304+ C = cp .asnumpy (dC )
305+ if not np .array_equal (C , A @ B + 1 ):
306+ print (A @ B + 1 )
307+ print (C )
308+ assert False
309+ if repeat_times < 1 :
310+ return
311+
107312 with time_cuda () as (start_gpu , end_gpu ):
108313 for _ in range (repeat_times ):
109314 cuda_func (
110- ( math . ceil ( M / BLOCK_SIZE ), math . ceil ( N / BLOCK_SIZE ), 1 ) ,
111- ( BLOCK_SIZE , BLOCK_SIZE , 1 ) ,
315+ grid_dims ,
316+ block_dims ,
112317 (dA .data .ptr , dB .data .ptr , dC .data .ptr ),
318+ shared_mem = shared_mem ,
113319 )
114320
115321 t_gpu = cp .cuda .get_elapsed_time (start_gpu , end_gpu )
116322
117323 print (f"t_gpu={ t_gpu / repeat_times :.6f} ms" )
118324
119- if not cp .array_equal (dC , dA @ dB + 1 ):
120- print (dA @ dB + 1 )
121- print (dC )
122325
326+ sizes = [128 , 256 , 512 , 1024 ]
327+ repeats = None
123328
124- for s in [ 128 , 256 , 512 , 1024 ] :
329+ for s in sizes :
125330 with (
126331 mlir_mod_ctx () as ctx ,
127332 # enable_debug()
128333 ):
129- main (ctx , s , s , s )
334+ main (ctx , s , s , s , repeat_times = repeats )
0 commit comments