@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
774774 }
775775};
776776
777- // / Returns an iterator over the dims (inc scalability) of a VectorType.
778- static auto getDims (VectorType vType) {
779- return llvm::zip_equal (vType.getShape (), vType.getScalableDims ());
780- }
781-
782- // / Helper to drop (fixed-size) unit dims from a VectorType.
783- static VectorType dropUnitDims (VectorType vType) {
784- SmallVector<bool > scalableFlags;
785- SmallVector<int64_t > dimSizes;
786- for (auto dim : getDims (vType)) {
787- if (dim == std::make_tuple (1 , false ))
788- continue ;
789- auto [size, scalableFlag] = dim;
790- dimSizes.push_back (size);
791- scalableFlags.push_back (scalableFlag);
792- }
793- return VectorType::get (dimSizes, vType.getElementType (), scalableFlags);
794- }
795-
796- // / A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
797- // / shape_cast only drops unit dimensions.
798- // /
799- // / This simplifies the transpose making it possible for other legalization
800- // / rewrites to handle it.
801- // /
802- // / Example:
803- // /
804- // / BEFORE:
805- // / ```mlir
806- // / %0 = vector.transpose %vector, [3, 0, 1, 2]
807- // / : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
808- // / %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
809- // / ```
810- // /
811- // / AFTER:
812- // / ```mlir
813- // / %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
814- // / %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
815- // / ```
816- struct SwapShapeCastOfTranspose : public OpRewritePattern <vector::ShapeCastOp> {
817- using OpRewritePattern::OpRewritePattern;
818-
819- LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
820- PatternRewriter &rewriter) const override {
821- auto transposeOp =
822- shapeCastOp.getSource ().getDefiningOp <vector::TransposeOp>();
823- if (!transposeOp)
824- return rewriter.notifyMatchFailure (shapeCastOp, " not TransposeOp" );
825-
826- auto resultType = shapeCastOp.getResultVectorType ();
827- if (resultType.getRank () <= 1 )
828- return rewriter.notifyMatchFailure (shapeCastOp, " result rank too low" );
829-
830- if (resultType != dropUnitDims (shapeCastOp.getSourceVectorType ()))
831- return rewriter.notifyMatchFailure (
832- shapeCastOp, " ShapeCastOp changes non-unit dimension(s)" );
833-
834- auto transposeSourceVectorType = transposeOp.getSourceVectorType ();
835- auto transposeSourceDims =
836- llvm::to_vector (getDims (transposeSourceVectorType));
837-
838- // Construct a map from dimIdx -> number of dims dropped before dimIdx.
839- SmallVector<int64_t > droppedDimsBefore (transposeSourceVectorType.getRank ());
840- int64_t droppedDims = 0 ;
841- for (auto [i, dim] : llvm::enumerate (transposeSourceDims)) {
842- droppedDimsBefore[i] = droppedDims;
843- if (dim == std::make_tuple (1 , false ))
844- ++droppedDims;
845- }
846-
847- // Drop unit dims from transpose permutation.
848- auto perm = transposeOp.getPermutation ();
849- SmallVector<int64_t > newPerm;
850- for (int64_t idx : perm) {
851- if (transposeSourceDims[idx] == std::make_tuple (1 , false ))
852- continue ;
853- newPerm.push_back (idx - droppedDimsBefore[idx]);
854- }
855-
856- auto loc = shapeCastOp.getLoc ();
857- auto newShapeCastOp = rewriter.create <vector::ShapeCastOp>(
858- loc, dropUnitDims (transposeSourceVectorType), transposeOp.getVector ());
859- rewriter.replaceOpWithNewOp <vector::TransposeOp>(shapeCastOp,
860- newShapeCastOp, newPerm);
861- return success ();
862- }
863- };
864-
865777// / Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
866778// / the ZA state. This workaround rewrite to support these transposes when ZA is
867779// / available.
@@ -1027,8 +939,7 @@ struct VectorLegalizationPass
1027939 patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
1028940 LiftIllegalVectorTransposeToMemory,
1029941 ConvertIllegalShapeCastOpsToTransposes,
1030- SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
1031- context);
942+ LowerIllegalTransposeStoreViaZA>(context);
1032943 // Note: These two patterns are added with a high benefit to ensure:
1033944 // - Masked outer products are handled before unmasked ones
1034945 // - Multi-tile writes are lowered as a store loop (if possible)
0 commit comments