Skip to content

Commit add8434

Browse files
committed
vectors in gpu
1 parent 5ab8373 commit add8434

File tree

7 files changed

+289
-60
lines changed

7 files changed

+289
-60
lines changed

examples/cuda_matmul_opt.py

Lines changed: 148 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
mlir_mod_ctx,
1414
MLIRContext,
1515
)
16-
from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg
16+
from mlir.extras.dialects.ext import arith, memref, gpu, scf, linalg, vector
1717
from mlir.extras.dialects.ext.gpu import (
1818
block_idx,
1919
thread_idx,
2020
block_dim,
2121
get_compile_object_bytes,
2222
)
23+
from mlir.extras.dialects.ext.memref import S
2324
from mlir.extras.dialects.ext.scf import range_
2425
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2526

@@ -47,23 +48,62 @@ def compile_module(module, enable_ir_printing=False, print_ptx_=False):
4748
print_ptx_ = True
4849
mod = run_pipeline(
4950
module,
51+
# if you're not using vectors you can just uncomment the gpu-lower-to-nvvm-pipeline below
5052
Pipeline()
5153
.convert_linalg_to_loops()
54+
.convert_nvgpu_to_nvvm()
55+
.gpu_kernel_outlining()
56+
.convert_vector_to_scf()
57+
.convert_scf_to_cf()
58+
.convert_nvvm_to_llvm()
59+
.convert_func_to_llvm()
60+
.expand_strided_metadata()
5261
.add_pass(
53-
"gpu-lower-to-nvvm-pipeline",
54-
# https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
62+
"nvvm-attach-target",
5563
**{
56-
"cubin-chip": "sm_80",
57-
"cubin-features": "+ptx83",
58-
"cubin-format": "isa",
59-
"kernel-bare-ptr-calling-convention": "1",
60-
"opt-level": "2",
61-
# "cubin-format": "fatbin",
62-
# "cubin-format": "bin",
64+
"chip": "sm_80",
65+
"features": "+ptx83",
66+
"O": "2",
6367
},
64-
),
68+
)
69+
.lower_affine()
70+
.convert_arith_to_llvm()
71+
.convert_index_to_llvm()
72+
.canonicalize()
73+
.cse()
74+
.Gpu(
75+
Pipeline()
76+
.strip_debuginfo()
77+
# TODO(max): upstream this (add to gpu pipeline)
78+
# vector.transfer
79+
.convert_vector_to_llvm()
80+
.convert_gpu_to_nvvm(use_bare_ptr_memref_call_conv=True)
81+
.canonicalize()
82+
.cse()
83+
.reconcile_unrealized_casts()
84+
)
85+
.gpu_to_llvm(use_bare_pointers_for_kernels=True)
86+
.gpu_module_to_binary(format="isa")
87+
.canonicalize()
88+
.cse()
89+
.reconcile_unrealized_casts()
90+
# .add_pass(
91+
# "gpu-lower-to-nvvm-pipeline",
92+
# # https://github.com/llvm/llvm-project/blob/ace69e6b942b8fa7e610d70be2a92e801ceea481/mlir/include/mlir/Dialect/GPU/Pipelines/Passes.h#L18
93+
# **{
94+
# "cubin-chip": "sm_80",
95+
# "cubin-features": "+ptx83",
96+
# "cubin-format": "isa",
97+
# "kernel-bare-ptr-calling-convention": "1",
98+
# "opt-level": "2",
99+
# # "cubin-format": "fatbin",
100+
# # "cubin-format": "bin",
101+
# },
102+
# )
103+
,
65104
enable_ir_printing=enable_ir_printing,
66105
)
106+
67107
if print_ptx_:
68108
print_ptx(mod)
69109

@@ -420,6 +460,102 @@ def sgemm_shared_mem_2d_block_tiling[
420460
)
421461

422462

463+
@gpu.func
464+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
465+
def sgemm_shared_mem_2d_block_tiling_vectorize[
466+
M, K, N, dtype, BM, BN, BK, TM, TN
467+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
468+
VECTOR_WIDTH = 4
469+
DTYPE_WIDTH = dtype.width // 8
470+
471+
# ld.global.v4.u32 and st.global.v4.f32 emitted only input args are aligned
472+
# alignment for cupy is 512 bytes https://github.com/cupy/cupy/blob/59e6c2b2e0c722b09c7a7af13f908942ef7806cc/cupy/cuda/memory.pyx#L805-L809
473+
# so we're good
474+
memref.assume_alignment(A, VECTOR_WIDTH * DTYPE_WIDTH)
475+
memref.assume_alignment(B, VECTOR_WIDTH * DTYPE_WIDTH)
476+
memref.assume_alignment(C, VECTOR_WIDTH * DTYPE_WIDTH)
477+
478+
base = gpu.dynamic_shared_memory()
479+
base = memref.memory_space_cast(T.memref(S, element_type=T.i8()), base)
480+
481+
# transpose A
482+
A_shared = memref.view(base, (BK, BM), dtype=dtype)
483+
B_shared = memref.view(base, (BK, BN), dtype=dtype, shift=BM * BK)
484+
485+
c_row = block_idx.y * BM
486+
c_col = block_idx.x * BN
487+
488+
tid = gpu.thread_id()
489+
# BN/TN are the number of threads to span a column
490+
thread_col = tid % (BN // TN)
491+
thread_row = tid / (BN // TN)
492+
493+
# calculating the indices that this thread will load into SMEM
494+
# we'll load 128bit / 32bit = 4 elements per thread at each step
495+
inner_col_A = tid % (BK // VECTOR_WIDTH) # warp-level GMEM coalescing
496+
inner_row_A = tid / (BK // VECTOR_WIDTH)
497+
inner_col_B = tid % (BN // VECTOR_WIDTH) # warp-level GMEM coalescing
498+
inner_row_B = tid / (BN // VECTOR_WIDTH)
499+
500+
thread_results = memref.alloca((TM, TN), dtype)
501+
linalg.fill(0, thread_results)
502+
503+
reg_M = memref.alloca((TM,), dtype)
504+
linalg.fill(0, reg_M)
505+
506+
reg_N = memref.alloca((TN,), dtype)
507+
linalg.fill(0, reg_N)
508+
509+
for bk_idx in range_(0, K, BK):
510+
A_ = A[c_row : c_row + BM, bk_idx : bk_idx + BK]
511+
B_ = B[bk_idx : bk_idx + BK, c_col : c_col + BN]
512+
513+
A_vec = vector.load(
514+
T.vector(VECTOR_WIDTH, dtype), A_, [inner_row_A, inner_col_A * VECTOR_WIDTH]
515+
)
516+
for j in range(VECTOR_WIDTH):
517+
# transpose A while loading it
518+
A_shared[inner_col_A * VECTOR_WIDTH + j, inner_row_A] = A_vec[j]
519+
520+
B_vec = vector.load(
521+
T.vector(VECTOR_WIDTH, dtype), B_, [inner_row_B, inner_col_B * VECTOR_WIDTH]
522+
)
523+
vector.store(B_vec, B_shared, [inner_row_B, inner_col_B * VECTOR_WIDTH])
524+
525+
gpu.barrier()
526+
527+
for dot_idx in range_(BK):
528+
for i in range_(TM):
529+
reg_M[i] = A_shared[dot_idx, thread_row * TM + i]
530+
531+
for i in range_(TN):
532+
reg_N[i] = B_shared[dot_idx, thread_col * TN + i]
533+
534+
for res_idx_m in range_(TM):
535+
for res_idx_n in range_(TN):
536+
thread_results[res_idx_m, res_idx_n] += (
537+
reg_M[res_idx_m] * reg_N[res_idx_n]
538+
)
539+
540+
gpu.barrier()
541+
542+
one = arith.constant(1.0, type=dtype)
543+
C_ = C[c_row : c_row + BM, c_col : c_col + BN]
544+
545+
for res_idx_m in range_(TM):
546+
for res_idx_n in range_(0, TN, VECTOR_WIDTH):
547+
tmp = vector.load(
548+
T.vector(VECTOR_WIDTH, dtype),
549+
C_,
550+
[thread_row * TM + res_idx_m, thread_col * TN + res_idx_n],
551+
)
552+
for j in range(VECTOR_WIDTH):
553+
tmp[j] = thread_results[res_idx_m, res_idx_n + j] + one
554+
vector.store(
555+
tmp, C_, [thread_row * TM + res_idx_m, thread_col * TN + res_idx_n]
556+
)
557+
558+
423559
def prepare_tiled_kernel(ctx: MLIRContext, kernel, M, K, N):
424560
dtype = T.f32()
425561
npy_dtype = np.float32
@@ -560,6 +696,7 @@ def run_eval(
560696
for k in [
561697
sgemm_shared_mem_1d_block_tiling,
562698
sgemm_shared_mem_2d_block_tiling,
699+
sgemm_shared_mem_2d_block_tiling_vectorize,
563700
]:
564701
print(f"\n{k.__name__}")
565702
for s in sizes:

examples/mlir_python_extras.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"from mlir.extras.dialects.ext.arith import constant\n",
5757
"from mlir.extras.dialects.ext.memref import S\n",
5858
"from mlir.extras.dialects.ext.func import func\n",
59-
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_ as range\n",
59+
"from mlir.extras.dialects.ext.scf import canonicalizer as scf, range_\n",
6060
"from mlir.extras.runtime.passes import Pipeline, run_pipeline\n",
6161
"from mlir.extras.runtime.refbackend import LLVMJITBackend\n",
6262
"from mlir.ir import StridedLayoutAttr\n",
@@ -102,8 +102,8 @@
102102
" if one > two:\n",
103103
" C[0, 0] = constant(3, T.i64())\n",
104104
" else:\n",
105-
" for i in range(0, K):\n",
106-
" for j in range(0, K):\n",
105+
" for i in range_(0, K):\n",
106+
" for j in range_(0, K):\n",
107107
" C[i, j] = A[i, j] * B[i, j]"
108108
]
109109
},
@@ -457,17 +457,17 @@
457457
"def tile(\n",
458458
" A: ranked_memref_dxd_f32, B: ranked_memref_dxd_f32, C: ranked_memref_dxd_f32\n",
459459
"):\n",
460-
" for i in range(0, D):\n",
461-
" for j in range(0, D):\n",
460+
" for i in range_(0, D):\n",
461+
" for j in range_(0, D):\n",
462462
" C[i, j] = A[i, j] + B[i, j]\n",
463463
"\n",
464464
"@func(emit=True)\n",
465465
"@canonicalize(using=scf)\n",
466466
"def tiled_memfoo(\n",
467467
" A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n",
468468
"):\n",
469-
" for i in range(0, F):\n",
470-
" for j in range(0, F):\n",
469+
" for i in range_(0, F):\n",
470+
" for j in range_(0, F):\n",
471471
" l = lambda l: l * D\n",
472472
" r = lambda r: (r + 1) * D\n",
473473
" a, b, c = (\n",
@@ -797,8 +797,8 @@
797797
"def linalg_memfoo(\n",
798798
" A: ranked_memref_kxk_f32, B: ranked_memref_kxk_f32, C: ranked_memref_kxk_f32\n",
799799
"):\n",
800-
" for i in range(0, F):\n",
801-
" for j in range(0, F):\n",
800+
" for i in range_(0, F):\n",
801+
" for j in range_(0, F):\n",
802802
" l = lambda l: l * D\n",
803803
" r = lambda r: (r + 1) * D\n",
804804
" a, b, c = (\n",

mlir/extras/dialects/ext/gpu.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import Any, List, Optional, Tuple, Union
44

5+
from mlir.dialects._gpu_enum_gen import AddressSpace
6+
57
from .arith import constant
68
from .func import FuncBase
79
from ... import types as T
@@ -117,32 +119,39 @@ def get_device_mapping_array_attr(
117119
return ArrayAttr.get(mapping, context=context)
118120

119121

120-
def device_mapping_attr(mnemonic, mapping_id_enum: MappingId):
122+
def gpu_attr(mnemonic, mapping_id_enum: MappingId):
121123
return Attribute.parse(f"#gpu.{mnemonic}<{mapping_id_enum}>")
122124

123125

124126
def thread_attr(thread):
125-
return device_mapping_attr("thread", thread)
127+
return gpu_attr("thread", thread)
126128

127129

128130
def block_attr(block):
129-
return device_mapping_attr("block", block)
131+
return gpu_attr("block", block)
130132

131133

132134
def warp_attr(warp):
133-
return device_mapping_attr("warp", warp)
135+
return gpu_attr("warp", warp)
134136

135137

136138
def warpgroup_attr(warpgroup):
137-
return device_mapping_attr("warpgroup", warpgroup)
139+
return gpu_attr("warpgroup", warpgroup)
138140

139141

140142
def address_space_attr(address_space: AddressSpace):
141-
return device_mapping_attr("address_space", address_space)
143+
return gpu_attr("address_space", address_space)
144+
145+
146+
_int = int
147+
142148

149+
def smem_space(int=False):
150+
a = AddressSpace.Workgroup
151+
if int:
152+
return _int(a)
143153

144-
def smem_space():
145-
return address_space_attr(AddressSpace.Workgroup)
154+
return address_space_attr(a)
146155

147156

148157
@_cext.register_operation(_Dialect, replace=True)
@@ -577,12 +586,12 @@ def printf(format, *args):
577586
_dynamic_shared_memory = dynamic_shared_memory
578587

579588

580-
def dynamic_shared_memory(*, loc=None, ip=None):
589+
def dynamic_shared_memory(*, int=False, loc=None, ip=None):
581590
return _dynamic_shared_memory(
582591
T.memref(
583592
ShapedType.get_dynamic_size(),
584593
element_type=T.i8(),
585-
memory_space=smem_space(),
594+
memory_space=smem_space(int),
586595
),
587596
loc=loc,
588597
ip=ip,

0 commit comments

Comments
 (0)