Skip to content

[mlir][vector] Add n-d deinterleave lowering #94237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2024

Conversation

mub-at-arm
Copy link
Contributor

@mub-at-arm mub-at-arm commented Jun 3, 2024

This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:

%0, %1 = vector.deinterleave %a : vector<2x8xi8> -> vector<2x4xi8>

To:

%cst = arith.constant dense<0> : vector<2x4xi32>
%0 = vector.extract %arg0[0] : vector<8xi32> from vector<2x8xi32>
%res1, %res2 = vector.deinterleave %0 : vector<8xi32> -> vector<4xi32>
%1 = vector.insert %res1, %cst [0] : vector<4xi32> into vector<2x4xi32>
%2 = vector.insert %res2, %cst [0] : vector<4xi32> into vector<2x4xi32>
%3 = vector.extract %arg0[1] : vector<8xi32> from vector<2x8xi32>
%res1_0, %res2_1 = vector.deinterleave %3 : vector<8xi32> -> vector<4xi32>
%4 = vector.insert %res1_0, %1 [1] : vector<4xi32> into vector<2x4xi32>
%5 = vector.insert %res2_1, %2 [1] : vector<4xi32> into vector<2x4xi32>
...etc.

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Mubashar Ahmad (mub-at-arm)

Changes

This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:

%0, %1 = vector.deinterleave %a : vector&lt;2x[4]xi8&gt; -&gt; vector&lt;2x[2]xi8&gt;

To:

%2 = llvm.extractvalue %0[0] : !llvm.array&lt;2 x vector&lt;8xf32&gt;&gt;
%3 = llvm.mlir.poison : vector&lt;8xf32&gt;
%4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector&lt;8xf32&gt;
%5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector&lt;8xf32&gt;
%6 = llvm.insertvalue %4, %1[0] : !llvm.array&lt;2 x vector&lt;4xf32&gt;&gt;
%7 = llvm.insertvalue %5, %1[0] : !llvm.array&lt;2 x vector&lt;4xf32&gt;&gt;
%8 = llvm.extractvalue %0[1] : !llvm.array&lt;2 x vector&lt;8xf32&gt;&gt;
...etc.

Full diff: https://github.com/llvm/llvm-project/pull/94237.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp (+37)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+36)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 77c97b2f1497c..557837426d855 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -79,6 +79,42 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
   int64_t targetRank = 1;
 };
 
+class UnrollDeinterleaveOp final : public OpRewritePattern<vector::DeinterleaveOp> {
+public:
+  UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
+                       PatternBenefit benefit = 1)
+      : OpRewritePattern(context, benefit), targetRank(targetRank) {};
+
+  LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResultVectorType();
+    auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
+    if (!unrollIterator)
+      return failure();
+
+    auto loc = op.getLoc();
+    Value evenResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+    Value oddResult = rewriter.create<arith::ConstantOp>(
+        loc, resultType, rewriter.getZeroAttr(resultType));
+
+    for (auto position : *unrollIterator) {
+      auto extractSrc = rewriter.create<vector::ExtractOp>(
+        loc, op.getSource(), position);
+      auto deinterleave = rewriter.create<vector::DeinterleaveOp>(
+        loc, extractSrc);
+      evenResult = rewriter.create<vector::InsertOp>(
+        loc, deinterleave.getRes1(), evenResult, position);
+      oddResult = rewriter.create<vector::InsertOp>(
+        loc, deinterleave.getRes2(), oddResult, position);
+    }
+    rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
+    return success();
+  }
+
+private:
+  int64_t targetRank = 1;
+};
 /// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
 /// applicable: `sourceType` must be 1D and non-scalable.
 ///
@@ -117,6 +153,7 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
 void mlir::vector::populateVectorInterleaveLoweringPatterns(
     RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
   patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
+  patterns.add<UnrollDeinterleaveOp>(targetRank, patterns.getContext(), benefit);
 }
 
 void mlir::vector::populateVectorInterleaveToShufflePatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 245edb6789d30..21f4872bb2cd9 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -2564,3 +2564,39 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
     %0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
     return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
 }
+
+// CHECK-LABEL: @vector_deinterleave_2d
+// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
+func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
+   // CHECK: %[[EXTRACT_ONE:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<8xf32>> 
+   // CHECK: %[[POISON_ONE:.*]] = llvm.mlir.poison : vector<8xf32>
+   // CHECK: %[[SHUFFLE_A:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [0, 2, 4, 6] : vector<8xf32> 
+   // CHECK: %[[SHUFFLE_B:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [1, 3, 5, 7] : vector<8xf32> 
+   // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[SHUFFLE_A]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>> 
+   // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[SHUFFLE_B]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>> 
+   // CHECK: %[[EXTRACT_TWO:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<8xf32>> 
+   // CHECK: %[[POISON_TWO:.*]] = llvm.mlir.poison : vector<8xf32>
+   // CHECK: %[[SHUFFLE_C:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [0, 2, 4, 6] : vector<8xf32> 
+   // CHECK: %[[SHUFFLE_D:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [1, 3, 5, 7] : vector<8xf32> 
+   // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[SHUFFLE_C]], %[[INSERT_A]][1] : !llvm.array<2 x vector<4xf32>> 
+   // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[SHUFFLE_D]], %[[INSERT_B]][1] : !llvm.array<2 x vector<4xf32>> 
+  %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
+  return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
+}
+
+func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
+    // CHECK: %[[EXTRACT_A:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[8]xf32>>
+    // CHECK: %[[VECTOR_ONE:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_ONE]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+    // CHECK: %[[EXTRACT_B:.*]] = llvm.extractvalue %[[VECTOR_ONE]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[EXTRACT_C:.*]] = llvm.extractvalue %[[VECTOR_ONE]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[EXTRACT_B]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>> 
+    // CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[EXTRACT_C]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>> 
+    // CHECK: %[[EXTRACT_D:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[8]xf32>> 
+    // CHECK: %[[VECTOR_TWO:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_D]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
+    // CHECK: %[[EXTRACT_E:.*]] = llvm.extractvalue %[[VECTOR_TWO]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[EXTRACT_F:.*]] = llvm.extractvalue %[[VECTOR_TWO]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)> 
+    // CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[EXTRACT_E]], %[[INSERT_A]][1] : !llvm.array<2 x vector<[4]xf32>>
+    // CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[EXTRACT_F]], %[[INSERT_B]][1] : !llvm.array<2 x vector<[4]xf32>>
+    %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
+    return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
+}

Copy link

github-actions bot commented Jun 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch 2 times, most recently from c78166b to 8471eb6 Compare June 4, 2024 10:53
@mub-at-arm
Copy link
Contributor Author

@c-rhodes @MacDue Please review also.

@MacDue MacDue requested review from MacDue and c-rhodes June 4, 2024 11:13
@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch from 8471eb6 to b95894d Compare June 4, 2024 11:23
@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch from b95894d to c602bf3 Compare June 5, 2024 15:10
@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch 3 times, most recently from 471d9a8 to b0649b5 Compare June 5, 2024 18:53
@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch 2 times, most recently from 64bfb63 to 815a1d7 Compare June 6, 2024 11:02
@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch 2 times, most recently from a422db0 to 166b6c9 Compare June 6, 2024 12:30
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, spotted a few more things 😅

This patch implements the lowering for vector
deinterleave for vector of n-dimensions. Process
involves unrolling the n-d vector to a series
of one-dimensional vectors. The deinterleave
operation is then used on these vectors.

From:
```
%0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8>
```

To:
```
%2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>>
%3 = llvm.mlir.poison : vector<8xf32>
%4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32>
%5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32>
%6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>>
%7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>>
%8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>>
...etc.
```
@mub-at-arm mub-at-arm force-pushed the nd_deinterleave_lowering branch from 166b6c9 to 8da2efc Compare June 6, 2024 14:17
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the changes!

@MacDue MacDue merged commit b87a80d into llvm:main Jun 7, 2024
7 checks passed
@HerrCai0907 HerrCai0907 mentioned this pull request Jun 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants