@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
774
774
}
775
775
};
776
776
777
- // / Returns an iterator over the dims (inc scalability) of a VectorType.
778
- static auto getDims (VectorType vType) {
779
- return llvm::zip_equal (vType.getShape (), vType.getScalableDims ());
780
- }
781
-
782
- // / Helper to drop (fixed-size) unit dims from a VectorType.
783
- static VectorType dropUnitDims (VectorType vType) {
784
- SmallVector<bool > scalableFlags;
785
- SmallVector<int64_t > dimSizes;
786
- for (auto dim : getDims (vType)) {
787
- if (dim == std::make_tuple (1 , false ))
788
- continue ;
789
- auto [size, scalableFlag] = dim;
790
- dimSizes.push_back (size);
791
- scalableFlags.push_back (scalableFlag);
792
- }
793
- return VectorType::get (dimSizes, vType.getElementType (), scalableFlags);
794
- }
795
-
796
- // / A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
797
- // / shape_cast only drops unit dimensions.
798
- // /
799
- // / This simplifies the transpose making it possible for other legalization
800
- // / rewrites to handle it.
801
- // /
802
- // / Example:
803
- // /
804
- // / BEFORE:
805
- // / ```mlir
806
- // / %0 = vector.transpose %vector, [3, 0, 1, 2]
807
- // / : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
808
- // / %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
809
- // / ```
810
- // /
811
- // / AFTER:
812
- // / ```mlir
813
- // / %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
814
- // / %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
815
- // / ```
816
- struct SwapShapeCastOfTranspose : public OpRewritePattern <vector::ShapeCastOp> {
817
- using OpRewritePattern::OpRewritePattern;
818
-
819
- LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
820
- PatternRewriter &rewriter) const override {
821
- auto transposeOp =
822
- shapeCastOp.getSource ().getDefiningOp <vector::TransposeOp>();
823
- if (!transposeOp)
824
- return rewriter.notifyMatchFailure (shapeCastOp, " not TransposeOp" );
825
-
826
- auto resultType = shapeCastOp.getResultVectorType ();
827
- if (resultType.getRank () <= 1 )
828
- return rewriter.notifyMatchFailure (shapeCastOp, " result rank too low" );
829
-
830
- if (resultType != dropUnitDims (shapeCastOp.getSourceVectorType ()))
831
- return rewriter.notifyMatchFailure (
832
- shapeCastOp, " ShapeCastOp changes non-unit dimension(s)" );
833
-
834
- auto transposeSourceVectorType = transposeOp.getSourceVectorType ();
835
- auto transposeSourceDims =
836
- llvm::to_vector (getDims (transposeSourceVectorType));
837
-
838
- // Construct a map from dimIdx -> number of dims dropped before dimIdx.
839
- SmallVector<int64_t > droppedDimsBefore (transposeSourceVectorType.getRank ());
840
- int64_t droppedDims = 0 ;
841
- for (auto [i, dim] : llvm::enumerate (transposeSourceDims)) {
842
- droppedDimsBefore[i] = droppedDims;
843
- if (dim == std::make_tuple (1 , false ))
844
- ++droppedDims;
845
- }
846
-
847
- // Drop unit dims from transpose permutation.
848
- auto perm = transposeOp.getPermutation ();
849
- SmallVector<int64_t > newPerm;
850
- for (int64_t idx : perm) {
851
- if (transposeSourceDims[idx] == std::make_tuple (1 , false ))
852
- continue ;
853
- newPerm.push_back (idx - droppedDimsBefore[idx]);
854
- }
855
-
856
- auto loc = shapeCastOp.getLoc ();
857
- auto newShapeCastOp = rewriter.create <vector::ShapeCastOp>(
858
- loc, dropUnitDims (transposeSourceVectorType), transposeOp.getVector ());
859
- rewriter.replaceOpWithNewOp <vector::TransposeOp>(shapeCastOp,
860
- newShapeCastOp, newPerm);
861
- return success ();
862
- }
863
- };
864
-
865
777
// / Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
866
778
// / the ZA state. This workaround rewrite to support these transposes when ZA is
867
779
// / available.
@@ -1027,8 +939,7 @@ struct VectorLegalizationPass
1027
939
patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
1028
940
LiftIllegalVectorTransposeToMemory,
1029
941
ConvertIllegalShapeCastOpsToTransposes,
1030
- SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
1031
- context);
942
+ LowerIllegalTransposeStoreViaZA>(context);
1032
943
// Note: These two patterns are added with a high benefit to ensure:
1033
944
// - Masked outer products are handled before unmasked ones
1034
945
// - Multi-tile writes are lowered as a store loop (if possible)
0 commit comments