Skip to content

Commit 674d096

Browse files
committed
do closures right-er
1 parent 2d4dba2 commit 674d096

File tree

5 files changed

+190
-57
lines changed

5 files changed

+190
-57
lines changed

examples/cuda_matmul_opt.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
2+
23
import ast
4+
import contextlib
35
import math
46
import re
5-
import time
67

78
import cupy as cp
89
import mlir.extras.types as T
@@ -20,7 +21,7 @@
2021
thread_id,
2122
block_dim,
2223
)
23-
from mlir.extras.dialects.ext.nvgpu import get_ptx, print_ptx
24+
from mlir.extras.dialects.ext.nvgpu import get_ptx
2425
from mlir.extras.dialects.ext.scf import range_
2526
from mlir.extras.runtime.passes import Pipeline, run_pipeline
2627

@@ -40,12 +41,22 @@ def build_cuda_func(compiled_module, kernel_name="mat_product_kernel"):
4041
return mod.get_function(kernel_name)
4142

4243

44+
@contextlib.contextmanager
45+
def time_cuda():
46+
start_gpu = cp.cuda.Event()
47+
end_gpu = cp.cuda.Event()
48+
49+
start_gpu.record()
50+
yield start_gpu, end_gpu
51+
end_gpu.record()
52+
end_gpu.synchronize()
53+
54+
4355
@gpu.func
4456
@canonicalize(using=(arith.canonicalizer, scf.canonicalizer))
4557
def mat_product_kernel[
4658
M, K, N, dtype
4759
](A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype)):
48-
M, K, N, dtype
4960
x = block_dim.x * block_id.x + thread_id.x
5061
y = block_dim.y * block_id.y + thread_id.y
5162

@@ -57,9 +68,7 @@ def mat_product_kernel[
5768
C[x, y] = tmp + one
5869

5970

60-
def main(ctx: MLIRContext):
61-
M, K, N = 2048, 2048, 2048
62-
BLOCK_SIZE = 32
71+
def main(ctx: MLIRContext, M, K, N, BLOCK_SIZE=32, repeat_times=50):
6372
dtype = T.f32()
6473
npy_dtype = np.float32
6574

@@ -87,6 +96,7 @@ def _():
8796
},
8897
),
8998
)
99+
cuda_func = build_cuda_func(compiled_module)
90100
# print(compiled_module)
91101
# print_ptx(compiled_module)
92102

@@ -98,34 +108,26 @@ def _():
98108
dB = cp.asarray(B)
99109
dC = cp.asarray(C)
100110

101-
start_gpu = cp.cuda.Event()
102-
end_gpu = cp.cuda.Event()
103-
104-
cuda_func = build_cuda_func(compiled_module)
105-
start_gpu.record()
106-
start_cpu = time.perf_counter()
107-
cuda_func(
108-
(math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE), 1),
109-
(BLOCK_SIZE, BLOCK_SIZE, 1),
110-
(dA.data.ptr, dB.data.ptr, dC.data.ptr),
111-
)
112-
end_cpu = time.perf_counter()
113-
end_gpu.record()
114-
end_gpu.synchronize()
111+
with time_cuda() as (start_gpu, end_gpu):
112+
for _ in range(repeat_times):
113+
cuda_func(
114+
(math.ceil(M / BLOCK_SIZE), math.ceil(N / BLOCK_SIZE), 1),
115+
(BLOCK_SIZE, BLOCK_SIZE, 1),
116+
(dA.data.ptr, dB.data.ptr, dC.data.ptr),
117+
)
115118

116119
t_gpu = cp.cuda.get_elapsed_time(start_gpu, end_gpu)
117-
t_cpu = end_cpu - start_cpu
118120

119-
print(f"{t_gpu=}ms")
120-
print(f"t_cpu={t_cpu / 1000}ms")
121+
print(f"t_gpu={t_gpu / repeat_times:.6f} ms")
121122

122123
if not cp.array_equal(dC, dA @ dB + 1):
123124
print(dA @ dB + 1)
124125
print(dC)
125126

126127

127-
with (
128-
mlir_mod_ctx() as ctx,
129-
# enable_debug()
130-
):
131-
main(ctx)
128+
for s in [128, 256, 512, 1024]:
129+
with (
130+
mlir_mod_ctx() as ctx,
131+
# enable_debug()
132+
):
133+
main(ctx, s, s, s)

mlir/extras/ast/canonicalize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def transform_func(f, *transformer_ctors: type(Transformer)):
6060
return module
6161

6262

63+
# TODO(max): unify with `replace_closure` in ast/utils.py
6364
def insert_closed_vars(f, module):
6465
enclosing_mod = ast.FunctionDef(
6566
name="enclosing_mod",

mlir/extras/ast/util.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import ast
2+
import functools
23
import inspect
4+
import types
35
from itertools import dropwhile
6+
from opcode import opmap
47
from textwrap import dedent
8+
from typing import Dict
9+
10+
from bytecode import ConcreteBytecode
11+
from cloudpickle import cloudpickle
512

613

714
def set_lineno(node, n=1):
@@ -42,6 +49,99 @@ def bind(func, instance, as_name=None):
4249
return bound_method
4350

4451

52+
# based on https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Tools/build/deepfreeze.py#L48
53+
def get_localsplus_name_to_idx(code: types.CodeType):
54+
localsplus = code.co_varnames + code.co_cellvars + code.co_freevars
55+
return localsplus, {v: i for i, v in enumerate(localsplus)}
56+
57+
58+
class _empty_cell_value:
59+
"""Sentinel for empty closures."""
60+
61+
@classmethod
62+
def __reduce__(cls):
63+
return cls.__name__
64+
65+
66+
_empty_cell_value = _empty_cell_value()
67+
68+
69+
# based on https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L513
70+
def make_empty_cell():
71+
if False:
72+
# trick the compiler into creating an empty cell in our lambda
73+
cell = None
74+
raise AssertionError("this route should not be executed")
75+
76+
return (lambda: cell).__closure__[0]
77+
78+
79+
def make_cell(value=_empty_cell_value):
80+
cell = make_empty_cell()
81+
if value is not _empty_cell_value:
82+
cell.cell_contents = value
83+
return cell
84+
85+
86+
# based on https://github.com/python/cpython/blob/a4b44d39cd6941cc03590fee7538776728bdfd0a/Lib/test/test_code.py#L197
87+
def replace_closure(code, new_closure: Dict):
88+
COPY_FREE_VARS = opmap["COPY_FREE_VARS"]
89+
LOAD_DEREF = opmap["LOAD_DEREF"]
90+
91+
# get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i
92+
localsplus, localsplus_name_to_idx = get_localsplus_name_to_idx(code)
93+
94+
# closure vars go into co_freevars
95+
new_code = code.replace(co_freevars=tuple(new_closure.keys()))
96+
# closure is a tuple of cells
97+
closure = tuple(
98+
make_cell(v) if not isinstance(v, types.CellType) else v
99+
for v in new_closure.values()
100+
)
101+
102+
new_code = ConcreteBytecode.from_code(new_code)
103+
# update how many closure vars are loaded from frame
104+
# see https://github.com/python/cpython/blob/6078f2033ea15a16cf52fe8d644a95a3be72d2e3/Python/bytecodes.c#L1571
105+
assert new_code[0].opcode == COPY_FREE_VARS
106+
new_code[0].arg = len(closure)
107+
108+
# map orig localsplus arg_i to new localplus position/arg_i
109+
new_localsplus = new_code.varnames + new_code.cellvars + new_code.freevars
110+
new_localsplus_name_to_idx = {v: i for i, v in enumerate(new_localsplus)}
111+
for c in new_code:
112+
if c.opcode == LOAD_DEREF and c.arg < len(localsplus):
113+
c.arg = new_localsplus_name_to_idx[localsplus[c.arg]]
114+
new_code = new_code.to_code()
115+
116+
return new_code, closure
117+
118+
119+
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard);
120+
# potentially more complete approach https://stackoverflow.com/a/56901529/9045206
121+
def copy_func(f, new_closure: Dict = None):
122+
if new_closure is not None:
123+
code, closure = replace_closure(f.__code__, new_closure)
124+
else:
125+
code, closure = f.__code__, f.__closure__
126+
127+
g = types.FunctionType(
128+
code=code,
129+
globals=f.__globals__,
130+
name=f.__name__,
131+
argdefs=f.__defaults__,
132+
# see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813
133+
# for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars)
134+
closure=cloudpickle.loads(cloudpickle.dumps(closure)),
135+
)
136+
g.__kwdefaults__ = f.__kwdefaults__
137+
g.__dict__.update(f.__dict__)
138+
g = functools.update_wrapper(g, f)
139+
140+
if inspect.ismethod(f):
141+
g = bind(g, f.__self__)
142+
return g
143+
144+
45145
def append_hidden_node(node_body, new_node):
46146
last_statement = node_body[-1]
47147
new_node = ast.fix_missing_locations(

mlir/extras/dialects/ext/func.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import inspect
22
import sys
3+
from dataclasses import dataclass
34
from functools import update_wrapper
5+
from typing import Optional
46

7+
from ...ast.util import copy_func
58
from ...meta import op_region_builder
69
from ...util import get_user_code_loc, make_maybe_no_args_decorator
710
from ....dialects._ods_common import get_op_result_or_op_results
@@ -122,6 +125,12 @@ def prep_func_types(sig, return_types):
122125
return input_types, return_types, user_locs
123126

124127

128+
@dataclass
129+
class ReifiedTypeParams:
130+
name: str
131+
val: object
132+
133+
125134
class FuncBase:
126135
def __init__(
127136
self,
@@ -193,18 +202,20 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
193202
if self._func_op is None or decl or force:
194203
if len(call_args) == 0:
195204
input_types = self.input_types[:]
196-
env = dict(self.body_builder.__globals__)
197-
if self.body_builder.__closure__:
198-
closure = dict(
199-
zip(
200-
self.body_builder.__code__.co_freevars,
201-
[c.cell_contents for c in self.body_builder.__closure__],
202-
)
203-
)
204-
env.update(closure)
205+
locals = {}
206+
if (
207+
hasattr(self.body_builder, "__type_params__")
208+
and self.body_builder.__type_params__
209+
):
210+
for t in self.body_builder.__type_params__:
211+
if not isinstance(t, ReifiedTypeParams):
212+
raise RuntimeError(f"{t=} must reified")
213+
locals[t.name] = t.val
205214
for i, v in enumerate(input_types):
206215
if isinstance(v, str):
207-
input_types[i] = Type(eval(v, env))
216+
input_types[i] = Type(
217+
eval(v, self.body_builder.__globals__, locals)
218+
)
208219
elif isalambda(v):
209220
input_types[i] = v()
210221
else:
@@ -256,24 +267,42 @@ def __call__(self, *call_args):
256267
return call(self.emit(*call_args), call_args)
257268

258269
def __getitem__(self, item):
259-
if self.body_builder.__code__.co_freevars and self.body_builder.__closure__:
260-
closure = {
261-
k: v
262-
for k, v in zip(
263-
self.body_builder.__code__.co_freevars,
264-
self.body_builder.__closure__,
265-
)
266-
if v.cell_contents in self.body_builder.__type_params__
267-
}
268-
269-
for i, t in enumerate(self.body_builder.__type_params__):
270-
if t.__bound__ is not None:
271-
v = t.__bound__
272-
else:
273-
v = item[i]
274-
closure[t.__name__].cell_contents = v
275-
276-
return self
270+
if (
271+
not hasattr(self.body_builder, "__type_params__")
272+
or not self.body_builder.__type_params__
273+
):
274+
raise RuntimeError(
275+
"using a generic call requires the func be generic (i.e., have type_params)"
276+
)
277+
# this also copies the function so that the original body_builder remains "generic"
278+
body_builder = copy_func(self.body_builder)
279+
reified_type_params = []
280+
for i, t in enumerate(body_builder.__type_params__):
281+
if t.__bound__ is not None:
282+
r = ReifiedTypeParams(t.__name__, t.__bound__)
283+
else:
284+
r = ReifiedTypeParams(t.__name__, item[i])
285+
reified_type_params.append(r)
286+
if r.name in body_builder.__code__.co_freevars:
287+
free_i = body_builder.__code__.co_freevars.index(r.name)
288+
body_builder.__closure__[free_i].cell_contents = r.val
289+
290+
body_builder.__type_params__ = tuple(reified_type_params)
291+
292+
return FuncBase(
293+
body_builder,
294+
self.func_op_ctor,
295+
self.return_op_ctor,
296+
self.call_op_ctor,
297+
self.return_types,
298+
self.sym_visibility,
299+
self.arg_attrs,
300+
self.res_attrs,
301+
self.func_attrs,
302+
self.loc,
303+
self.ip,
304+
self.qualname,
305+
)
277306

278307

279308
@make_maybe_no_args_decorator

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ black
44
bytecode
55
inflection
66
numpy
7-
astunparse
7+
astunparse
8+
cloudpickle

0 commit comments

Comments
 (0)