-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][vector] Add n-d deinterleave lowering #94237
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Mubashar Ahmad (mub-at-arm) ChangesThis patch implements the lowering for vector From:
To:
Full diff: https://github.com/llvm/llvm-project/pull/94237.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..557837426d855 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,42 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
int64_t targetRank = 1;
};
+class UnrollDeinterleaveOp final : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+ UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+ LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResultVectorType();
+ auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+ if (!unrollIterator)
+ return failure();
+
+ auto loc = op.getLoc();
+ Value evenResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+ Value oddResult = rewriter.create<arith::ConstantOp>(
+ loc, resultType, rewriter.getZeroAttr(resultType));
+
+ for (auto position : *unrollIterator) {
+ auto extractSrc = rewriter.create<vector::ExtractOp>(
+ loc, op.getSource(), position);
+ auto deinterleave = rewriter.create<vector::DeinterleaveOp>(
+ loc, extractSrc);
+ evenResult = rewriter.create<vector::InsertOp>(
+ loc, deinterleave.getRes1(), evenResult, position);
+ oddResult = rewriter.create<vector::InsertOp>(
+ loc, deinterleave.getRes2(), oddResult, position);
+ }
+ rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+ return success();
+ }
+
+private:
+ int64_t targetRank = 1;
+};
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
/// applicable: `sourceType` must be 1D and non-scalable.
///
@@ -117,6 +153,7 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
void mlir::vector::populateVectorInterleaveLoweringPatterns(
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+ patterns.add<UnrollDeinterleaveOp>(targetRank, patterns.getContext(), benefit);
}
void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6789d30..21f4872bb2cd9 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,39 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
}
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+ // CHECK: %[[EXTRACT_ONE:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<8xf32>>
+ // CHECK: %[[POISON_ONE:.*]] = llvm.mlir.poison : vector<8xf32>
+ // CHECK: %[[SHUFFLE_A:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [0, 2, 4, 6] : vector<8xf32>
+ // CHECK: %[[SHUFFLE_B:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [1, 3, 5, 7] : vector<8xf32>
+ // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[SHUFFLE_A]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>>
+ // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[SHUFFLE_B]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>>
+ // CHECK: %[[EXTRACT_TWO:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<8xf32>>
+ // CHECK: %[[POISON_TWO:.*]] = llvm.mlir.poison : vector<8xf32>
+ // CHECK: %[[SHUFFLE_C:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [0, 2, 4, 6] : vector<8xf32>
+ // CHECK: %[[SHUFFLE_D:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [1, 3, 5, 7] : vector<8xf32>
+ // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[SHUFFLE_C]], %[[INSERT_A]][1] : !llvm.array<2 x vector<4xf32>>
+ // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[SHUFFLE_D]], %[[INSERT_B]][1] : !llvm.array<2 x vector<4xf32>>
+ %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+ return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+ // CHECK: %[[EXTRACT_A:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[8]xf32>>
+ // CHECK: %[[VECTOR_ONE:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_ONE]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_B:.*]] = llvm.extractvalue %[[VECTOR_ONE]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_C:.*]] = llvm.extractvalue %[[VECTOR_ONE]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[EXTRACT_B]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[EXTRACT_C]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[EXTRACT_D:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[8]xf32>>
+ // CHECK: %[[VECTOR_TWO:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_D]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_E:.*]] = llvm.extractvalue %[[VECTOR_TWO]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[EXTRACT_F:.*]] = llvm.extractvalue %[[VECTOR_TWO]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+ // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[EXTRACT_E]], %[[INSERT_A]][1] : !llvm.array<2 x vector<[4]xf32>>
+ // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[EXTRACT_F]], %[[INSERT_B]][1] : !llvm.array<2 x vector<[4]xf32>>
+ %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+ return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
c78166b
to
8471eb6
Compare
8471eb6
to
b95894d
Compare
b95894d
to
c602bf3
Compare
471d9a8
to
b0649b5
Compare
64bfb63
to
815a1d7
Compare
a422db0
to
166b6c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, spotted a few more things 😅
mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir
Outdated
Show resolved
Hide resolved
This patch implements the lowering for vector deinterleave for vector of n-dimensions. Process involves unrolling the n-d vector to a series of one-dimensional vectors. The deinterleave operation is then used on these vectors. From: ``` %0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8> ``` To: ``` %2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>> %3 = llvm.mlir.poison : vector<8xf32> %4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32> %5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32> %6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>> %7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>> %8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>> ...etc. ```
166b6c9
to
8da2efc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the changes!
This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.
From:
To: