Skip to content

Commit fc4485b

Browse files
authored
Revert "[mlir][ArmSME] Pattern to swap shape_cast(tranpose) with transpose(shape_cast) (#100731)" (#102457)
This reverts commit 88accd9. This change can be dropped in favor of just #102017.
1 parent 94473f4 commit fc4485b

File tree

2 files changed

+1
-116
lines changed

2 files changed

+1
-116
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -774,94 +774,6 @@ struct ConvertIllegalShapeCastOpsToTransposes
774774
}
775775
};
776776

777-
/// Returns an iterator over the dims (inc scalability) of a VectorType.
778-
static auto getDims(VectorType vType) {
779-
return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
780-
}
781-
782-
/// Helper to drop (fixed-size) unit dims from a VectorType.
783-
static VectorType dropUnitDims(VectorType vType) {
784-
SmallVector<bool> scalableFlags;
785-
SmallVector<int64_t> dimSizes;
786-
for (auto dim : getDims(vType)) {
787-
if (dim == std::make_tuple(1, false))
788-
continue;
789-
auto [size, scalableFlag] = dim;
790-
dimSizes.push_back(size);
791-
scalableFlags.push_back(scalableFlag);
792-
}
793-
return VectorType::get(dimSizes, vType.getElementType(), scalableFlags);
794-
}
795-
796-
/// A pattern to swap shape_cast(tranpose) with transpose(shape_cast) if the
797-
/// shape_cast only drops unit dimensions.
798-
///
799-
/// This simplifies the transpose making it possible for other legalization
800-
/// rewrites to handle it.
801-
///
802-
/// Example:
803-
///
804-
/// BEFORE:
805-
/// ```mlir
806-
/// %0 = vector.transpose %vector, [3, 0, 1, 2]
807-
/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
808-
/// %1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
809-
/// ```
810-
///
811-
/// AFTER:
812-
/// ```mlir
813-
/// %0 = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
814-
/// %1 = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
815-
/// ```
816-
struct SwapShapeCastOfTranspose : public OpRewritePattern<vector::ShapeCastOp> {
817-
using OpRewritePattern::OpRewritePattern;
818-
819-
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
820-
PatternRewriter &rewriter) const override {
821-
auto transposeOp =
822-
shapeCastOp.getSource().getDefiningOp<vector::TransposeOp>();
823-
if (!transposeOp)
824-
return rewriter.notifyMatchFailure(shapeCastOp, "not TransposeOp");
825-
826-
auto resultType = shapeCastOp.getResultVectorType();
827-
if (resultType.getRank() <= 1)
828-
return rewriter.notifyMatchFailure(shapeCastOp, "result rank too low");
829-
830-
if (resultType != dropUnitDims(shapeCastOp.getSourceVectorType()))
831-
return rewriter.notifyMatchFailure(
832-
shapeCastOp, "ShapeCastOp changes non-unit dimension(s)");
833-
834-
auto transposeSourceVectorType = transposeOp.getSourceVectorType();
835-
auto transposeSourceDims =
836-
llvm::to_vector(getDims(transposeSourceVectorType));
837-
838-
// Construct a map from dimIdx -> number of dims dropped before dimIdx.
839-
SmallVector<int64_t> droppedDimsBefore(transposeSourceVectorType.getRank());
840-
int64_t droppedDims = 0;
841-
for (auto [i, dim] : llvm::enumerate(transposeSourceDims)) {
842-
droppedDimsBefore[i] = droppedDims;
843-
if (dim == std::make_tuple(1, false))
844-
++droppedDims;
845-
}
846-
847-
// Drop unit dims from transpose permutation.
848-
auto perm = transposeOp.getPermutation();
849-
SmallVector<int64_t> newPerm;
850-
for (int64_t idx : perm) {
851-
if (transposeSourceDims[idx] == std::make_tuple(1, false))
852-
continue;
853-
newPerm.push_back(idx - droppedDimsBefore[idx]);
854-
}
855-
856-
auto loc = shapeCastOp.getLoc();
857-
auto newShapeCastOp = rewriter.create<vector::ShapeCastOp>(
858-
loc, dropUnitDims(transposeSourceVectorType), transposeOp.getVector());
859-
rewriter.replaceOpWithNewOp<vector::TransposeOp>(shapeCastOp,
860-
newShapeCastOp, newPerm);
861-
return success();
862-
}
863-
};
864-
865777
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
866778
/// the ZA state. This workaround rewrite to support these transposes when ZA is
867779
/// available.
@@ -1027,8 +939,7 @@ struct VectorLegalizationPass
1027939
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
1028940
LiftIllegalVectorTransposeToMemory,
1029941
ConvertIllegalShapeCastOpsToTransposes,
1030-
SwapShapeCastOfTranspose, LowerIllegalTransposeStoreViaZA>(
1031-
context);
942+
LowerIllegalTransposeStoreViaZA>(context);
1032943
// Note: These two patterns are added with a high benefit to ensure:
1033944
// - Masked outer products are handled before unmasked ones
1034945
// - Multi-tile writes are lowered as a store loop (if possible)

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -646,29 +646,3 @@ func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vect
646646
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
647647
return
648648
}
649-
650-
// -----
651-
652-
// CHECK-LABEL: @swap_shape_cast_of_transpose(
653-
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
654-
func.func @swap_shape_cast_of_transpose(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x4xf32> {
655-
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
656-
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
657-
// CHECK: return %[[TRANSPOSE]]
658-
%0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
659-
%1 = vector.shape_cast %0 : vector<[4]x1x1x4xf32> to vector<[4]x4xf32>
660-
return %1 : vector<[4]x4xf32>
661-
}
662-
663-
// -----
664-
665-
// CHECK-LABEL: @swap_shape_cast_of_transpose_units_dims_before_and_after(
666-
// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
667-
func.func @swap_shape_cast_of_transpose_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x4xf32> {
668-
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
669-
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
670-
// CHECK: return %[[TRANSPOSE]]
671-
%0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
672-
%1 = vector.shape_cast %0 : vector<[4]x1x1x1x4x1xf32> to vector<[4]x4xf32>
673-
return %1 : vector<[4]x4xf32>
674-
}

0 commit comments

Comments
 (0)