1111from mlir .extras .dialects .ext import arith , memref , scf , gpu , linalg , transform
1212from mlir .dialects .transform import any_op_t
1313from mlir .extras .dialects .ext .func import func
14- from mlir .extras .dialects .ext .nvgpu import tensormap_descriptor
14+ from mlir .extras .dialects .ext .nvgpu import (
15+ TensorMapDescriptorType ,
16+ TensorMapSwizzleKind ,
17+ TensorMapL2PromoKind ,
18+ TensorMapOOBKind ,
19+ TensorMapInterleaveKind ,
20+ )
1521from mlir .dialects .transform .structured import MatchInterfaceEnum
1622from mlir .dialects .memref import cast
1723from mlir .dialects .nvgpu import tma_create_descriptor
@@ -37,7 +43,13 @@ def create_tensor_map(
3743 crd0 = arith .constant (64 , index = True )
3844 crd1 = arith .constant (128 , index = True )
3945 device_ptr_2d_unranked = cast (T .memref (element_type = T .f32 ()), device_ptr_2d )
40- tensor_map_2d = tensormap_descriptor (T .memref (32 , 32 , T .f32 (), memory_space = 3 ))
46+ tensor_map_2d = TensorMapDescriptorType .get (
47+ T .memref (32 , 32 , T .f32 (), memory_space = 3 ),
48+ TensorMapSwizzleKind .SWIZZLE_NONE ,
49+ TensorMapL2PromoKind .L2PROMO_NONE ,
50+ TensorMapOOBKind .OOB_NAN ,
51+ TensorMapInterleaveKind .INTERLEAVE_NONE ,
52+ )
4153 tensor_map_2d = tma_create_descriptor (
4254 tensor_map_2d , device_ptr_2d_unranked , [crd0 , crd1 ]
4355 )
@@ -187,8 +199,7 @@ def payload():
187199 compute_linspace_val .emit ()
188200
189201 @func
190- def printMemrefF32 (x : T .memref (T .f32 ())):
191- ...
202+ def printMemrefF32 (x : T .memref (T .f32 ())): ...
192203
193204 printMemrefF32_ .append (printMemrefF32 )
194205
@@ -408,6 +419,7 @@ def main(module: any_op_t()):
408419
409420CUDA_RUNTIME_LIB_PATH = Path (_mlir_libs .__file__ ).parent / f"libmlir_cuda_runtime.so"
410421
422+
411423# based on https://github.com/llvm/llvm-project/blob/9cc2122bf5a81f7063c2a32b2cb78c8d615578a1/mlir/test/Integration/GPU/CUDA/TensorCore/sm80/transform-mma-sync-matmul-f16-f16-accum.mlir#L6
412424@pytest .mark .skipif (not CUDA_RUNTIME_LIB_PATH .exists (), reason = "no cuda library" )
413425def test_transform_mma_sync_matmul_f16_f16_accum_run (ctx : MLIRContext , capfd ):
@@ -536,8 +548,7 @@ def payload():
536548 compute_linspace_val .emit ()
537549
538550 @func
539- def printMemrefF32 (x : T .memref (T .f32 ())):
540- ...
551+ def printMemrefF32 (x : T .memref (T .f32 ())): ...
541552
542553 printMemrefF32_ .append (printMemrefF32 )
543554
0 commit comments