@@ -728,6 +728,52 @@ func.func @async_tma_load_multicast(
728728 func.return
729729}
730730
731+ func.func @async_tma_store (%tensorMap1d: !tensorMap1d , %tensorMap2d: !tensorMap2d , %tensorMap3d: !tensorMap3d , %tensorMap4d: !tensorMap4d , %tensorMap5d: !tensorMap5d ,
732+ %buffer1d: memref <128 xf32 ,3 >,
733+ %buffer2d: memref <32 x32 xf32 ,3 >,
734+ %buffer3d: memref <2 x32 x32 xf32 ,3 >,
735+ %buffer4d: memref <2 x2 x32 x32 xf32 ,3 >,
736+ %buffer5d: memref <2 x2 x2 x32 x32 xf32 ,3 >) {
737+ %c0 = arith.constant 0 : index
738+ %crd0 = arith.constant 0 : index
739+ %crd1 = arith.constant 0 : index
740+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}]
741+ nvgpu.tma.async.store %buffer1d to %tensorMap1d [%crd0 ] : memref <128 xf32 ,3 > -> !tensorMap1d
742+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}]
743+ nvgpu.tma.async.store %buffer2d to %tensorMap2d [%crd0 , %crd1 ] : memref <32 x32 xf32 ,3 > -> !tensorMap2d
744+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}]
745+ nvgpu.tma.async.store %buffer3d to %tensorMap3d [%crd0 , %crd1 , %crd0 ] : memref <2 x32 x32 xf32 ,3 > -> !tensorMap3d
746+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
747+ nvgpu.tma.async.store %buffer4d to %tensorMap4d [%crd0 , %crd1 , %crd1 , %crd0 ] : memref <2 x2 x32 x32 xf32 ,3 > -> !tensorMap4d
748+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}]
749+ nvgpu.tma.async.store %buffer5d to %tensorMap5d [%crd0 , %crd1 , %crd1 , %crd0 , %crd0 ] : memref <2 x2 x2 x32 x32 xf32 ,3 > -> !tensorMap5d
750+ func.return
751+ }
752+
753+
754+ func.func @async_tma_store_predicate (%tensorMap1d: !tensorMap1d , %tensorMap2d: !tensorMap2d , %tensorMap3d: !tensorMap3d , %tensorMap4d: !tensorMap4d , %tensorMap5d: !tensorMap5d ,
755+ %buffer1d: memref <128 xf32 ,3 >,
756+ %buffer2d: memref <32 x32 xf32 ,3 >,
757+ %buffer3d: memref <2 x32 x32 xf32 ,3 >,
758+ %buffer4d: memref <2 x2 x32 x32 xf32 ,3 >,
759+ %buffer5d: memref <2 x2 x2 x32 x32 xf32 ,3 >,
760+ %p: i1 ) {
761+ %c0 = arith.constant 0 : index
762+ %crd0 = arith.constant 0 : index
763+ %crd1 = arith.constant 0 : index
764+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}], predicate = %{{.*}}
765+ nvgpu.tma.async.store %buffer1d to %tensorMap1d [%crd0 ], predicate = %p : memref <128 xf32 ,3 > -> !tensorMap1d
766+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}], predicate = %{{.*}}
767+ nvgpu.tma.async.store %buffer2d to %tensorMap2d [%crd0 , %crd1 ], predicate = %p : memref <32 x32 xf32 ,3 > -> !tensorMap2d
768+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
769+ nvgpu.tma.async.store %buffer3d to %tensorMap3d [%crd0 , %crd1 , %crd0 ], predicate = %p : memref <2 x32 x32 xf32 ,3 > -> !tensorMap3d
770+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
771+ nvgpu.tma.async.store %buffer4d to %tensorMap4d [%crd0 , %crd1 , %crd1 , %crd0 ], predicate = %p : memref <2 x2 x32 x32 xf32 ,3 > -> !tensorMap4d
772+ // CHECK: nvvm.cp.async.bulk.tensor.global.shared.cta %{{.*}} %{{.*}}, box[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], predicate = %{{.*}}
773+ nvgpu.tma.async.store %buffer5d to %tensorMap5d [%crd0 , %crd1 , %crd1 , %crd0 , %crd0 ], predicate = %p : memref <2 x2 x2 x32 x32 xf32 ,3 > -> !tensorMap5d
774+ func.return
775+ }
776+
731777func.func @create_tensor_map (%devicePtr2d : memref <64 x128 xf32 >, %devicePtr1d : memref <128 xf32 >) {
732778 %crd0 = arith.constant 64 : index
733779 %crd1 = arith.constant 128 : index
0 commit comments