diff --git a/mlir/extras/dialects/ext/gpu.py b/mlir/extras/dialects/ext/gpu.py index 176fe87..9b12348 100644 --- a/mlir/extras/dialects/ext/gpu.py +++ b/mlir/extras/dialects/ext/gpu.py @@ -49,43 +49,43 @@ def __get__(self, owner_self, owner_cls): class block_idx: @classproperty def x(cls): - return _block_id("x") + return _block_id("x", loc=get_user_code_loc()) @classproperty def y(cls): - return _block_id("y") + return _block_id("y", loc=get_user_code_loc()) @classproperty def z(cls): - return _block_id("z") + return _block_id("z", loc=get_user_code_loc()) class block_dim: @classproperty def x(cls): - return _block_dim("x") + return _block_dim("x", loc=get_user_code_loc()) @classproperty def y(cls): - return _block_dim("y") + return _block_dim("y", loc=get_user_code_loc()) @classproperty def z(cls): - return _block_dim("z") + return _block_dim("z", loc=get_user_code_loc()) class thread_idx: @classproperty def x(cls): - return _thread_id("x") + return _thread_id("x", loc=get_user_code_loc()) @classproperty def y(cls): - return _thread_id("y") + return _thread_id("y", loc=get_user_code_loc()) @classproperty def z(cls): - return _thread_id("z") + return _thread_id("z", loc=get_user_code_loc()) def thread_id(): @@ -222,6 +222,8 @@ def __init__( loc=None, ip=None, ): + if loc is None: + loc = get_user_code_loc() super().__init__( function_type=function_type, arg_attrs=arg_attrs, @@ -301,10 +303,10 @@ def launch_( ): if loc is None: loc = get_user_code_loc() - for size in [grid_size, block_size]: - for i, s in enumerate(size): - if isinstance(s, int): - size[i] = constant(s, index=True) + for size in [grid_size, block_size]: + for i, s in enumerate(size): + if isinstance(s, int): + size[i] = constant(s, index=True) launch_op = LaunchOp( grid_size, block_size, @@ -371,13 +373,16 @@ def __call__( async_dependencies=None, dynamic_shared_memory_size: Optional[Value] = None, stream=None, + loc=None, + ip=None, ): for size in [grid_size, block_size]: for i, s in enumerate(size): if isinstance(s, int): size[i] = constant(s, index=True) - loc = get_user_code_loc() + if loc is None: + loc = get_user_code_loc() return get_op_result_or_op_results( LaunchFuncOp( ( @@ -469,6 +474,8 @@ def all_reduce__(value: Value, *, op=None, uniform=None, loc=None, ip=None): def all_reduce_(value: Value, *, op=None, uniform=None, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() return get_op_result_or_op_results( all_reduce__(value, op=op, uniform=uniform, loc=loc, ip=ip) ) @@ -577,15 +584,18 @@ def get_compile_object_bytes(compiled_module): _printf = printf -def printf(format, *args): - loc = get_user_code_loc() - return _printf(format=format, args=args, loc=loc) +def printf(format, *args, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() + return _printf(format=format, args=args, loc=loc, ip=ip) _dynamic_shared_memory = dynamic_shared_memory def dynamic_shared_memory(*, int=False, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() return _dynamic_shared_memory( T.memref( ShapedType.get_dynamic_size(), @@ -611,3 +621,10 @@ def memset(dst, value, async_dependencies=None, *, loc=None, ip=None): if isinstance(value, (int, float, bool)): value = constant(value, type=dst.type.element_type) return _memset(async_token, async_dependencies, dst, value, loc=loc, ip=ip) + + +def barrier(*, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() + + return BarrierOp(loc=loc, ip=ip) diff --git a/mlir/extras/dialects/ext/memref.py b/mlir/extras/dialects/ext/memref.py index 2a710e7..de6a8c6 100644 --- a/mlir/extras/dialects/ext/memref.py +++ b/mlir/extras/dialects/ext/memref.py @@ -281,6 +281,8 @@ def _canonicalize_start_stop(start, stop, step): elif isinstance(start, int) and isinstance(stop, int): return stop - start + raise NotImplementedError + def _subview( mem: MemRef, @@ -362,6 +364,8 @@ def _copy_to_subview( def dim(source, index, *, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() if isinstance(index, int): index = constant(index, index=True) return _dim(source=source, index=index, loc=loc, ip=ip) @@ -412,7 +416,9 @@ def global_( ).opview -def view(source, shape, dtype=None, shift=0, memory_space=None): +def view(source, shape, dtype=None, shift=0, memory_space=None, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() if dtype is None: dtype = source.type.element_type byte_width_dtype = dtype.width // 8 @@ -425,6 +431,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None): source, byte_shift, [], + loc=loc, + ip=ip, ) @@ -434,6 +442,8 @@ def view(source, shape, dtype=None, shift=0, memory_space=None): def get_global( name_or_global, *, name=None, global_=None, result=None, loc=None, ip=None ): + if loc is None: + loc = get_user_code_loc() if isinstance(name_or_global, GlobalOp): global_ = name_or_global elif isinstance(name_or_global, str): diff --git a/mlir/extras/dialects/ext/rocdl.py b/mlir/extras/dialects/ext/rocdl.py index 28b19ae..3b89210 100644 --- a/mlir/extras/dialects/ext/rocdl.py +++ b/mlir/extras/dialects/ext/rocdl.py @@ -24,6 +24,8 @@ class WMMA_F16_16X16X16_F16(ir.OpView): _ODS_REGIONS = (0, True) def __init__(self, res, args, *, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() operands = [] results = [] attributes = {} @@ -56,5 +58,11 @@ def res(self): return self.operation.results[0] -def wmma_f16_16x16x16_f16(res, args, *, loc=None, ip=None) -> ir.Value: - return WMMA_F16_16X16X16_F16(res=res, args=args, loc=loc, ip=ip).result +def wmma_f16_16x16x16_f16(A, B, C, *, OPSEL=False, loc=None, ip=None) -> ir.Value: + if loc is None: + loc = get_user_code_loc() + + opsel = arith.constant(OPSEL, ir.IntegerType.get_signless(1)) + args = [A, B, C, opsel] + v16 = ir.VectorType.get((16,), ir.F16Type.get()) + return WMMA_F16_16X16X16_F16(res=v16, args=args, loc=loc, ip=ip).result diff --git a/mlir/extras/dialects/ext/vector.py b/mlir/extras/dialects/ext/vector.py index 2a7b118..4ab7d09 100644 --- a/mlir/extras/dialects/ext/vector.py +++ b/mlir/extras/dialects/ext/vector.py @@ -251,6 +251,8 @@ def extract_strided_slice(vector, offsets, sizes, strides, *, loc=None, ip=None) def outerproduct(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() if kind is None: kind = CombiningKind.ADD result_shape = [lhs.shape[0], rhs.shape[0]] @@ -262,6 +264,8 @@ def outerproduct(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None): @Infix def outer(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() return outerproduct(lhs, rhs, acc, kind=kind, loc=loc, ip=ip) @@ -270,14 +274,20 @@ def outer(lhs, rhs, acc=None, *, kind=None, loc=None, ip=None): @Infix def shuffle(v1, v2, mask, *, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() return ShuffleOp(v1=v1, v2=v2, mask=mask, loc=loc, ip=ip).result _load = load -@Infix -def load(base, indices, result, *, nontemporal=None, loc=None, ip=None): +def load_(base, indices, result, *, nontemporal=None, loc=None, ip=None): + if loc is None: + loc = get_user_code_loc() + for j, i in enumerate(indices): + if isinstance(i, int): + indices[j] = constant(i, index=True) return LoadOp( result=result, base=base, @@ -286,3 +296,6 @@ def load(base, indices, result, *, nontemporal=None, loc=None, ip=None): loc=loc, ip=ip, ).result + + +load = Infix(load_) diff --git a/mlir/extras/runtime/passes.py b/mlir/extras/runtime/passes.py index 43823e4..823bd7f 100644 --- a/mlir/extras/runtime/passes.py +++ b/mlir/extras/runtime/passes.py @@ -31,7 +31,7 @@ def run_pipeline( print_pipeline=False, verify=True, ): - module = Module.parse(str(module)) + module = Module.parse(module.operation.get_asm(enable_debug_info=True)) if isinstance(pipeline, Pipeline): pipeline = str(pipeline) diff --git a/tests/test_gpu.py b/tests/test_gpu.py index 615f721..7abccb4 100644 --- a/tests/test_gpu.py +++ b/tests/test_gpu.py @@ -15,7 +15,7 @@ from mlir.dialects.memref import cast from mlir.extras.ast.canonicalize import canonicalize -from mlir.extras.dialects.ext import arith, scf, memref, rocdl +from mlir.extras.dialects.ext import arith, scf, memref, rocdl, gpu from mlir.extras.dialects.ext.func import func # noinspection PyUnresolvedReferences @@ -758,7 +758,7 @@ def mat_product_kernel( props = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(props, 0)) - arch = props.gcnArchName.decode() + arch = props.gcnArchName.decode().split(":")[0] @module("naive", [f'#rocdl.target']) def gpu_module(): @@ -869,7 +869,7 @@ def mat_product_kernel( props = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(props, 0)) - arch = props.gcnArchName.decode() + arch = props.gcnArchName.decode().split(":")[0] @module("naive", [f'#rocdl.target']) def gpu_module(): @@ -996,7 +996,7 @@ def smol_matmul( props = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(props, 0)) - arch = props.gcnArchName.decode() + arch = props.gcnArchName.decode().split(":")[0] @module("naive", [f'#rocdl.target']) def gpu_module(): @@ -1104,7 +1104,7 @@ def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())): props = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(props, 0)) - arch = props.gcnArchName.decode() + arch = props.gcnArchName.decode().split(":")[0] @module("naive", [f'#rocdl.target']) def gpu_module(): @@ -1228,9 +1228,10 @@ def smol_matmul( a_frag[ele] = a[lane, ele] a_frag, b_frag = yield a_frag, b_frag - # call the WMMA intrinsic - false = arith.constant(False, T.bool()) - c_frag = rocdl.wmma_f16_16x16x16_f16(v16f16, [a_frag, b_frag, c_frag, false]) + c_frag = rocdl.wmma_f16_16x16x16_f16(a_frag, b_frag, c_frag) + + for i in scf.range_(v_len): + gpu.printf("(%02ld, %02ld, %02ld), %f\n", lIdx, lane, i, c_frag[i]) for ele in scf.range_(v_len // 2): r = ele * 2 + (lIdx // v_len) @@ -1239,7 +1240,7 @@ def smol_matmul( props = hip.hipDeviceProp_t() hip_check(hip.hipGetDeviceProperties(props, 0)) - arch = props.gcnArchName.decode() + arch = props.gcnArchName.decode().split(":")[0] @module("naive", [f'#rocdl.target']) def gpu_module(): @@ -1250,7 +1251,11 @@ def gpu_module(): lowered_module = run_pipeline( gpu_module, Pipeline() - .Gpu(Pipeline().convert_gpu_to_rocdl(use_bare_ptr_memref_call_conv=True)) + .Gpu( + Pipeline().convert_gpu_to_rocdl( + use_bare_ptr_memref_call_conv=True, runtime="HIP" + ) + ) .rocdl_attach_target(chip=arch, abi="500") .gpu_to_llvm() .lower_to_llvm()