11import inspect
22import sys
33from textwrap import dedent
4+ from typing import TypeVar
45
56import pytest
67
78import mlir .extras .types as T
9+
10+ from mlir .extras .ast .canonicalize import canonicalize
811from mlir .extras .context import mlir_mod_ctx
912from mlir .extras .dialects .ext .arith import constant
1013from mlir .extras .dialects .ext .func import func
11- from mlir .extras .dialects .ext import linalg
14+ from mlir .extras .dialects .ext import linalg , arith , scf , memref
1215
1316# noinspection PyUnresolvedReferences
1417from mlir .extras .testing import mlir_ctx as ctx , filecheck , MLIRContext
@@ -40,8 +43,7 @@ def demo_fun1():
4043
4144
4245def test_declare_byte_rep (ctx : MLIRContext ):
43- def demo_fun1 ():
44- ...
46+ def demo_fun1 (): ...
4547
4648 if sys .version_info .minor == 12 :
4749 assert demo_fun1 .__code__ .co_code == b"\x97 \x00 y\x00 "
@@ -55,20 +57,16 @@ def demo_fun1():
5557
5658def test_declare (ctx : MLIRContext ):
5759 @func
58- def demo_fun1 () -> T .i32 ():
59- ...
60+ def demo_fun1 () -> T .i32 (): ...
6061
6162 @func
62- def demo_fun2 () -> (T .i32 (), T .i32 ()):
63- ...
63+ def demo_fun2 () -> (T .i32 (), T .i32 ()): ...
6464
6565 @func
66- def demo_fun3 (x : T .i32 ()) -> (T .i32 (), T .i32 ()):
67- ...
66+ def demo_fun3 (x : T .i32 ()) -> (T .i32 (), T .i32 ()): ...
6867
6968 @func
70- def demo_fun4 (x : T .i32 (), y : T .i32 ()) -> (T .i32 (), T .i32 ()):
71- ...
69+ def demo_fun4 (x : T .i32 (), y : T .i32 ()) -> (T .i32 (), T .i32 ()): ...
7270
7371 demo_fun1 ()
7472 demo_fun2 ()
@@ -197,3 +195,99 @@ def test_func_no_context_2(ctx: MLIRContext):
197195 """
198196 )
199197 filecheck (correct , ctx .module )
198+
199+
200+ def test_generics_just_args (ctx : MLIRContext ):
201+ @func (generics = ["M" , "K" , "N" , "dtype" ])
202+ def mat_product_kernel (
203+ A : "T.memref(M, K, dtype)" ,
204+ B : "T.memref(K, N, dtype)" ,
205+ C : "T.memref(M, N, dtype)" ,
206+ ):
207+ one = arith .constant (1.0 )
208+
209+ mat_product_kernel [32 , 32 , 32 , T .i32 ()].emit ()
210+ correct = dedent (
211+ """\
212+ module {
213+ func.func @mat_product_kernel(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>, %arg2: memref<32x32xi32>) {
214+ %cst = arith.constant 1.000000e+00 : f32
215+ return
216+ }
217+ }
218+ """
219+ )
220+ filecheck (correct , ctx .module )
221+
222+
223+ def test_generics_closure (ctx : MLIRContext ):
224+ dtype = None
225+
226+ @func (generics = ["M" , "K" , "N" , "dtype" ])
227+ def mat_product_kernel (
228+ A : "T.memref(M, K, dtype)" ,
229+ B : "T.memref(K, N, dtype)" ,
230+ C : "T.memref(M, N, dtype)" ,
231+ ):
232+ one = arith .constant (1 , dtype )
233+
234+ mat_product_kernel [32 , 32 , 32 , T .i32 ()].emit ()
235+ correct = dedent (
236+ """\
237+ module {
238+ func.func @mat_product_kernel(%arg0: memref<32x32xi32>, %arg1: memref<32x32xi32>, %arg2: memref<32x32xi32>) {
239+ %c1_i32 = arith.constant 1 : i32
240+ return
241+ }
242+ }
243+ """
244+ )
245+ filecheck (correct , ctx .module )
246+
247+
248+ def test_generics_with_canonicalizations (ctx : MLIRContext ):
249+ dtype = None
250+ K = None
251+
252+ @func (generics = ["M" , "K" , "N" , "dtype" ])
253+ @canonicalize (using = (arith .canonicalizer , scf .canonicalizer ))
254+ def mat_product_kernel (
255+ A : "T.memref(M, K, dtype)" ,
256+ B : "T.memref(K, N, dtype)" ,
257+ C : "T.memref(M, N, dtype)" ,
258+ ):
259+ x = arith .constant (1 , index = True )
260+ y = arith .constant (1 , index = True )
261+ one = arith .constant (1.0 , type = dtype )
262+ tmp = arith .constant (0 , type = dtype )
263+ for k , tmp in scf .range_ (K , iter_args = [tmp ]):
264+ tmp += A [x , k ] * B [k , y ]
265+ tmp = yield tmp
266+ C [x , y ] = tmp + one
267+
268+ mat_product_kernel [32 , 32 , 32 , T .f32 ()].emit ()
269+ correct = dedent (
270+ """\
271+ module {
272+ func.func @mat_product_kernel(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) {
273+ %c1 = arith.constant 1 : index
274+ %c1_0 = arith.constant 1 : index
275+ %cst = arith.constant 1.000000e+00 : f32
276+ %cst_1 = arith.constant 0.000000e+00 : f32
277+ %c0 = arith.constant 0 : index
278+ %c32 = arith.constant 32 : index
279+ %c1_2 = arith.constant 1 : index
280+ %0 = scf.for %arg3 = %c0 to %c32 step %c1_2 iter_args(%arg4 = %cst_1) -> (f32) {
281+ %2 = memref.load %arg0[%c1, %arg3] : memref<32x32xf32>
282+ %3 = memref.load %arg1[%arg3, %c1_0] : memref<32x32xf32>
283+ %4 = math.fma %2, %3, %arg4 : f32
284+ scf.yield %4 : f32
285+ }
286+ %1 = arith.addf %0, %cst : f32
287+ memref.store %1, %arg2[%c1, %c1_0] : memref<32x32xf32>
288+ return
289+ }
290+ }
291+ """
292+ )
293+ filecheck (correct , ctx .module )
0 commit comments