Skip to content

Commit b13e937

Browse files
committed
backport generics
1 parent e61b129 commit b13e937

File tree

7 files changed

+240
-67
lines changed

7 files changed

+240
-67
lines changed

examples/cuda_matmul_opt.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from __future__ import annotations
22

3-
import ast
43
import contextlib
54
import math
6-
import re
75

86
import cupy as cp
97
import mlir.extras.types as T
@@ -20,8 +18,8 @@
2018
block_id,
2119
thread_id,
2220
block_dim,
21+
get_compile_object_bytes,
2322
)
24-
from mlir.extras.dialects.ext.nvgpu import get_ptx
2523
from mlir.extras.dialects.ext.scf import range_
2624
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2725

@@ -33,9 +31,7 @@
3331

3432

3533
def build_cuda_func(compiled_module, kernel_name="mat_product_kernel"):
36-
ptx = get_ptx(compiled_module)
37-
ptx = re.sub(r"\\(\w\w)", lambda m: r"\x" + m.groups(0)[0].lower(), ptx)
38-
ptx = ast.literal_eval(rf"b'{ptx}'")
34+
ptx = get_compile_object_bytes(compiled_module)
3935
mod = Module()
4036
mod.load(ptx)
4137
return mod.get_function(kernel_name)

mlir/extras/dialects/ext/func.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
from dataclasses import dataclass
44
from functools import update_wrapper
5-
from typing import Optional
5+
from typing import Optional, List, Union, TypeVar
66

77
from ...ast.util import copy_func
88
from ...meta import op_region_builder
@@ -138,14 +138,16 @@ def __init__(
138138
func_op_ctor,
139139
return_op_ctor,
140140
call_op_ctor,
141+
*,
141142
return_types=None,
142143
sym_visibility=None,
143144
arg_attrs=None,
144145
res_attrs=None,
145146
func_attrs=None,
147+
generics: List[Union[TypeVar, ReifiedTypeParams]] = None,
148+
qualname=None,
146149
loc=None,
147150
ip=None,
148-
qualname=None,
149151
):
150152
assert inspect.isfunction(body_builder), body_builder
151153
assert inspect.isclass(func_op_ctor), func_op_ctor
@@ -159,6 +161,7 @@ def __init__(
159161
self.call_op_ctor = call_op_ctor
160162
self.arg_attrs = arg_attrs
161163
self.res_attrs = res_attrs
164+
self.generics = generics
162165
self.loc = loc
163166
self.ip = ip
164167
self._func_op = None
@@ -203,11 +206,8 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
203206
if len(call_args) == 0:
204207
input_types = self.input_types[:]
205208
locals = {}
206-
if (
207-
hasattr(self.body_builder, "__type_params__")
208-
and self.body_builder.__type_params__
209-
):
210-
for t in self.body_builder.__type_params__:
209+
if self.generics is not None:
210+
for t in self.generics:
211211
if not isinstance(t, ReifiedTypeParams):
212212
raise RuntimeError(f"{t=} must reified")
213213
locals[t.name] = t.val
@@ -267,17 +267,14 @@ def __call__(self, *call_args):
267267
return call(self.emit(*call_args), call_args)
268268

269269
def __getitem__(self, item):
270-
if (
271-
not hasattr(self.body_builder, "__type_params__")
272-
or not self.body_builder.__type_params__
273-
):
270+
if self.generics is None:
274271
raise RuntimeError(
275272
"using a generic call requires the func be generic (i.e., have type_params)"
276273
)
277-
# this also copies the function so that the original body_builder remains "generic"
274+
# this also copies the function so that the original body_builder remains "generic" (via its closure)
278275
body_builder = copy_func(self.body_builder)
279276
reified_type_params = []
280-
for i, t in enumerate(body_builder.__type_params__):
277+
for i, t in enumerate(self.generics):
281278
if t.__bound__ is not None:
282279
r = ReifiedTypeParams(t.__name__, t.__bound__)
283280
else:
@@ -287,21 +284,22 @@ def __getitem__(self, item):
287284
free_i = body_builder.__code__.co_freevars.index(r.name)
288285
body_builder.__closure__[free_i].cell_contents = r.val
289286

290-
body_builder.__type_params__ = tuple(reified_type_params)
287+
generics = tuple(reified_type_params)
291288

292289
return FuncBase(
293290
body_builder,
294291
self.func_op_ctor,
295292
self.return_op_ctor,
296293
self.call_op_ctor,
297-
self.return_types,
298-
self.sym_visibility,
299-
self.arg_attrs,
300-
self.res_attrs,
301-
self.func_attrs,
302-
self.loc,
303-
self.ip,
304-
self.qualname,
294+
return_types=self.return_types,
295+
sym_visibility=self.sym_visibility,
296+
arg_attrs=self.arg_attrs,
297+
res_attrs=self.res_attrs,
298+
func_attrs=self.func_attrs,
299+
generics=generics,
300+
qualname=self.qualname,
301+
loc=self.loc,
302+
ip=self.ip,
305303
)
306304

307305

@@ -314,11 +312,20 @@ def func(
314312
res_attrs=None,
315313
func_attrs=None,
316314
emit=False,
315+
generics=None,
317316
loc=None,
318317
ip=None,
319318
) -> FuncBase:
320319
if loc is None:
321320
loc = get_user_code_loc()
321+
if hasattr(f, "__type_params__") and f.__type_params__:
322+
assert generics is None, "generics XOR type params"
323+
generics = f.__type_params__
324+
if generics is not None:
325+
for i, g in enumerate(generics):
326+
if isinstance(g, str):
327+
generics[i] = TypeVar(g)
328+
322329
func_ = FuncBase(
323330
body_builder=f,
324331
func_op_ctor=FuncOp.__base__,
@@ -328,6 +335,7 @@ def func(
328335
arg_attrs=arg_attrs,
329336
res_attrs=res_attrs,
330337
func_attrs=func_attrs,
338+
generics=generics,
331339
loc=loc,
332340
ip=ip,
333341
)

mlir/extras/dialects/ext/gpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_get_previous_frame_idents,
1414
get_user_code_loc,
1515
make_maybe_no_args_decorator,
16+
find_ops,
1617
)
1718
from ....dialects._gpu_ops_gen import _Dialect
1819
from ....dialects._ods_common import (
@@ -414,12 +415,16 @@ def func(
414415
res_attrs=None,
415416
func_attrs=None,
416417
emit=False,
418+
generics=None,
417419
loc=None,
418420
ip=None,
419421
emit_grid=False,
420422
) -> Grid:
421423
if loc is None:
422424
loc = get_user_code_loc()
425+
if hasattr(f, "__type_params__") and f.__type_params__:
426+
assert generics is None, "generics XOR type params"
427+
generics = f.__type_params__
423428
func_ = GPUFunc(
424429
body_builder=f,
425430
func_op_ctor=GPUFuncOp,
@@ -429,6 +434,7 @@ def func(
429434
arg_attrs=arg_attrs,
430435
res_attrs=res_attrs,
431436
func_attrs=func_attrs,
437+
generics=generics,
432438
loc=loc,
433439
ip=ip,
434440
)
@@ -544,3 +550,9 @@ def memcpy(dst, src, async_dependencies=None, *, loc=None, ip=None):
544550
loc=loc,
545551
ip=ip,
546552
)
553+
554+
555+
def get_compile_object_bytes(compiled_module):
556+
binary = find_ops(compiled_module, lambda o: isinstance(o, BinaryOp), single=True)
557+
objects = list(map(ObjectAttr, binary.objects))
558+
return objects[-1].object

mlir/extras/dialects/ext/nvgpu.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
import re
21
from textwrap import dedent
32

4-
from ...util import find_ops
53
from ....dialects.nvgpu import *
64
from ....ir import Type
75

@@ -26,15 +24,3 @@ def tensormap_descriptor(
2624
),
2725
context=context,
2826
)
29-
30-
31-
def get_ptx(compiled_module):
32-
binary = find_ops(compiled_module, lambda o: o.name == "gpu.binary", single=True)
33-
r = re.findall(r'"(.*?)"', str(binary.objects[1]))
34-
return r[-1]
35-
36-
37-
def print_ptx(compiled_module):
38-
ptx = get_ptx(compiled_module)
39-
ptx = str(ptx).replace("\\0A", "\n").replace("\\09", "\t")
40-
print(ptx)

tests/test_func.py

Lines changed: 105 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import inspect
22
import sys
33
from textwrap import dedent
4+
from typing import TypeVar
45

56
import pytest
67

78
import mlir.extras.types as T
9+
10+
from mlir.extras.ast.canonicalize import canonicalize
811
from mlir.extras.context import mlir_mod_ctx
912
from mlir.extras.dialects.ext.arith import constant
1013
from mlir.extras.dialects.ext.func import func
11-
from mlir.extras.dialects.ext import linalg
14+
from mlir.extras.dialects.ext import linalg, arith, scf, memref
1215

1316
# noinspection PyUnresolvedReferences
1417
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
@@ -40,8 +43,7 @@ def demo_fun1():
4043

4144

4245
def test_declare_byte_rep(ctx: MLIRContext):
43-
def demo_fun1():
44-
...
46+
def demo_fun1(): ...
4547

4648
if sys.version_info.minor == 12:
4749
assert demo_fun1.__code__.co_code == b"\x97\x00y\x00"
@@ -55,20 +57,16 @@ def demo_fun1():
5557

5658
def test_declare(ctx: MLIRContext):
5759
@func
58-
def demo_fun1() -> T.i32():
59-
...
60+
def demo_fun1() -> T.i32(): ...
6061

6162
@func
62-
def demo_fun2() -> (T.i32(), T.i32()):
63-
...
63+
def demo_fun2() -> (T.i32(), T.i32()): ...
6464

6565
@func
66-
def demo_fun3(x: T.i32()) -> (T.i32(), T.i32()):
67-
...
66+
def demo_fun3(x: T.i32()) -> (T.i32(), T.i32()): ...
6867

6968
@func
70-
def demo_fun4(x: T.i32(), y: T.i32()) -> (T.i32(), T.i32()):
71-
...
69+
def demo_fun4(x: T.i32(), y: T.i32()) -> (T.i32(), T.i32()): ...
7270

7371
demo_fun1()
7472
demo_fun2()
@@ -197,3 +195,99 @@ def test_func_no_context_2(ctx: MLIRContext):
197195
"""
198196
)
199197
filecheck(correct, ctx.module)
198+
199+
200+
def test_generics_just_args(ctx: MLIRContext):
201+
@func(generics=["M", "K", "N", "dtype"])
202+
def mat_product_kernel(
203+
A: "T.memref(M, K, dtype)",
204+
B: "T.memref(K, N, dtype)",
205+
C: "T.memref(M, N, dtype)",
206+
):
207+
one = arith.constant(1.0)
208+
209+
mat_product_kernel[32, 32, 32, T.i32()].emit()
210+
correct = dedent(
211+
"""\
212+
module {
213+
func.func @mat_product_kernel(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>, %arg2: memref<32x32xi32>) {
214+
%cst = arith.constant 1.000000e+00 : f32
215+
return
216+
}
217+
}
218+
"""
219+
)
220+
filecheck(correct, ctx.module)
221+
222+
223+
def test_generics_closure(ctx: MLIRContext):
224+
dtype = None
225+
226+
@func(generics=["M", "K", "N", "dtype"])
227+
def mat_product_kernel(
228+
A: "T.memref(M, K, dtype)",
229+
B: "T.memref(K, N, dtype)",
230+
C: "T.memref(M, N, dtype)",
231+
):
232+
one = arith.constant(1, dtype)
233+
234+
mat_product_kernel[32, 32, 32, T.i32()].emit()
235+
correct = dedent(
236+
"""\
237+
module {
238+
func.func @mat_product_kernel(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>, %arg2: memref<32x32xi32>) {
239+
%c1_i32 = arith.constant 1 : i32
240+
return
241+
}
242+
}
243+
"""
244+
)
245+
filecheck(correct, ctx.module)
246+
247+
248+
def test_generics_with_canonicalizations(ctx: MLIRContext):
249+
dtype = None
250+
K = None
251+
252+
@func(generics=["M", "K", "N", "dtype"])
253+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
254+
def mat_product_kernel(
255+
A: "T.memref(M, K, dtype)",
256+
B: "T.memref(K, N, dtype)",
257+
C: "T.memref(M, N, dtype)",
258+
):
259+
x = arith.constant(1, index=True)
260+
y = arith.constant(1, index=True)
261+
one = arith.constant(1.0, type=dtype)
262+
tmp = arith.constant(0, type=dtype)
263+
for k, tmp in scf.range_(K, iter_args=[tmp]):
264+
tmp += A[x, k] * B[k, y]
265+
tmp = yield tmp
266+
C[x, y] = tmp + one
267+
268+
mat_product_kernel[32, 32, 32, T.f32()].emit()
269+
correct = dedent(
270+
"""\
271+
module {
272+
func.func @mat_product_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
273+
%c1 = arith.constant 1 : index
274+
%c1_0 = arith.constant 1 : index
275+
%cst = arith.constant 1.000000e+00 : f32
276+
%cst_1 = arith.constant 0.000000e+00 : f32
277+
%c0 = arith.constant 0 : index
278+
%c32 = arith.constant 32 : index
279+
%c1_2 = arith.constant 1 : index
280+
%0 = scf.for %arg3 = %c0 to %c32 step %c1_2 iter_args(%arg4 = %cst_1) -> (f32) {
281+
%2 = memref.load %arg0[%c1, %arg3] : memref<32x32xf32>
282+
%3 = memref.load %arg1[%arg3, %c1_0] : memref<32x32xf32>
283+
%4 = math.fma %2, %3, %arg4 : f32
284+
scf.yield %4 : f32
285+
}
286+
%1 = arith.addf %0, %cst : f32
287+
memref.store %1, %arg2[%c1, %c1_0] : memref<32x32xf32>
288+
return
289+
}
290+
}
291+
"""
292+
)
293+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)