Skip to content

Commit df9201d

Browse files
authored
fix func and add module/context convenience classes (#52)
1 parent b43ca57 commit df9201d

File tree

4 files changed

+123
-2
lines changed

4 files changed

+123
-2
lines changed

mlir/extras/context.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,43 @@ def mlir_mod_ctx(
4141
context._clear_live_operations()
4242

4343

44+
class RAIIMLIRContext:
45+
context: ir.Context
46+
location: ir.Location
47+
48+
def __init__(self, location: Optional[ir.Location] = None):
49+
self.context = ir.Context()
50+
self.context.__enter__()
51+
if location is None:
52+
location = ir.Location.unknown()
53+
self.location = location
54+
self.location.__enter__()
55+
56+
def __del__(self):
57+
self.location.__exit__(None, None, None)
58+
self.context.__exit__(None, None, None)
59+
assert ir.Context is None
60+
61+
62+
class ExplicitlyManagedModule:
63+
module: ir.Module
64+
_ip: ir.InsertionPoint
65+
66+
def __init__(self, src: Optional[str] = None):
67+
if src is not None:
68+
self.module = ir.Module.parse(src)
69+
else:
70+
self.module = ir.Module.create()
71+
self._ip = ir.InsertionPoint(self.module.body)
72+
self._ip.__enter__()
73+
74+
def finish(self):
75+
self._ip.__exit__(None, None, None)
76+
77+
def __str__(self):
78+
return str(self.module)
79+
80+
4481
@contextlib.contextmanager
4582
def enable_multithreading(context=None):
4683
from ..ir import Context

mlir/extras/dialects/ext/func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
182182
input_types = self.input_types[:]
183183
for i, v in enumerate(input_types):
184184
if isinstance(v, str):
185-
input_types[i] = Type(eval(v, {"T": T}))
185+
input_types[i] = Type(eval(v, self.body_builder.__globals__))
186186
elif isalambda(v):
187187
input_types[i] = v()
188188
else:

tests/test_regions.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,42 @@ def mod():
521521
filecheck(correct, ctx.module)
522522

523523

524+
M, K, N = 64, 32, 64
525+
526+
527+
@func(emit=False, sym_visibility="private")
528+
def matmul_i16_i16(
529+
A: "T.memref(M, K, T.i16())",
530+
B: "T.memref(K, N, T.i16())",
531+
C: "T.memref(M, N, T.i16())",
532+
):
533+
linalg.matmul(A, B, C)
534+
535+
536+
def test_defer_emit_3(ctx: MLIRContext):
537+
538+
matmul_i16_i16.emit(force=True)
539+
540+
@module
541+
def mod():
542+
matmul_i16_i16.emit(decl=True)
543+
544+
correct = dedent(
545+
"""\
546+
module {
547+
func.func private @matmul_i16_i16(%arg0: memref<64x32xi16>, %arg1: memref<32x64xi16>, %arg2: memref<64x64xi16>) {
548+
linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : memref<64x32xi16>, memref<32x64xi16>) outs(%arg2 : memref<64x64xi16>)
549+
return
550+
}
551+
module {
552+
func.func private @matmul_i16_i16(memref<64x32xi16>, memref<32x64xi16>, memref<64x64xi16>)
553+
}
554+
}
555+
"""
556+
)
557+
filecheck(correct, ctx.module)
558+
559+
524560
def test_successor_ctx_manager(ctx: MLIRContext):
525561
@func
526562
def foo1():

tests/test_runtime.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
import numpy as np
77
import pytest
8-
from mlir.ir import UnitAttr, Module, StridedLayoutAttr
8+
from mlir.ir import UnitAttr, Module, StridedLayoutAttr, InsertionPoint, Context
99
from mlir.runtime import get_unranked_memref_descriptor, get_ranked_memref_descriptor
1010

1111
import mlir.extras.types as T
1212
from mlir.extras.ast.canonicalize import canonicalize
13+
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
1314
from mlir.extras.dialects.ext import linalg
1415
from mlir.dialects.arith import sitofp, index_cast
1516
from mlir.extras.dialects.ext.arith import constant
@@ -872,3 +873,50 @@ def tenfoo(
872873
#
873874
# invoker.memfoo_capi_wrapper(AA, BB, CC)
874875
# # assert np.array_equal(A + B, C)
876+
877+
878+
def test_raii_context():
879+
def foo():
880+
ctx = RAIIMLIRContext()
881+
mod = Module.create()
882+
with InsertionPoint(mod.body):
883+
884+
@func(emit=True)
885+
def foo(x: T.i32()):
886+
return x
887+
888+
correct = dedent(
889+
"""\
890+
module {
891+
func.func @foo(%arg0: i32) -> i32 {
892+
return %arg0 : i32
893+
}
894+
}
895+
"""
896+
)
897+
filecheck(correct, mod)
898+
899+
foo()
900+
assert Context.current is None
901+
902+
903+
def test_explicit_module():
904+
ctx = RAIIMLIRContext()
905+
mod = ExplicitlyManagedModule()
906+
907+
@func(emit=True)
908+
def foo(x: T.i32()):
909+
return x
910+
911+
mod.finish()
912+
913+
correct = dedent(
914+
"""\
915+
module {
916+
func.func @foo(%arg0: i32) -> i32 {
917+
return %arg0 : i32
918+
}
919+
}
920+
"""
921+
)
922+
filecheck(correct, mod)

0 commit comments

Comments
 (0)