@@ -460,3 +460,33 @@ func.func @cast_f16_to_f32_write(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf
460460 vector.transfer_write %cast , %arg3 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <16 x16 xf32 >, memref <16 x16 xf32 >
461461 return
462462}
463+
464+ // -----
465+
466+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
467+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
468+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
469+
470+ // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d1, d0)>
471+ // CHECK-LABEL: func @fold_transpose_into_transfer_read(
472+ // CHECK-SAME: %[[ALLOC:.+]]: memref<64x128xf16>
473+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
474+ // CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f16
475+ // CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true], permutation_map = #[[$MAP]]}
476+ // CHECK: %[[EXTF1:.+]] = arith.extf %[[READ]]
477+ // CHECK-NOT: vector.transpose
478+ // CHECK: %[[RESULT:.+]] = vector.contract
479+ func.func @fold_transpose_into_transfer_read (%alloc: memref <64 x128 xf16 >, %vector: vector <32 x128 xf16 >, %alloc2: memref <32 x64 xf32 >) {
480+ %c0 = arith.constant 0 : index
481+ %cst = arith.constant 0.000000e+00 : f16
482+ %init = arith.constant dense <0.000000e+00 > : vector <32 x64 xf32 >
483+ %0 = vector.transfer_read %alloc [%c0 , %c0 ], %cst {in_bounds = [true , true ]} : memref <64 x128 xf16 >, vector <64 x128 xf16 >
484+ %1 = arith.extf %0 : vector <64 x128 xf16 > to vector <64 x128 xf32 >
485+ %2 = arith.extf %vector : vector <32 x128 xf16 > to vector <32 x128 xf32 >
486+ %3 = vector.transpose %1 , [1 , 0 ] : vector <64 x128 xf32 > to vector <128 x64 xf32 >
487+ %4 = vector.contract {index ing_maps = [#map1 , #map2 , #map3 ], iterator_types = [" parallel" , " parallel" , " reduction" ], kind = #vector.kind <add >} %2 , %3 , %init : vector <32 x128 xf32 >, vector <128 x64 xf32 > into vector <32 x64 xf32 >
488+ vector.transfer_write %4 , %alloc2 [%c0 , %c0 ] {in_bounds = [true , true ]} : vector <32 x64 xf32 >, memref <32 x64 xf32 >
489+ return
490+ }
491+
492+ // -----
0 commit comments