Skip to content

Commit 4e9396e

Browse files
committed
context fixture
1 parent 2d62353 commit 4e9396e

File tree

6 files changed

+20
-26
lines changed

6 files changed

+20
-26
lines changed

mlir/extras/context.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ def __str__(self):
1515
return str(self.module)
1616

1717

18+
@contextmanager
19+
def mlir_mod(
20+
src: Optional[str] = None,
21+
location: ir.Location = None,
22+
) -> ir.Module:
23+
with ExitStack() as stack:
24+
if location is None:
25+
location = ir.Location.unknown()
26+
stack.enter_context(location)
27+
if src is not None:
28+
module = ir.Module.parse(src)
29+
else:
30+
module = ir.Module.create()
31+
ip = ir.InsertionPoint(module.body)
32+
stack.enter_context(ip)
33+
yield module
34+
35+
1836
@contextmanager
1937
def mlir_mod_ctx(
2038
src: Optional[str] = None,
@@ -112,24 +130,6 @@ def __str__(self):
112130
return str(self.module)
113131

114132

115-
@contextmanager
116-
def mlir_mod(
117-
src: Optional[str] = None,
118-
location: ir.Location = None,
119-
) -> ir.Module:
120-
with ExitStack() as stack:
121-
if location is None:
122-
location = ir.Location.unknown()
123-
stack.enter_context(location)
124-
if src is not None:
125-
module = ir.Module.parse(src)
126-
else:
127-
module = ir.Module.create()
128-
ip = ir.InsertionPoint(module.body)
129-
stack.enter_context(ip)
130-
yield module
131-
132-
133133
@contextlib.contextmanager
134134
def enable_multithreading(context=None):
135135
from ..ir import Context

mlir/extras/testing/testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ def filecheck_with_comments(module):
123123
raise ValueError(f"\n{err}")
124124

125125

126-
@pytest.fixture
126+
@pytest.fixture(scope="function")
127127
def mlir_ctx() -> MLIRContext:
128128
with mlir_mod_ctx(allow_unregistered_dialects=True) as ctx:
129129
yield ctx
130130

131131

132-
@pytest.fixture
132+
@pytest.fixture(scope="function")
133133
def backend() -> LLVMJITBackend:
134134
return LLVMJITBackend()

tests/test_gpu.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,8 +1220,6 @@ def smol_matmul(
12201220
def gpu_module():
12211221
smol_matmul.emit()
12221222

1223-
print(gpu_module)
1224-
12251223
lowered_module = run_pipeline(
12261224
gpu_module,
12271225
Pipeline()

tests/test_nvgpu_nvvm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,6 @@ def main(module: any_op_t()):
413413
# CHECK: }
414414
# CHECK: }
415415

416-
print(mod)
417416
filecheck_with_comments(mod)
418417

419418

tests/test_runtime.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,6 @@ def memfoo(mem: ranked_memref_kxk_f32):
522522
mem[i, i] = mem[i, i] + mem[i, i] * sitofp(T.f32(), index_cast(T.i32(), i))
523523

524524
memfoo.emit()
525-
print(ctx.module)
526525

527526
module = backend.compile(
528527
ctx.module,

tests/test_vector.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ def pats():
120120
),
121121
)
122122

123-
print(vectorized_module)
124-
125123
compiled_module = backend.compile(
126124
find_ops(
127125
vectorized_module.operation,

0 commit comments

Comments
 (0)