@@ -666,14 +666,69 @@ struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
666
666
}
667
667
};
668
668
669
+ // / Folds a MoveTileSliceToVectorOp + TransferWriteOp to a StoreTileSliceOp.
670
+ // /
671
+ // / BEFORE:
672
+ // / ```mlir
673
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%index]
674
+ // / : vector<[4]xf32> from vector<[4]x[4]xf32>
675
+ // / vector.transfer_write %slice, %memref[%i, %j], %mask {in_bounds = [true]}
676
+ // / : vector<[4]xf32>, memref<?x?xf32>
677
+ // / ```
678
+ // / AFTER:
679
+ // / ```mlir
680
+ // / arm_sme.store_tile_slice %tile, %index, %mask, %memref[%i, %j]
681
+ // / : memref<?x?xf32>, vector<[4]xi1>, vector<[4]x[4]xf32>
682
+ // / ```
683
+ struct FoldTransferWriteOfExtractTileSlice
684
+ : public OpRewritePattern<vector::TransferWriteOp> {
685
+ using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
686
+
687
+ LogicalResult matchAndRewrite (vector::TransferWriteOp writeOp,
688
+ PatternRewriter &rewriter) const final {
689
+ if (!isa<MemRefType>(writeOp.getSource ().getType ()))
690
+ return rewriter.notifyMatchFailure (writeOp, " destination not a memref" );
691
+
692
+ if (writeOp.hasOutOfBoundsDim ())
693
+ return rewriter.notifyMatchFailure (writeOp,
694
+ " not inbounds transfer write" );
695
+
696
+ auto moveTileSlice =
697
+ writeOp.getVector ().getDefiningOp <arm_sme::MoveTileSliceToVectorOp>();
698
+ if (!moveTileSlice)
699
+ return rewriter.notifyMatchFailure (
700
+ writeOp, " vector to store not from MoveTileSliceToVectorOp" );
701
+
702
+ AffineMap map = writeOp.getPermutationMap ();
703
+ if (!map.isMinorIdentity ())
704
+ return rewriter.notifyMatchFailure (writeOp,
705
+ " unsupported permutation map" );
706
+
707
+ Value mask = writeOp.getMask ();
708
+ if (!mask) {
709
+ auto maskType = writeOp.getVectorType ().clone (rewriter.getI1Type ());
710
+ mask = rewriter.create <arith::ConstantOp>(
711
+ writeOp.getLoc (), maskType, DenseElementsAttr::get (maskType, true ));
712
+ }
713
+
714
+ rewriter.replaceOpWithNewOp <arm_sme::StoreTileSliceOp>(
715
+ writeOp, moveTileSlice.getTile (), moveTileSlice.getTileSliceIndex (),
716
+ mask, writeOp.getSource (), writeOp.getIndices (),
717
+ moveTileSlice.getLayout ());
718
+ return success ();
719
+ }
720
+ };
721
+
669
722
} // namespace
670
723
671
724
void mlir::populateVectorToArmSMEPatterns (RewritePatternSet &patterns,
672
725
MLIRContext &ctx) {
673
- patterns.add <BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
674
- TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
675
- TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
676
- VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
677
- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
678
- VectorPrintToArmSMELowering>(&ctx);
726
+ patterns
727
+ .add <BroadcastOpToArmSMELowering, SplatOpToArmSMELowering,
728
+ TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
729
+ TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
730
+ VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
731
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
732
+ VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
733
+ &ctx);
679
734
}
0 commit comments