Skip to content

Commit 4d6b992

Browse files
authored
[mlir][ArmSME] Fold MoveTileSliceToVector + TransferWrite to StoreTileSlice (#95907)
1 parent 5dde495 commit 4d6b992

File tree

3 files changed

+117
-6
lines changed

3 files changed

+117
-6
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
666666
}
667667
};
668668

669+
/// Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
670+
///
671+
/// BEFORE:
672+
/// ```mlir
673+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
674+
/// : vector<[4]xf32> from vector<[4]x[4]xf32>
675+
/// vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676+
/// : vector<[4]xf32>, memref<?x?xf32>
677+
/// ```
678+
/// AFTER:
679+
/// ```mlir
680+
/// arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681+
/// : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682+
/// ```
683+
struct FoldTransferWriteOfExtractTileSlice
684+
: public OpRewritePattern<vector::TransferWriteOp> {
685+
using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
686+
687+
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
688+
PatternRewriter &rewriter) const final {
689+
if (!isa<MemRefType>(writeOp.getSource().getType()))
690+
return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
691+
692+
if (writeOp.hasOutOfBoundsDim())
693+
return rewriter.notifyMatchFailure(writeOp,
694+
"not inbounds transfer write");
695+
696+
auto moveTileSlice =
697+
writeOp.getVector().getDefiningOp<arm_sme::MoveTileSliceToVectorOp>();
698+
if (!moveTileSlice)
699+
return rewriter.notifyMatchFailure(
700+
writeOp, "vector to store not from MoveTileSliceToVectorOp");
701+
702+
AffineMap map = writeOp.getPermutationMap();
703+
if (!map.isMinorIdentity())
704+
return rewriter.notifyMatchFailure(writeOp,
705+
"unsupported permutation map");
706+
707+
Value mask = writeOp.getMask();
708+
if (!mask) {
709+
auto maskType = writeOp.getVectorType().clone(rewriter.getI1Type());
710+
mask = rewriter.create<arith::ConstantOp>(
711+
writeOp.getLoc(), maskType, DenseElementsAttr::get(maskType, true));
712+
}
713+
714+
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
715+
writeOp, moveTileSlice.getTile(), moveTileSlice.getTileSliceIndex(),
716+
mask, writeOp.getSource(), writeOp.getIndices(),
717+
moveTileSlice.getLayout());
718+
return success();
719+
}
720+
};
721+
669722
} // namespace
670723

671724
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
672725
MLIRContext &ctx) {
673-
patterns.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
674-
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
675-
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
676-
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
677-
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
678-
VectorPrintToArmSMELowering>(&ctx);
726+
patterns
727+
.add<BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
728+
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
729+
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
730+
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
731+
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
732+
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
733+
&ctx);
679734
}

mlir/test/Conversion/VectorToArmSME/unsupported.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,18 @@ func.func @transfer_write_2d__out_of_bounds(%vector : vector<[4]x[4]xf32>, %dest
145145
return
146146
}
147147

148+
// -----
149+
150+
// CHECK-LABEL: func.func @transfer_write_slice_unsupported_permutation
151+
// CHECK-NOT: arm_sme.store_tile_slice
152+
func.func @transfer_write_slice_unsupported_permutation(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
153+
%c0 = arith.constant 0 : index
154+
%slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
155+
vector.transfer_write %slice, %dest[%slice_index, %c0] { permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
156+
return
157+
}
158+
159+
148160
//===----------------------------------------------------------------------===//
149161
// vector.outerproduct
150162
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,50 @@ func.func @transfer_write_2d_transpose_with_mask_bf16(%vector : vector<[8]x[8]xb
334334
return
335335
}
336336

337+
// -----
338+
339+
// CHECK-LABEL: func.func @transfer_write_slice(
340+
// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
341+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
342+
// CHECK-SAME: %[[INDEX:.*]]: index) {
343+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
344+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<[4]xi1>
345+
// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
346+
func.func @transfer_write_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
347+
%c0 = arith.constant 0 : index
348+
%slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
349+
vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
350+
return
351+
}
352+
353+
// -----
354+
355+
// CHECK-LABEL: func.func @transfer_write_slice_with_mask(
356+
// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>,
357+
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
358+
// CHECK-SAME: %[[MASK:.*]]: vector<[4]xi1>,
359+
// CHECK-SAME: %[[INDEX:.*]]: index) {
360+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
361+
// CHECK: arm_sme.store_tile_slice %[[VECTOR]], %[[INDEX]], %[[MASK]], %[[DEST]][%[[INDEX]], %[[C0]]] : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
362+
func.func @transfer_write_slice_with_mask(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %mask: vector<[4]xi1>, %slice_index: index) {
363+
%c0 = arith.constant 0 : index
364+
%slice = vector.extract %vector[%slice_index] : vector<[4]xf32> from vector<[4]x[4]xf32>
365+
vector.transfer_write %slice, %dest[%slice_index, %c0], %mask { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
366+
return
367+
}
368+
369+
// -----
370+
371+
// CHECK-LABEL: func.func @transfer_write_vertical_slice
372+
// CHECK: arm_sme.store_tile_slice {{.*}} layout<vertical>
373+
func.func @transfer_write_vertical_slice(%vector: vector<[4]x[4]xf32>, %dest : memref<?x?xf32>, %slice_index: index) {
374+
%c0 = arith.constant 0 : index
375+
%slice = arm_sme.move_tile_slice_to_vector %vector[%slice_index] layout<vertical>
376+
: vector<[4]xf32> from vector<[4]x[4]xf32>
377+
vector.transfer_write %slice, %dest[%slice_index, %c0] { in_bounds = [true] }: vector<[4]xf32>, memref<?x?xf32>
378+
return
379+
}
380+
337381
//===----------------------------------------------------------------------===//
338382
// vector.broadcast
339383
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)