Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions mlir/extras/dialects/ext/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,43 +49,43 @@ def __get__(self, owner_self, owner_cls):
class block_idx:
@classproperty
def x(cls):
return _block_id("x")
return _block_id("x", loc=get_user_code_loc())

@classproperty
def y(cls):
return _block_id("y")
return _block_id("y", loc=get_user_code_loc())

@classproperty
def z(cls):
return _block_id("z")
return _block_id("z", loc=get_user_code_loc())


class block_dim:
@classproperty
def x(cls):
return _block_dim("x")
return _block_dim("x", loc=get_user_code_loc())

@classproperty
def y(cls):
return _block_dim("y")
return _block_dim("y", loc=get_user_code_loc())

@classproperty
def z(cls):
return _block_dim("z")
return _block_dim("z", loc=get_user_code_loc())


class thread_idx:
@classproperty
def x(cls):
return _thread_id("x")
return _thread_id("x", loc=get_user_code_loc())

@classproperty
def y(cls):
return _thread_id("y")
return _thread_id("y", loc=get_user_code_loc())

@classproperty
def z(cls):
return _thread_id("z")
return _thread_id("z", loc=get_user_code_loc())


def thread_id():
Expand Down Expand Up @@ -222,6 +222,8 @@ def __init__(
loc=None,
ip=None,
):
if loc is None:
loc = get_user_code_loc()
super().__init__(
function_type=function_type,
arg_attrs=arg_attrs,
Expand Down Expand Up @@ -301,10 +303,10 @@ def launch_(
):
if loc is None:
loc = get_user_code_loc()
for size in [grid_size, block_size]:
for i, s in enumerate(size):
if isinstance(s, int):
size[i] = constant(s, index=True)
for size in [grid_size, block_size]:
for i, s in enumerate(size):
if isinstance(s, int):
size[i] = constant(s, index=True)
launch_op = LaunchOp(
grid_size,
block_size,
Expand Down Expand Up @@ -371,13 +373,16 @@ def __call__(
async_dependencies=None,
dynamic_shared_memory_size: Optional[Value] = None,
stream=None,
loc=None,
ip=None,
):
for size in [grid_size, block_size]:
for i, s in enumerate(size):
if isinstance(s, int):
size[i] = constant(s, index=True)

loc = get_user_code_loc()
if loc is None:
loc = get_user_code_loc()
return get_op_result_or_op_results(
LaunchFuncOp(
(
Expand Down Expand Up @@ -469,6 +474,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None):


def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return get_op_result_or_op_results(
all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip)
)
Expand Down Expand Up @@ -577,15 +584,18 @@ def get_compile_object_bytes(compiled_module):
_printf = printf


def printf(format, *args):
loc = get_user_code_loc()
return _printf(format=format, args=args, loc=loc)
def printf(format, *args, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return _printf(format=format, args=args, loc=loc, ip=ip)


_dynamic_shared_memory = dynamic_shared_memory


def dynamic_shared_memory(*, int=False, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return _dynamic_shared_memory(
T.memref(
ShapedType.get_dynamic_size(),
Expand All @@ -611,3 +621,10 @@ def memset(dst, value, async_dependencies=None, *, loc=None, ip=None):
if isinstance(value, (int, float, bool)):
value = constant(value, type=dst.type.element_type)
return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip)


def barrier(*, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()

return BarrierOp(loc=loc, ip=ip)
12 changes: 11 additions & 1 deletion mlir/extras/dialects/ext/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def _canonicalize_start_stop(start, stop, step):
elif isinstance(start, int) and isinstance(stop, int):
return stop - start

raise NotImplementedError


def _subview(
mem: MemRef,
Expand Down Expand Up @@ -362,6 +364,8 @@ def _copy_to_subview(


def dim(source, index, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
if isinstance(index, int):
index = constant(index, index=True)
return _dim(source=source, index=index, loc=loc, ip=ip)
Expand Down Expand Up @@ -412,7 +416,9 @@ def global_(
).opview


def view(source, shape, dtype=None, shift=0, memory_space=None):
def view(source, shape, dtype=None, shift=0, memory_space=None, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
if dtype is None:
dtype = source.type.element_type
byte_width_dtype = dtype.width // 8
Expand All @@ -425,6 +431,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None):
source,
byte_shift,
[],
loc=loc,
ip=ip,
)


Expand All @@ -434,6 +442,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None):
def get_global(
name_or_global, *, name=None, global_=None, result=None, loc=None, ip=None
):
if loc is None:
loc = get_user_code_loc()
if isinstance(name_or_global, GlobalOp):
global_ = name_or_global
elif isinstance(name_or_global, str):
Expand Down
12 changes: 10 additions & 2 deletions mlir/extras/dialects/ext/rocdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class WMMA_F16_16X16X16_F16(ir.OpView):
_ODS_REGIONS = (0, True)

def __init__(self, res, args, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
operands = []
results = []
attributes = {}
Expand Down Expand Up @@ -56,5 +58,11 @@ def res(self):
return self.operation.results[0]


def wmma_f16_16x16x16_f16(res, args, *, loc=None, ip=None) -> ir.Value:
return WMMA_F16_16X16X16_F16(res=res, args=args, loc=loc, ip=ip).result
def wmma_f16_16x16x16_f16(A, B, C, *, OPSEL=False, loc=None, ip=None) -> ir.Value:
if loc is None:
loc = get_user_code_loc()

opsel = arith.constant(OPSEL, ir.IntegerType.get_signless(1))
args = [A, B, C, opsel]
v16 = ir.VectorType.get((16,), ir.F16Type.get())
return WMMA_F16_16X16X16_F16(res=v16, args=args, loc=loc, ip=ip).result
17 changes: 15 additions & 2 deletions mlir/extras/dialects/ext/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ def extract_strided_slice(vector, offsets, sizes, strides, *, loc=None, ip=None)


def outerproduct(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
if kind is None:
kind = CombiningKind.ADD
result_shape = [lhs.shape[0], rhs.shape[0]]
Expand All @@ -262,6 +264,8 @@ def outerproduct(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):

@Infix
def outer(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return outerproduct(lhs, rhs, acc, kind=kind, loc=loc, ip=ip)


Expand All @@ -270,14 +274,20 @@ def outer(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None):

@Infix
def shuffle(v1, v2, mask, *, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
return ShuffleOp(v1=v1, v2=v2, mask=mask, loc=loc, ip=ip).result


_load = load


@Infix
def load(base, indices, result, *, nontemporal=None, loc=None, ip=None):
def load_(base, indices, result, *, nontemporal=None, loc=None, ip=None):
if loc is None:
loc = get_user_code_loc()
for j, i in enumerate(indices):
if isinstance(i, int):
indices[j] = constant(i, index=True)
return LoadOp(
result=result,
base=base,
Expand All @@ -286,3 +296,6 @@ def load(base, indices, result, *, nontemporal=None, loc=None, ip=None):
loc=loc,
ip=ip,
).result


load = Infix(load_)
2 changes: 1 addition & 1 deletion mlir/extras/runtime/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def run_pipeline(
print_pipeline=False,
verify=True,
):
module = Module.parse(str(module))
module = Module.parse(module.operation.get_asm(enable_debug_info=True))

if isinstance(pipeline, Pipeline):
pipeline = str(pipeline)
Expand Down
25 changes: 15 additions & 10 deletions tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mlir.dialects.memref import cast

from mlir.extras.ast.canonicalize import canonicalize
from mlir.extras.dialects.ext import arith, scf, memref, rocdl
from mlir.extras.dialects.ext import arith, scf, memref, rocdl, gpu
from mlir.extras.dialects.ext.func import func

# noinspection PyUnresolvedReferences
Expand Down Expand Up @@ -758,7 +758,7 @@ def mat_product_kernel(

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()
arch = props.gcnArchName.decode().split(":")[0]

@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
def gpu_module():
Expand Down Expand Up @@ -869,7 +869,7 @@ def mat_product_kernel(

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()
arch = props.gcnArchName.decode().split(":")[0]

@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
def gpu_module():
Expand Down Expand Up @@ -996,7 +996,7 @@ def smol_matmul(

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()
arch = props.gcnArchName.decode().split(":")[0]

@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
def gpu_module():
Expand Down Expand Up @@ -1104,7 +1104,7 @@ def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())):

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()
arch = props.gcnArchName.decode().split(":")[0]

@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
def gpu_module():
Expand Down Expand Up @@ -1228,9 +1228,10 @@ def smol_matmul(
a_frag[ele] = a[lane, ele]
a_frag, b_frag = yield a_frag, b_frag

# call the WMMA intrinsic
false = arith.constant(False, T.bool())
c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false])
c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag)

for i in scf.range_(v_len):
gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i])

for ele in scf.range_(v_len // 2):
r = ele * 2 + (lIdx // v_len)
Expand All @@ -1239,7 +1240,7 @@ def smol_matmul(

props = hip.hipDeviceProp_t()
hip_check(hip.hipGetDeviceProperties(props, 0))
arch = props.gcnArchName.decode()
arch = props.gcnArchName.decode().split(":")[0]

@module("naive", [f'#rocdl.target<chip = "{arch}", abi = "500">'])
def gpu_module():
Expand All @@ -1250,7 +1251,11 @@ def gpu_module():
lowered_module = run_pipeline(
gpu_module,
Pipeline()
.Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True))
.Gpu(
Pipeline().convert_gpu_to_rocdl(
use_bare_ptr_memref_call_conv=True, runtime="HIP"
)
)
.rocdl_attach_target(chip=arch, abi="500")
.gpu_to_llvm()
.lower_to_llvm()
Expand Down