Skip to content

Commit 2d4dba2

Browse files
committed
support generics (and do closures right)
1 parent 10c8824 commit 2d4dba2

File tree

4 files changed

+94
-68
lines changed

4 files changed

+94
-68
lines changed

examples/cuda_matmul_opt.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
1+
from __future__ import annotations
12
import ast
23
import math
34
import re
5+
import time
46

57
import cupy as cp
68
import mlir.extras.types as T
79
import numpy as np
810
from cupy.cuda import Module
9-
from mlir.dialects import math as math_dialect
1011

12+
from mlir.extras.ast.canonicalize import canonicalize
1113
from mlir.extras.context import (
1214
mlir_mod_ctx,
1315
MLIRContext,
1416
)
15-
from mlir.extras.dialects.ext import arith, memref, gpu
17+
from mlir.extras.dialects.ext import arith, memref, gpu, scf
1618
from mlir.extras.dialects.ext.gpu import (
1719
block_id,
1820
thread_id,
1921
block_dim,
2022
)
2123
from mlir.extras.dialects.ext.nvgpu import get_ptx, print_ptx
22-
from mlir.extras.dialects.ext.scf import range_, yield_
24+
from mlir.extras.dialects.ext.scf import range_
2325
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2426

2527
# noinspection PyUnresolvedReferences
@@ -39,26 +41,24 @@ def build_cuda_func(compiled_module, kernel_name="mat_product_kernel"):
3941

4042

4143
@gpu.func
42-
# @canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
44+
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
4345
def mat_product_kernel[
4446
M, K, N, dtype
45-
](A: "T.memref(M, K, dtype)", B: "T.memref(K, N, dtype)", C: "T.memref(M, N, dtype)"):
47+
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
4648
M, K, N, dtype
4749
x = block_dim.x * block_id.x + thread_id.x
4850
y = block_dim.y * block_id.y + thread_id.y
4951

5052
one = arith.constant(1.0, type=dtype)
5153
tmp = arith.constant(0, type=dtype)
5254
for k, tmp in range_(K, iter_args=[tmp]):
53-
# tmp += A[x, k] * B[k, y]
54-
# tmp = yield tmp
55-
tmp = math_dialect.fma(A[x, k], B[k, y], tmp)
56-
tmp = yield_(tmp)
55+
tmp += A[x, k] * B[k, y]
56+
tmp = yield tmp
5757
C[x, y] = tmp + one
5858

5959

6060
def main(ctx: MLIRContext):
61-
M, K, N = 256, 256, 256
61+
M, K, N = 2048, 2048, 2048
6262
BLOCK_SIZE = 32
6363
dtype = T.f32()
6464
npy_dtype = np.float32
@@ -69,8 +69,8 @@ def main(ctx: MLIRContext):
6969
def _():
7070
mat_product_kernel[M, K, N, dtype].emit()
7171

72-
print(ctx.module)
73-
print(ctx.module.operation.verify())
72+
# print(ctx.module)
73+
ctx.module.operation.verify()
7474

7575
compiled_module = run_pipeline(
7676
ctx.module,
@@ -87,8 +87,8 @@ def _():
8787
},
8888
),
8989
)
90-
print(compiled_module)
91-
print_ptx(compiled_module)
90+
# print(compiled_module)
91+
# print_ptx(compiled_module)
9292

9393
A = np.random.randint(0, 10, (M, K)).astype(npy_dtype)
9494
B = np.random.randint(0, 10, (K, N)).astype(npy_dtype)
@@ -98,12 +98,26 @@ def _():
9898
dB = cp.asarray(B)
9999
dC = cp.asarray(C)
100100

101+
start_gpu = cp.cuda.Event()
102+
end_gpu = cp.cuda.Event()
103+
101104
cuda_func = build_cuda_func(compiled_module)
105+
start_gpu.record()
106+
start_cpu = time.perf_counter()
102107
cuda_func(
103108
(math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE), 1),
104109
(BLOCK_SIZE, BLOCK_SIZE, 1),
105110
(dA.data.ptr, dB.data.ptr, dC.data.ptr),
106111
)
112+
end_cpu = time.perf_counter()
113+
end_gpu.record()
114+
end_gpu.synchronize()
115+
116+
t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu)
117+
t_cpu = end_cpu - start_cpu
118+
119+
print(f"{t_gpu=}ms")
120+
print(f"t_cpu={t_cpu / 1000}ms")
107121

108122
if not cp.array_equal(dC, dA @ dB + 1):
109123
print(dA @ dB + 1)

mlir/extras/ast/canonicalize.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import logging
66
import types
7+
import warnings
78
from abc import ABC, abstractmethod
89
from dis import findlinestarts
910
from opcode import opmap
@@ -13,7 +14,7 @@
1314
import astunparse
1415
from bytecode import ConcreteBytecode
1516

16-
from ..ast.util import get_module_cst, copy_func
17+
from ..ast.util import get_module_cst
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -59,28 +60,62 @@ def transform_func(f, *transformer_ctors: type(Transformer)):
5960
return module
6061

6162

63+
def insert_closed_vars(f, module):
64+
enclosing_mod = ast.FunctionDef(
65+
name="enclosing_mod",
66+
args=ast.arguments(
67+
posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
68+
),
69+
body=[],
70+
decorator_list=[],
71+
)
72+
for var in f.__code__.co_freevars:
73+
enclosing_mod.body.append(
74+
ast.Assign(
75+
targets=[ast.Name(var, ctx=ast.Store())],
76+
value=ast.Constant(None, kind="None"),
77+
)
78+
)
79+
enclosing_mod.body.extend(module.body)
80+
module.body = [enclosing_mod]
81+
return module
82+
83+
84+
def find_func_in_code_object(co, func_name):
85+
for c in co.co_consts:
86+
if type(c) is CodeType:
87+
if c.co_name == func_name:
88+
return c
89+
else:
90+
f = find_func_in_code_object(c, func_name)
91+
if f is not None:
92+
return f
93+
94+
6295
def transform_ast(
6396
f, transformers: List[Union[type(Transformer), type(StrictTransformer)]] = None
6497
):
6598
if transformers is None:
6699
return f
67100

68101
module = transform_func(f, *transformers)
102+
if f.__closure__:
103+
module = insert_closed_vars(f, module)
69104
module = ast.fix_missing_locations(module)
70105
module = ast.increment_lineno(module, f.__code__.co_firstlineno - 1)
71106
module_code_o = compile(module, f.__code__.co_filename, "exec")
72-
new_f_code_o = next(
73-
c
74-
for c in module_code_o.co_consts
75-
if type(c) is CodeType and c.co_name == f.__name__
76-
)
107+
new_f_code_o = find_func_in_code_object(module_code_o, f.__name__)
77108
n_lines = len(inspect.getsource(f).splitlines())
78109
line_starts = list(findlinestarts(new_f_code_o))
79-
assert (
110+
if (
80111
max([l for _, l in line_starts]) - min([l for _, l in line_starts]) + 1
81-
<= n_lines
82-
), f"something went wrong with the line numbers for the rewritten/canonicalized function"
83-
return copy_func(f, new_f_code_o)
112+
> n_lines
113+
) or (f.__code__.co_firstlineno != min([l for _, l in line_starts])):
114+
warnings.warn(
115+
"something went wrong with the line numbers for the rewritten/canonicalized function"
116+
)
117+
f.__code__ = new_f_code_o
118+
return f
84119

85120

86121
# this is like this because i couldn't figure out how to subclass
@@ -117,7 +152,8 @@ def patch_bytecode(f, patchers: List[type(BytecodePatcher)] = None):
117152
for patcher in patchers:
118153
code = patcher(context).patch_bytecode(code, f)
119154

120-
return copy_func(f, code.to_code())
155+
f.__code__ = code.to_code()
156+
return f
121157

122158

123159
class Canonicalizer(ABC):

mlir/extras/ast/util.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import ast
2-
import functools
32
import inspect
4-
import types
3+
from itertools import dropwhile
54
from textwrap import dedent
65

76

@@ -26,8 +25,8 @@ def ast_call(name, args=None, keywords=None):
2625

2726

2827
def get_module_cst(f):
29-
f_src = dedent(inspect.getsource(f))
30-
# tree = cst.parse_module(f_src)
28+
lines, _lnum = inspect.getsourcelines(f)
29+
f_src = dedent("".join(list(dropwhile(lambda l: l.startswith("@"), lines))))
3130
tree = ast.parse(f_src)
3231
assert isinstance(
3332
tree.body[0], ast.FunctionDef
@@ -43,31 +42,6 @@ def bind(func, instance, as_name=None):
4342
return bound_method
4443

4544

46-
def copy_func(f, new_code):
47-
"""Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard)"""
48-
g = types.FunctionType(
49-
code=new_code,
50-
globals={
51-
**f.__globals__,
52-
**{
53-
fr: f.__closure__[i].cell_contents
54-
for i, fr in enumerate(f.__code__.co_freevars)
55-
},
56-
},
57-
name=f.__name__,
58-
argdefs=f.__defaults__,
59-
# TODO(max): ValueError: foo requires closure of length 0, not 1
60-
# closure=f.__closure__ if f.__closure__ is not None else None,
61-
)
62-
g.__kwdefaults__ = f.__kwdefaults__
63-
g.__dict__.update(f.__dict__)
64-
g = functools.update_wrapper(g, f)
65-
66-
if inspect.ismethod(f):
67-
g = bind(g, f.__self__)
68-
return g
69-
70-
7145
def append_hidden_node(node_body, new_node):
7246
last_statement = node_body[-1]
7347
new_node = ast.fix_missing_locations(

mlir/extras/dialects/ext/func.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,20 +256,22 @@ def __call__(self, *call_args):
256256
return call(self.emit(*call_args), call_args)
257257

258258
def __getitem__(self, item):
259-
closure = {
260-
k: v
261-
for k, v in zip(
262-
self.body_builder.__code__.co_freevars, self.body_builder.__closure__
263-
)
264-
if v.cell_contents in self.body_builder.__type_params__
265-
}
266-
267-
for i, t in enumerate(self.body_builder.__type_params__):
268-
if t.__bound__ is not None:
269-
v = t.__bound__
270-
else:
271-
v = item[i]
272-
closure[t.__name__].cell_contents = v
259+
if self.body_builder.__code__.co_freevars and self.body_builder.__closure__:
260+
closure = {
261+
k: v
262+
for k, v in zip(
263+
self.body_builder.__code__.co_freevars,
264+
self.body_builder.__closure__,
265+
)
266+
if v.cell_contents in self.body_builder.__type_params__
267+
}
268+
269+
for i, t in enumerate(self.body_builder.__type_params__):
270+
if t.__bound__ is not None:
271+
v = t.__bound__
272+
else:
273+
v = item[i]
274+
closure[t.__name__].cell_contents = v
273275

274276
return self
275277

0 commit comments

Comments
 (0)