-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[mlir][vector] Add pattern to drop unit dims from vector.transpose #102017
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesExample: BEFORE: %transpose = vector.transpose %vector, [3, 0, 1, 2]
: vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> AFTER: %dropDims = vector.shape_cast %vector
: vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
%transpose = vector.transpose %0, [1, 0]
: vector<4x[4]xf32> to vector<[4]x4xf32>
%restoreDims = vector.shape_cast %transpose
: vector<[4]x4xf32> to vector<[4]x1x1x4xf32> Full diff: https://github.com/llvm/llvm-project/pull/102017.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 40e04b76593a0..67c36bfa06ded 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -120,6 +120,11 @@ inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
};
}
+/// Returns an iterator over the dims (inc scalability) of a VectorType.
+inline auto getDims(VectorType vType) {
+ return llvm::zip_equal(vType.getShape(), vType.getScalableDims());
+}
+
/// A wrapper for getMixedSizes for vector.transfer_read and
/// vector.transfer_write Ops (for source and destination, respectively).
///
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 6777e589795c8..6b39cef7899d9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1720,6 +1720,72 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// A pattern to drop unit dims from vector.transpose.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vector, [3, 0, 1, 2]
+/// : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %dropDims = vector.shape_cast %vector
+/// : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+/// %transpose = vector.transpose %0, [1, 0]
+/// : vector<4x[4]xf32> to vector<[4]x4xf32>
+/// %restoreDims = vector.shape_cast %transpose
+/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+/// ```
+struct DropUnitDimsFromTransposeOp final
+ : OpRewritePattern<vector::TransposeOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransposeOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ VectorType sourceTypeWithoutUnitDims =
+ dropNonScalableUnitDimFromType(sourceType);
+
+ if (sourceType == sourceTypeWithoutUnitDims)
+ return failure();
+
+ // Construct a map from dimIdx -> number of dims dropped before dimIdx.
+ auto sourceDims = llvm::to_vector(vector::getDims(sourceType));
+ SmallVector<int64_t> droppedDimsBefore(sourceType.getRank());
+ int64_t droppedDims = 0;
+ for (auto [i, dim] : llvm::enumerate(sourceDims)) {
+ droppedDimsBefore[i] = droppedDims;
+ if (dim == std::make_tuple(1, false))
+ ++droppedDims;
+ }
+
+ // Drop unit dims from transpose permutation.
+ ArrayRef<int64_t> perm = op.getPermutation();
+ SmallVector<int64_t> newPerm;
+ for (int64_t idx : perm) {
+ if (sourceDims[idx] == std::make_tuple(1, false))
+ continue;
+ newPerm.push_back(idx - droppedDimsBefore[idx]);
+ }
+
+ auto loc = op.getLoc();
+ // Drop the unit dims via shape_cast.
+ auto dropDimsShapeCast = rewriter.create<vector::ShapeCastOp>(
+ loc, sourceTypeWithoutUnitDims, op.getVector());
+ // Create the new transpose.
+ auto tranposeWithoutUnitDims =
+ rewriter.create<vector::TransposeOp>(loc, dropDimsShapeCast, newPerm);
+ // Restore the unit dims via shape cast.
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ op, op.getResultVectorType(), tranposeWithoutUnitDims);
+
+ return failure();
+ }
+};
+
/// Pattern to eliminate redundant zero-constants added to reduction operands.
/// It's enough for there to be one initial zero value, so we can eliminate the
/// extra ones that feed into `vector.reduction <add>`. These get created by the
@@ -1924,8 +1990,8 @@ void mlir::vector::populateShapeCastFoldingPatterns(RewritePatternSet &patterns,
void mlir::vector::populateDropUnitDimWithShapeCastPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<DropUnitDimFromElementwiseOps, ShapeCastOpFolder>(
- patterns.getContext(), benefit);
+ patterns.add<DropUnitDimFromElementwiseOps, DropUnitDimsFromTransposeOp,
+ ShapeCastOpFolder>(patterns.getContext(), benefit);
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 9d16aa46a9f2a..222a05ff70d02 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -700,3 +700,36 @@ func.func @negative_out_of_bound_transfer_write(
}
// CHECK: func.func @negative_out_of_bound_transfer_write
// CHECK-NOT: memref.collapse_shape
+
+// -----
+
+///----------------------------------------------------------------------------------------
+/// [Pattern: DropUnitDimsFromTransposeOp]
+/// TODO: Move to a dedicated file - there's no "flattening" in the following tests
+///----------------------------------------------------------------------------------------
+
+func.func @transpose_with_internal_unit_dims(%vector: vector<1x1x4x[4]xf32>) -> vector<[4]x1x1x4xf32> {
+ %0 = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32>
+ return %0 : vector<[4]x1x1x4xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_internal_unit_dims(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x4x[4]xf32>)
+// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x4x[4]xf32> to vector<4x[4]xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
+// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x4xf32>
+
+// -----
+
+func.func @transpose_with_units_dims_before_and_after(%vector: vector<1x1x1x4x[4]x1xf32>) -> vector<[4]x1x1x1x4x1xf32> {
+ %0 = vector.transpose %vector, [4, 1, 0, 2, 3, 5] : vector<1x1x1x4x[4]x1xf32> to vector<[4]x1x1x1x4x1xf32>
+ return %0 : vector<[4]x1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @transpose_with_units_dims_before_and_after(
+// CHECK-SAME: %[[VEC:.*]]: vector<1x1x1x4x[4]x1xf32>)
+// CHECK-NEXT: %[[DROP_DIMS:.*]] = vector.shape_cast %arg0 : vector<1x1x1x4x[4]x1xf32> to vector<4x[4]xf32>
+// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32>
+// CHECK-NEXT: %[[RESTORE_DIMS:.*]] = vector.shape_cast %1 : vector<[4]x4xf32> to vector<[4]x1x1x1x4x1xf32>
+// CHECK-NEXT: return %[[RESTORE_DIMS]] : vector<[4]x1x1x1x4x1xf32>
|
Example: BEFORE: ```mlir %transpose = vector.transpose %vector, [3, 0, 1, 2] : vector<1x1x4x[4]xf32> to vector<[4]x1x1x4xf32> ``` AFTER: ```mlir %dropDims = vector.shape_cast %vector : vector<1x1x4x[4]xf32> to vector<4x[4]xf32> %transpose = vector.transpose %0, [1, 0] : vector<4x[4]xf32> to vector<[4]x4xf32> %restoreDims = vector.shape_cast %transpose : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> ```
This can supersede #100933 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Argh, sorry, forgot to send yesterday.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't foresee this causing any issues related to the SPIR-V lowering.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for bearing with me. Appreciate it
Example:
BEFORE:
AFTER: