1+ import sys
12from textwrap import dedent
23
34import pytest
@@ -138,6 +139,7 @@ def mat_product_kernel(
138139 filecheck (correct , ctx .module )
139140
140141
142+ @pytest .mark .skipif (sys .version_info < (3 , 10 ), reason = "requires python3.10 or higher" )
141143def test_class_call (ctx : MLIRContext ):
142144 scale = 1
143145 M , N , K = 4 * scale , 16 * scale , 8 * scale
@@ -163,7 +165,10 @@ def mat_product_kernel(
163165 b = alloc ((N , K ), T .f32 ())
164166 c = alloc ((M , K ), T .f32 ())
165167
166- MyClass1 .mat_product_kernel [grid_size := [4 , 4 , 1 ], block_size := [1 , 1 , 1 ]](a , b , c )
168+ # this is to avoid python 3.8 parser
169+ eval (
170+ "MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c)"
171+ )
167172
168173 correct = dedent (
169174 """\
@@ -196,6 +201,7 @@ def mat_product_kernel(
196201 filecheck (correct , ctx .module )
197202
198203
204+ @pytest .mark .skipif (sys .version_info < (3 , 10 ), reason = "requires python3.10 or higher" )
199205def test_class_call_from_func (ctx : MLIRContext ):
200206 scale = 1
201207 M , N , K = 4 * scale , 16 * scale , 8 * scale
@@ -227,8 +233,9 @@ def main():
227233 b = alloc ((N , K ), T .f32 ())
228234 c = alloc ((M , K ), T .f32 ())
229235
230- MyClass1 .mat_product_kernel [grid_size := [4 , 4 , 1 ], block_size := [1 , 1 , 1 ]](
231- a , b , c
236+ MyClass1
237+ eval (
238+ "MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c)"
232239 )
233240
234241 ctx .module .operation .verify ()
@@ -267,6 +274,7 @@ def main():
267274 filecheck (correct , ctx .module )
268275
269276
277+ @pytest .mark .skipif (sys .version_info < (3 , 10 ), reason = "requires python3.10 or higher" )
270278def test_async_object (ctx : MLIRContext ):
271279 scale = 1
272280 M , N , K = 4 * scale , 16 * scale , 8 * scale
@@ -300,12 +308,9 @@ def main():
300308
301309 w = wait ()
302310 stream = mlir_zero (llvm_ptr_t ())
303- MyClass1 .mat_product_kernel [grid_size := [4 , 4 , 1 ], block_size := [1 , 1 , 1 ]](
304- a ,
305- b ,
306- c ,
307- async_dependencies = [w ],
308- stream = stream ,
311+ MyClass1
312+ eval (
313+ "MyClass1.mat_product_kernel[grid_size:= [4, 4, 1], block_size:= [1, 1, 1]](a, b, c, async_dependencies=[w], stream=stream)"
309314 )
310315
311316 correct = dedent (
0 commit comments