-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Add Vector-dialect interleave-to-shuffle pattern, enable in VectorToSPIRV #92012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Benoit Jacob (bjacob) ChangesThis is the second attempt at merging #91800, which bounced due to a linker error apparently caused by an undeclared dependency. Context: iree-org/iree#17346. Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131: This is added to VectorToSPIRV because SPIRV doesn't currently handle This is limited to 1D, non-scalable vectors. Full diff: https://github.com/llvm/llvm-project/pull/92012.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c3944..bc3c16d40520e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.interleave_to_shuffle",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that 1D vector interleave operations should be rewritten as
+ vector shuffle operations.
+
+ This is motivated by some current codegen backends not handling vector
+ interleave operations.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.rewrite_narrow_types",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 350d2777cadf5..8fd9904fabc0e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index bb9f793d7fe0f..113983146f5be 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
MLIRSPIRVDialect
MLIRSPIRVConversion
MLIRVectorDialect
+ MLIRVectorTransforms
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 868a3521e7a0f..c2dd37f481466 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -828,6 +829,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
+
+ // Need this until vector.interleave is handled.
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f7..61fd6bd972e3a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
vector::populateVectorInterleaveLoweringPatterns(patterns);
}
+void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
+}
+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 3a456076f8fba..5326760c9b4eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-interleave-lowering"
@@ -77,9 +78,49 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
+/// applicable: `sourceType` must be 1D and non-scalable.
+///
+/// Example:
+///
+/// ```mlir
+/// vector.interleave %a, %b : vector<7xi16>
+/// ```
+///
+/// Is rewritten into:
+///
+/// ```mlir
+/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
+/// : vector<7xi16>, vector<7xi16>
+/// ```
+class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
+public:
+ InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit) {};
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return failure();
+ }
+ int64_t n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto zip = llvm::to_vector(llvm::map_range(
+ seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; }));
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, op.getLhs(), op.getRhs(), zip);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
}
+
+void mlir::vector::populateVectorInterleaveToShufflePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
new file mode 100644
index 0000000000000..ed3b3396bf3ea
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_to_shuffle
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
+{
+ %0 = vector.interleave %a, %b : vector<7xi16>
+ return %0 : vector<14xi16>
+}
+// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.interleave_to_shuffle
+ } : !transform.any_op
+ transform.yield
+ }
+}
|
@llvm/pr-subscribers-mlir-vector Author: Benoit Jacob (bjacob) ChangesThis is the second attempt at merging #91800, which bounced due to a linker error apparently caused by an undeclared dependency. Context: iree-org/iree#17346. Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131: This is added to VectorToSPIRV because SPIRV doesn't currently handle This is limited to 1D, non-scalable vectors. Full diff: https://github.com/llvm/llvm-project/pull/92012.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
index f6371f39c3944..bc3c16d40520e 100644
--- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td
@@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
+def ApplyInterleaveToShufflePatternsOp : Op<Transform_Dialect,
+ "apply_patterns.vector.interleave_to_shuffle",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Indicates that 1D vector interleave operations should be rewritten as
+ vector shuffle operations.
+
+ This is motivated by some current codegen backends not handling vector
+ interleave operations.
+ }];
+
+ let assemblyFormat = "attr-dict";
+}
+
def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
"apply_patterns.vector.rewrite_narrow_types",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 350d2777cadf5..8fd9904fabc0e 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns,
int64_t targetRank = 1,
PatternBenefit benefit = 1);
+void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns,
+ PatternBenefit benefit = 1);
+
} // namespace vector
} // namespace mlir
#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
index bb9f793d7fe0f..113983146f5be 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt
@@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
MLIRSPIRVDialect
MLIRSPIRVConversion
MLIRVectorDialect
+ MLIRVectorTransforms
MLIRTransforms
)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 868a3521e7a0f..c2dd37f481466 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
@@ -828,6 +829,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
+
+ // Need this until vector.interleave is handled.
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 885644864c0f7..61fd6bd972e3a 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns(
vector::populateVectorInterleaveLoweringPatterns(patterns);
}
+void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorInterleaveToShufflePatterns(patterns);
+}
+
void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns(
RewritePatternSet &patterns) {
populateVectorNarrowTypeRewritePatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 3a456076f8fba..5326760c9b4eb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LogicalResult.h"
#define DEBUG_TYPE "vector-interleave-lowering"
@@ -77,9 +78,49 @@ class UnrollInterleaveOp : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
+/// applicable: `sourceType` must be 1D and non-scalable.
+///
+/// Example:
+///
+/// ```mlir
+/// vector.interleave %a, %b : vector<7xi16>
+/// ```
+///
+/// Is rewritten into:
+///
+/// ```mlir
+/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
+/// : vector<7xi16>, vector<7xi16>
+/// ```
+class InterleaveToShuffle : public OpRewritePattern<vector::InterleaveOp> {
+public:
+ InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit) {};
+
+ LogicalResult matchAndRewrite(vector::InterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = op.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return failure();
+ }
+ int64_t n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto zip = llvm::to_vector(llvm::map_range(
+ seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; }));
+ rewriter.replaceOpWithNewOp<ShuffleOp>(op, op.getLhs(), op.getRhs(), zip);
+ return success();
+ }
+};
+
} // namespace
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
}
+
+void mlir::vector::populateVectorInterleaveToShufflePatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
new file mode 100644
index 0000000000000..ed3b3396bf3ea
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// CHECK-LABEL: @vector_interleave_to_shuffle
+func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16>
+{
+ %0 = vector.interleave %a, %b : vector<7xi16>
+ return %0 : vector<14xi16>
+}
+// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module_op
+ : (!transform.any_op) -> !transform.any_op
+
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.interleave_to_shuffle
+ } : !transform.any_op
+ transform.yield
+ }
+}
|
This is the second attempt at merging #91800, which bounced due to a linker error apparently caused by an undeclared dependency.
MLIRVectorToSPIRV
needed to depend onMLIRVectorTransforms
. In fact that was a preexisting issue already flagged by the tool in https://discourse.llvm.org/t/ninja-can-now-check-for-missing-cmake-dependencies-on-generated-files/74344.Context: iree-org/iree#17346.
Test IREE integrate showing it's fixing the problem it's intended to fix, i.e. it allows IREE to drop its local revert of #89131:
iree-org/iree#17359
This is added to VectorToSPIRV because SPIRV doesn't currently handle
vector.interleave
(see motivating context above).This is limited to 1D, non-scalable vectors.