Skip to content

Commit 8d4ab82

Browse files
committed
use comments
1 parent 3b6adc9 commit 8d4ab82

File tree

12 files changed

+2730
-3369
lines changed

12 files changed

+2730
-3369
lines changed

mlir/extras/testing/testing.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import difflib
22
import inspect
33
import platform
4+
import re
45
import shutil
56
import sys
67
import tempfile
78
from pathlib import Path
89
from subprocess import PIPE, Popen
9-
from textwrap import dedent
10+
from textwrap import dedent, indent
1011

1112
import pytest
1213

@@ -16,6 +17,25 @@
1617
from ...ir import Module
1718

1819

20+
def replace_correct_str_with_comments(fun, correct_with_checks):
21+
# fun = inspect.currentframe().f_back.f_code
22+
lines, lnum = inspect.findsource(fun)
23+
fun_src = inspect.getsource(fun)
24+
fun_src = re.sub(
25+
r'dedent\(\s+""".*"""\s+\)',
26+
"#####"
27+
+ indent(correct_with_checks, " ")
28+
+ "\n filecheck_with_comments(ctx.module)\n#####",
29+
fun_src,
30+
flags=re.DOTALL,
31+
)
32+
fun_src = fun_src.splitlines(keepends=True)
33+
lines[lnum : lnum + len(fun_src)] = fun_src
34+
35+
with open(inspect.getfile(fun), "w") as f:
36+
f.writelines(lines)
37+
38+
1939
def filecheck(correct: str, module):
2040
if isinstance(module, Module):
2141
assert module.operation.verify()

tests/test_arith.py

Lines changed: 154 additions & 188 deletions
Large diffs are not rendered by default.

tests/test_func.py

Lines changed: 105 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from mlir.ir import FunctionType
1717

1818
# noinspection PyUnresolvedReferences
19-
from mlir.extras.testing import mlir_ctx as ctx, filecheck, MLIRContext
19+
from mlir.extras.testing import (
20+
mlir_ctx as ctx,
21+
filecheck,
22+
filecheck_with_comments,
23+
MLIRContext,
24+
)
2025

2126
# needed since the fix isn't defined here nor conftest.py
2227
pytest.mark.usefixtures("ctx")
@@ -31,17 +36,13 @@ def demo_fun1():
3136
assert hasattr(demo_fun1, "emit")
3237
assert inspect.ismethod(demo_fun1.emit)
3338
demo_fun1.emit()
34-
correct = dedent(
35-
"""\
36-
module {
37-
func.func @demo_fun1() -> i32 {
38-
%c1_i32 = arith.constant 1 : i32
39-
return %c1_i32 : i32
40-
}
41-
}
42-
"""
43-
)
44-
filecheck(correct, ctx.module)
39+
40+
# CHECK: func.func @demo_fun1() -> i32 {
41+
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
42+
# CHECK: return %[[VAL_0]] : i32
43+
# CHECK: }
44+
45+
filecheck_with_comments(ctx.module)
4546

4647

4748
def test_declare_byte_rep(ctx: MLIRContext):
@@ -79,22 +80,18 @@ def demo_fun4(x: T.i32(), y: T.i32()) -> (T.i32(), T.i32()): ...
7980
demo_fun4(one, one)
8081

8182
ctx.module.operation.verify()
82-
correct = dedent(
83-
"""\
84-
module {
85-
func.func private @demo_fun1() -> i32
86-
func.func private @demo_fun2() -> (i32, i32)
87-
func.func private @demo_fun3(i32) -> (i32, i32)
88-
func.func private @demo_fun4(i32, i32) -> (i32, i32)
89-
%0 = func.call @demo_fun1() : () -> i32
90-
%1:2 = func.call @demo_fun2() : () -> (i32, i32)
91-
%c1_i32 = arith.constant 1 : i32
92-
%2:2 = func.call @demo_fun3(%c1_i32) : (i32) -> (i32, i32)
93-
%3:2 = func.call @demo_fun4(%c1_i32, %c1_i32) : (i32, i32) -> (i32, i32)
94-
}
95-
"""
96-
)
97-
filecheck(correct, ctx.module)
83+
84+
# CHECK: func.func private @demo_fun1() -> i32
85+
# CHECK: func.func private @demo_fun2() -> (i32, i32)
86+
# CHECK: func.func private @demo_fun3(i32) -> (i32, i32)
87+
# CHECK: func.func private @demo_fun4(i32, i32) -> (i32, i32)
88+
# CHECK: %[[VAL_0:.*]] = func.call @demo_fun1() : () -> i32
89+
# CHECK: %[[VAL_1:.*]]:2 = func.call @demo_fun2() : () -> (i32, i32)
90+
# CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
91+
# CHECK: %[[VAL_3:.*]]:2 = func.call @demo_fun3(%[[VAL_2]]) : (i32) -> (i32, i32)
92+
# CHECK: %[[VAL_4:.*]]:2 = func.call @demo_fun4(%[[VAL_2]], %[[VAL_2]]) : (i32, i32) -> (i32, i32)
93+
94+
filecheck_with_comments(ctx.module)
9895

9996

10097
def test_func_base_meta(ctx: MLIRContext):
@@ -104,31 +101,15 @@ def foo1():
104101
return one
105102

106103
foo1.emit()
107-
correct = dedent(
108-
"""\
109-
module {
110-
func.func @foo1() -> i32 {
111-
%c1_i32 = arith.constant 1 : i32
112-
return %c1_i32 : i32
113-
}
114-
}
115-
"""
116-
)
117-
filecheck(correct, ctx.module)
118-
119104
foo1()
120-
correct = dedent(
121-
"""\
122-
module {
123-
func.func @foo1() -> i32 {
124-
%c1_i32 = arith.constant 1 : i32
125-
return %c1_i32 : i32
126-
}
127-
%0 = func.call @foo1() : () -> i32
128-
}
129-
"""
130-
)
131-
filecheck(correct, ctx.module)
105+
106+
# CHECK: func.func @foo1() -> i32 {
107+
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
108+
# CHECK: return %[[VAL_0]] : i32
109+
# CHECK: }
110+
# CHECK: %[[VAL_1:.*]] = func.call @foo1() : () -> i32
111+
112+
filecheck_with_comments(ctx.module)
132113

133114

134115
def test_func_base_meta2(ctx: MLIRContext):
@@ -138,18 +119,14 @@ def foo1():
138119
return one
139120

140121
foo1()
141-
correct = dedent(
142-
"""\
143-
module {
144-
func.func @foo1() -> i32 {
145-
%c1_i32 = arith.constant 1 : i32
146-
return %c1_i32 : i32
147-
}
148-
%0 = func.call @foo1() : () -> i32
149-
}
150-
"""
151-
)
152-
filecheck(correct, ctx.module)
122+
123+
# CHECK: func.func @foo1() -> i32 {
124+
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
125+
# CHECK: return %[[VAL_0]] : i32
126+
# CHECK: }
127+
# CHECK: %[[VAL_1:.*]] = func.call @foo1() : () -> i32
128+
129+
filecheck_with_comments(ctx.module)
153130

154131

155132
def test_func_no_context():
@@ -160,18 +137,14 @@ def foo1():
160137

161138
with mlir_mod_ctx() as mod_ctx:
162139
foo1()
163-
correct = dedent(
164-
"""\
165-
module {
166-
func.func @foo1() -> i32 {
167-
%c1_i32 = arith.constant 1 : i32
168-
return %c1_i32 : i32
169-
}
170-
%0 = func.call @foo1() : () -> i32
171-
}
172-
"""
173-
)
174-
filecheck(correct, mod_ctx.module)
140+
141+
# CHECK: func.func @foo1() -> i32 {
142+
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
143+
# CHECK: return %[[VAL_0]] : i32
144+
# CHECK: }
145+
# CHECK: %[[VAL_1:.*]] = func.call @foo1() : () -> i32
146+
147+
filecheck_with_comments(mod_ctx.module)
175148

176149

177150
generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"]))
@@ -188,17 +161,13 @@ def matmul_i32_i32(
188161

189162
def test_func_no_context_2(ctx: MLIRContext):
190163
matmul_i32_i32[16, 16].emit()
191-
correct = dedent(
192-
"""\
193-
module {
194-
func.func @matmul_i32_i32(%arg0: memref<16x16xi32>, %arg1: memref<16x16xi32>, %arg2: memref<16x16xi32>) {
195-
linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : memref<16x16xi32>, memref<16x16xi32>) outs(%arg2 : memref<16x16xi32>)
196-
return
197-
}
198-
}
199-
"""
200-
)
201-
filecheck(correct, ctx.module)
164+
165+
# CHECK: func.func @matmul_i32_i32(%[[VAL_0:.*]]: memref<16x16xi32>, %[[VAL_1:.*]]: memref<16x16xi32>, %[[VAL_2:.*]]: memref<16x16xi32>) {
166+
# CHECK: linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : memref<16x16xi32>, memref<16x16xi32>) outs(%[[VAL_2]] : memref<16x16xi32>)
167+
# CHECK: return
168+
# CHECK: }
169+
170+
filecheck_with_comments(ctx.module)
202171

203172

204173
def test_generics_just_args(ctx: MLIRContext):
@@ -212,17 +181,13 @@ def mat_product_kernel(
212181
one = arith.constant(1.0, dtype)
213182

214183
mat_product_kernel[32, 32, 32, T.f32()].emit()
215-
correct = dedent(
216-
"""\
217-
module {
218-
func.func @mat_product_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
219-
%cst = arith.constant 1.000000e+00 : f32
220-
return
221-
}
222-
}
223-
"""
224-
)
225-
filecheck(correct, ctx.module)
184+
185+
# CHECK: func.func @mat_product_kernel(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) {
186+
# CHECK: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
187+
# CHECK: return
188+
# CHECK: }
189+
190+
filecheck_with_comments(ctx.module)
226191

227192

228193
def test_generics_closure(ctx: MLIRContext):
@@ -237,17 +202,13 @@ def mat_product_kernel(
237202
one = arith.constant(1, dtype)
238203

239204
mat_product_kernel[32, 32, 32, T.i32()].emit()
240-
correct = dedent(
241-
"""\
242-
module {
243-
func.func @mat_product_kernel(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>, %arg2: memref<32x32xi32>) {
244-
%c1_i32 = arith.constant 1 : i32
245-
return
246-
}
247-
}
248-
"""
249-
)
250-
filecheck(correct, ctx.module)
205+
206+
# CHECK: func.func @mat_product_kernel(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) {
207+
# CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
208+
# CHECK: return
209+
# CHECK: }
210+
211+
filecheck_with_comments(ctx.module)
251212

252213

253214
def test_generics_with_canonicalizations(ctx: MLIRContext):
@@ -271,31 +232,27 @@ def mat_product_kernel(
271232
C[x, y] = tmp + one
272233

273234
mat_product_kernel[32, 32, 32, T.f32()].emit()
274-
correct = dedent(
275-
"""\
276-
module {
277-
func.func @mat_product_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
278-
%c1 = arith.constant 1 : index
279-
%c1_0 = arith.constant 1 : index
280-
%cst = arith.constant 1.000000e+00 : f32
281-
%cst_1 = arith.constant 0.000000e+00 : f32
282-
%c0 = arith.constant 0 : index
283-
%c32 = arith.constant 32 : index
284-
%c1_2 = arith.constant 1 : index
285-
%0 = scf.for %arg3 = %c0 to %c32 step %c1_2 iter_args(%arg4 = %cst_1) -> (f32) {
286-
%2 = memref.load %arg0[%c1, %arg3] : memref<32x32xf32>
287-
%3 = memref.load %arg1[%arg3, %c1_0] : memref<32x32xf32>
288-
%4 = math.fma %2, %3, %arg4 : f32
289-
scf.yield %4 : f32
290-
}
291-
%1 = arith.addf %0, %cst : f32
292-
memref.store %1, %arg2[%c1, %c1_0] : memref<32x32xf32>
293-
return
294-
}
295-
}
296-
"""
297-
)
298-
filecheck(correct, ctx.module)
235+
236+
# CHECK: func.func @mat_product_kernel(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) {
237+
# CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
238+
# CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
239+
# CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
240+
# CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
241+
# CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
242+
# CHECK: %[[VAL_8:.*]] = arith.constant 32 : index
243+
# CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
244+
# CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (f32) {
245+
# CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_11]]] : memref<32x32xf32>
246+
# CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_11]], %[[VAL_4]]] : memref<32x32xf32>
247+
# CHECK: %[[VAL_15:.*]] = math.fma %[[VAL_13]], %[[VAL_14]], %[[VAL_12]] : f32
248+
# CHECK: scf.yield %[[VAL_15]] : f32
249+
# CHECK: }
250+
# CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_17:.*]], %[[VAL_5]] : f32
251+
# CHECK: memref.store %[[VAL_16]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_4]]] : memref<32x32xf32>
252+
# CHECK: return
253+
# CHECK: }
254+
255+
filecheck_with_comments(ctx.module)
299256

300257

301258
def test_raii_mlir_context_module():
@@ -310,17 +267,13 @@ def demo_fun1():
310267
assert hasattr(demo_fun1, "emit")
311268
assert inspect.ismethod(demo_fun1.emit)
312269
demo_fun1.emit()
313-
correct = dedent(
314-
"""\
315-
module {
316-
func.func @demo_fun1() -> i32 {
317-
%c1_i32 = arith.constant 1 : i32
318-
return %c1_i32 : i32
319-
}
320-
}
321-
"""
322-
)
323-
filecheck(correct, tls.ctx.module)
270+
271+
# CHECK: func.func @demo_fun1() -> i32 {
272+
# CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
273+
# CHECK: return %[[VAL_0]] : i32
274+
# CHECK: }
275+
276+
filecheck_with_comments(tls.ctx.module)
324277

325278

326279
def test_explicit_function_type(ctx: MLIRContext):
@@ -334,14 +287,10 @@ def demo_fun1(a, b):
334287
return one
335288

336289
demo_fun1.emit()
337-
correct = dedent(
338-
"""\
339-
module {
340-
func.func @demo_fun1(%arg0: i32, %arg1: i32) -> i32 {
341-
%c1_i32 = arith.constant 1 : i32
342-
return %c1_i32 : i32
343-
}
344-
}
345-
"""
346-
)
347-
filecheck(correct, ctx.module)
290+
291+
# CHECK: func.func @demo_fun1(%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32) -> i32 {
292+
# CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
293+
# CHECK: return %[[VAL_2]] : i32
294+
# CHECK: }
295+
296+
filecheck_with_comments(ctx.module)

0 commit comments

Comments
 (0)