1616from mlir .ir import FunctionType
1717
1818# noinspection PyUnresolvedReferences
19- from mlir .extras .testing import mlir_ctx as ctx , filecheck , MLIRContext
19+ from mlir .extras .testing import (
20+ mlir_ctx as ctx ,
21+ filecheck ,
22+ filecheck_with_comments ,
23+ MLIRContext ,
24+ )
2025
2126# needed since the fix isn't defined here nor conftest.py
2227pytest .mark .usefixtures ("ctx" )
@@ -31,17 +36,13 @@ def demo_fun1():
3136 assert hasattr (demo_fun1 , "emit" )
3237 assert inspect .ismethod (demo_fun1 .emit )
3338 demo_fun1 .emit ()
34- correct = dedent (
35- """\
36- module {
37- func.func @demo_fun1() -> i32 {
38- %c1_i32 = arith.constant 1 : i32
39- return %c1_i32 : i32
40- }
41- }
42- """
43- )
44- filecheck (correct , ctx .module )
39+
40+ # CHECK: func.func @demo_fun1() -> i32 {
41+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
42+ # CHECK: return %[[VAL_0]] : i32
43+ # CHECK: }
44+
45+ filecheck_with_comments (ctx .module )
4546
4647
4748def test_declare_byte_rep (ctx : MLIRContext ):
@@ -79,22 +80,18 @@ def demo_fun4(x: T.i32(), y: T.i32()) -> (T.i32(), T.i32()): ...
7980 demo_fun4 (one , one )
8081
8182 ctx .module .operation .verify ()
82- correct = dedent (
83- """\
84- module {
85- func.func private @demo_fun1() -> i32
86- func.func private @demo_fun2() -> (i32, i32)
87- func.func private @demo_fun3(i32) -> (i32, i32)
88- func.func private @demo_fun4(i32, i32) -> (i32, i32)
89- %0 = func.call @demo_fun1() : () -> i32
90- %1:2 = func.call @demo_fun2() : () -> (i32, i32)
91- %c1_i32 = arith.constant 1 : i32
92- %2:2 = func.call @demo_fun3(%c1_i32) : (i32) -> (i32, i32)
93- %3:2 = func.call @demo_fun4(%c1_i32, %c1_i32) : (i32, i32) -> (i32, i32)
94- }
95- """
96- )
97- filecheck (correct , ctx .module )
83+
84+ # CHECK: func.func private @demo_fun1() -> i32
85+ # CHECK: func.func private @demo_fun2() -> (i32, i32)
86+ # CHECK: func.func private @demo_fun3(i32) -> (i32, i32)
87+ # CHECK: func.func private @demo_fun4(i32, i32) -> (i32, i32)
88+ # CHECK: %[[VAL_0:.*]] = func.call @demo_fun1() : () -> i32
89+ # CHECK: %[[VAL_1:.*]]:2 = func.call @demo_fun2() : () -> (i32, i32)
90+ # CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
91+ # CHECK: %[[VAL_3:.*]]:2 = func.call @demo_fun3(%[[VAL_2]]) : (i32) -> (i32, i32)
92+ # CHECK: %[[VAL_4:.*]]:2 = func.call @demo_fun4(%[[VAL_2]], %[[VAL_2]]) : (i32, i32) -> (i32, i32)
93+
94+ filecheck_with_comments (ctx .module )
9895
9996
10097def test_func_base_meta (ctx : MLIRContext ):
@@ -104,31 +101,15 @@ def foo1():
104101 return one
105102
106103 foo1 .emit ()
107- correct = dedent (
108- """\
109- module {
110- func.func @foo1() -> i32 {
111- %c1_i32 = arith.constant 1 : i32
112- return %c1_i32 : i32
113- }
114- }
115- """
116- )
117- filecheck (correct , ctx .module )
118-
119104 foo1 ()
120- correct = dedent (
121- """\
122- module {
123- func.func @foo1() -> i32 {
124- %c1_i32 = arith.constant 1 : i32
125- return %c1_i32 : i32
126- }
127- %0 = func.call @foo1() : () -> i32
128- }
129- """
130- )
131- filecheck (correct , ctx .module )
105+
106+ # CHECK: func.func @foo1() -> i32 {
107+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
108+ # CHECK: return %[[VAL_0]] : i32
109+ # CHECK: }
110+ # CHECK: %[[VAL_1:.*]] = func.call @foo1() : () -> i32
111+
112+ filecheck_with_comments (ctx .module )
132113
133114
134115def test_func_base_meta2 (ctx : MLIRContext ):
@@ -138,18 +119,14 @@ def foo1():
138119 return one
139120
140121 foo1 ()
141- correct = dedent (
142- """\
143- module {
144- func.func @foo1() -> i32 {
145- %c1_i32 = arith.constant 1 : i32
146- return %c1_i32 : i32
147- }
148- %0 = func.call @foo1() : () -> i32
149- }
150- """
151- )
152- filecheck (correct , ctx .module )
122+
123+ # CHECK: func.func @foo1() -> i32 {
124+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
125+ # CHECK: return %[[VAL_0]] : i32
126+ # CHECK: }
127+ # CHECK: %[[VAL_1:.*]] = func.call @foo1() : () -> i32
128+
129+ filecheck_with_comments (ctx .module )
153130
154131
155132def test_func_no_context ():
@@ -160,18 +137,14 @@ def foo1():
160137
161138 with mlir_mod_ctx () as mod_ctx :
162139 foo1 ()
163- correct = dedent (
164- """\
165- module {
166- func.func @foo1() -> i32 {
167- %c1_i32 = arith.constant 1 : i32
168- return %c1_i32 : i32
169- }
170- %0 = func.call @foo1() : () -> i32
171- }
172- """
173- )
174- filecheck (correct , mod_ctx .module )
140+
141+ # CHECK: func.func @foo1() -> i32 {
142+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
143+ # CHECK: return %[[VAL_0]] : i32
144+ # CHECK: }
145+ # CHECK: %[[VAL_1:.*]] = func.call @foo1() : () -> i32
146+
147+ filecheck_with_comments (mod_ctx .module )
175148
176149
177150generics = M , K , N , dtype = list (map (TypeVar , ["M" , "K" , "N" , "dtype" ]))
@@ -188,17 +161,13 @@ def matmul_i32_i32(
188161
189162def test_func_no_context_2 (ctx : MLIRContext ):
190163 matmul_i32_i32 [16 , 16 ].emit ()
191- correct = dedent (
192- """\
193- module {
194- func.func @matmul_i32_i32(%arg0: memref<16x16xi32>, %arg1: memref<16x16xi32>, %arg2: memref<16x16xi32>) {
195- linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%arg0, %arg1 : memref<16x16xi32>, memref<16x16xi32>) outs(%arg2 : memref<16x16xi32>)
196- return
197- }
198- }
199- """
200- )
201- filecheck (correct , ctx .module )
164+
165+ # CHECK: func.func @matmul_i32_i32(%[[VAL_0:.*]]: memref<16x16xi32>, %[[VAL_1:.*]]: memref<16x16xi32>, %[[VAL_2:.*]]: memref<16x16xi32>) {
166+ # CHECK: linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%[[VAL_0]], %[[VAL_1]] : memref<16x16xi32>, memref<16x16xi32>) outs(%[[VAL_2]] : memref<16x16xi32>)
167+ # CHECK: return
168+ # CHECK: }
169+
170+ filecheck_with_comments (ctx .module )
202171
203172
204173def test_generics_just_args (ctx : MLIRContext ):
@@ -212,17 +181,13 @@ def mat_product_kernel(
212181 one = arith .constant (1.0 , dtype )
213182
214183 mat_product_kernel [32 , 32 , 32 , T .f32 ()].emit ()
215- correct = dedent (
216- """\
217- module {
218- func.func @mat_product_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
219- %cst = arith.constant 1.000000e+00 : f32
220- return
221- }
222- }
223- """
224- )
225- filecheck (correct , ctx .module )
184+
185+ # CHECK: func.func @mat_product_kernel(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) {
186+ # CHECK: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32
187+ # CHECK: return
188+ # CHECK: }
189+
190+ filecheck_with_comments (ctx .module )
226191
227192
228193def test_generics_closure (ctx : MLIRContext ):
@@ -237,17 +202,13 @@ def mat_product_kernel(
237202 one = arith .constant (1 , dtype )
238203
239204 mat_product_kernel [32 , 32 , 32 , T .i32 ()].emit ()
240- correct = dedent (
241- """\
242- module {
243- func.func @mat_product_kernel(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>, %arg2: memref<32x32xi32>) {
244- %c1_i32 = arith.constant 1 : i32
245- return
246- }
247- }
248- """
249- )
250- filecheck (correct , ctx .module )
205+
206+ # CHECK: func.func @mat_product_kernel(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) {
207+ # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
208+ # CHECK: return
209+ # CHECK: }
210+
211+ filecheck_with_comments (ctx .module )
251212
252213
253214def test_generics_with_canonicalizations (ctx : MLIRContext ):
@@ -271,31 +232,27 @@ def mat_product_kernel(
271232 C [x , y ] = tmp + one
272233
273234 mat_product_kernel [32 , 32 , 32 , T .f32 ()].emit ()
274- correct = dedent (
275- """\
276- module {
277- func.func @mat_product_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
278- %c1 = arith.constant 1 : index
279- %c1_0 = arith.constant 1 : index
280- %cst = arith.constant 1.000000e+00 : f32
281- %cst_1 = arith.constant 0.000000e+00 : f32
282- %c0 = arith.constant 0 : index
283- %c32 = arith.constant 32 : index
284- %c1_2 = arith.constant 1 : index
285- %0 = scf.for %arg3 = %c0 to %c32 step %c1_2 iter_args(%arg4 = %cst_1) -> (f32) {
286- %2 = memref.load %arg0[%c1, %arg3] : memref<32x32xf32>
287- %3 = memref.load %arg1[%arg3, %c1_0] : memref<32x32xf32>
288- %4 = math.fma %2, %3, %arg4 : f32
289- scf.yield %4 : f32
290- }
291- %1 = arith.addf %0, %cst : f32
292- memref.store %1, %arg2[%c1, %c1_0] : memref<32x32xf32>
293- return
294- }
295- }
296- """
297- )
298- filecheck (correct , ctx .module )
235+
236+ # CHECK: func.func @mat_product_kernel(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) {
237+ # CHECK: %[[VAL_3:.*]] = arith.constant 1 : index
238+ # CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
239+ # CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32
240+ # CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32
241+ # CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
242+ # CHECK: %[[VAL_8:.*]] = arith.constant 32 : index
243+ # CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
244+ # CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (f32) {
245+ # CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_11]]] : memref<32x32xf32>
246+ # CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_11]], %[[VAL_4]]] : memref<32x32xf32>
247+ # CHECK: %[[VAL_15:.*]] = math.fma %[[VAL_13]], %[[VAL_14]], %[[VAL_12]] : f32
248+ # CHECK: scf.yield %[[VAL_15]] : f32
249+ # CHECK: }
250+ # CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_17:.*]], %[[VAL_5]] : f32
251+ # CHECK: memref.store %[[VAL_16]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_4]]] : memref<32x32xf32>
252+ # CHECK: return
253+ # CHECK: }
254+
255+ filecheck_with_comments (ctx .module )
299256
300257
301258def test_raii_mlir_context_module ():
@@ -310,17 +267,13 @@ def demo_fun1():
310267 assert hasattr (demo_fun1 , "emit" )
311268 assert inspect .ismethod (demo_fun1 .emit )
312269 demo_fun1 .emit ()
313- correct = dedent (
314- """\
315- module {
316- func.func @demo_fun1() -> i32 {
317- %c1_i32 = arith.constant 1 : i32
318- return %c1_i32 : i32
319- }
320- }
321- """
322- )
323- filecheck (correct , tls .ctx .module )
270+
271+ # CHECK: func.func @demo_fun1() -> i32 {
272+ # CHECK: %[[VAL_0:.*]] = arith.constant 1 : i32
273+ # CHECK: return %[[VAL_0]] : i32
274+ # CHECK: }
275+
276+ filecheck_with_comments (tls .ctx .module )
324277
325278
326279def test_explicit_function_type (ctx : MLIRContext ):
@@ -334,14 +287,10 @@ def demo_fun1(a, b):
334287 return one
335288
336289 demo_fun1 .emit ()
337- correct = dedent (
338- """\
339- module {
340- func.func @demo_fun1(%arg0: i32, %arg1: i32) -> i32 {
341- %c1_i32 = arith.constant 1 : i32
342- return %c1_i32 : i32
343- }
344- }
345- """
346- )
347- filecheck (correct , ctx .module )
290+
291+ # CHECK: func.func @demo_fun1(%[[VAL_0:.*]]: i32, %[[VAL_1:.*]]: i32) -> i32 {
292+ # CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
293+ # CHECK: return %[[VAL_2]] : i32
294+ # CHECK: }
295+
296+ filecheck_with_comments (ctx .module )
0 commit comments