Skip to content

Commit 8da2efc

Browse files
committed
[mlir][vector] Add n-d deinterleave lowering
This patch implements the lowering for vector deinterleave for vector of n-dimensions. Process involves unrolling the n-d vector to a series of one-dimensional vectors. The deinterleave operation is then used on these vectors. From: ``` %0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8> ``` To: ``` %2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>> %3 = llvm.mlir.poison : vector<8xf32> %4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32> %5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32> %6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>> %7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>> %8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>> ...etc. ```
1 parent fc5254c commit 8da2efc

File tree

3 files changed

+153
-1
lines changed

3 files changed

+153
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,73 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
7979
int64_t targetRank = 1;
8080
};
8181

82+
/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
83+
///
84+
/// Example:
85+
///
86+
/// ```mlir
87+
/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
88+
/// ```
89+
/// Would be unrolled to:
90+
/// ```mlir
91+
/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
92+
/// %0 = vector.extract %a[0, 0, 0] ─┐
93+
/// : vector<8xi64> from vector<1x2x3x8xi64> |
94+
/// %1, %2 = vector.deinterleave %0 |
95+
/// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave
96+
/// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled.
97+
/// : vector<4xi64> into vector<1x2x3x4xi64> |
98+
/// %4 = vector.insert %2, %result [0, 0, 0] |
99+
/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
100+
/// %5 = vector.extract %a[0, 0, 1] ─┐
101+
/// : vector<8xi64> from vector<1x2x3x8xi64> |
102+
/// %6, %7 = vector.deinterleave %5 |
103+
/// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for
104+
/// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled
105+
/// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave
106+
/// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated
107+
/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case.
108+
/// ```
109+
///
110+
/// Note: If any leading dimension before the `targetRank` is scalable the
111+
/// unrolling will stop before the scalable dimension.
112+
class UnrollDeinterleaveOp final
113+
: public OpRewritePattern<vector::DeinterleaveOp> {
114+
public:
115+
UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
116+
PatternBenefit benefit = 1)
117+
: OpRewritePattern(context, benefit), targetRank(targetRank) {};
118+
119+
LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
120+
PatternRewriter &rewriter) const override {
121+
VectorType resultType = op.getResultVectorType();
122+
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
123+
if (!unrollIterator)
124+
return failure();
125+
126+
auto loc = op.getLoc();
127+
Value emptyResult = rewriter.create<arith::ConstantOp>(
128+
loc, resultType, rewriter.getZeroAttr(resultType));
129+
Value evenResult = emptyResult;
130+
Value oddResult = emptyResult;
131+
132+
for (auto position : *unrollIterator) {
133+
auto extractSrc =
134+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
135+
auto deinterleave =
136+
rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
137+
evenResult = rewriter.create<vector::InsertOp>(
138+
loc, deinterleave.getRes1(), evenResult, position);
139+
oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
140+
oddResult, position);
141+
}
142+
rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
143+
return success();
144+
}
145+
146+
private:
147+
int64_t targetRank = 1;
148+
};
82149
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
83150
/// applicable: `sourceType` must be 1D and non-scalable.
84151
///
@@ -116,7 +183,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
116183

117184
void mlir::vector::populateVectorInterleaveLoweringPatterns(
118185
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
119-
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
186+
patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
187+
targetRank, patterns.getContext(), benefit);
120188
}
121189

122190
void mlir::vector::populateVectorInterleaveToShufflePatterns(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
25652565
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
25662566
}
25672567

2568+
// CHECK-LABEL: @vector_deinterleave_2d
2569+
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
2570+
func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
2571+
// CHECK: llvm.shufflevector
2572+
// CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
2573+
%0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
2574+
return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
2575+
}
2576+
2577+
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
2578+
// CHECK: llvm.intr.vector.deinterleave2
2579+
// CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
2580+
%0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
2581+
return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
2582+
}
2583+
25682584
// -----
25692585

25702586
// CHECK-LABEL: func.func @vector_bitcast_2d
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: @vector_deinterleave_2d
4+
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>)
5+
func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) {
6+
// CHECK: %[[CST:.*]] = arith.constant dense<0>
7+
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
8+
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
9+
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
10+
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
11+
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
12+
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
13+
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
14+
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
15+
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32>
16+
%0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32>
17+
return %0, %1 : vector<2x4xi32>, vector<2x4xi32>
18+
}
19+
20+
// CHECK-LABEL: @vector_deinterleave_2d_scalable
21+
// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>)
22+
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) {
23+
// CHECK: %[[CST:.*]] = arith.constant dense<0>
24+
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0]
25+
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]]
26+
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0]
27+
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0]
28+
// CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1]
29+
// CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]]
30+
// CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1]
31+
// CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1]
32+
// CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32>
33+
%0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32>
34+
return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32>
35+
}
36+
37+
// CHECK-LABEL: @vector_deinterleave_4d
38+
// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>)
39+
func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) {
40+
// CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64>
41+
// CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64>
42+
// CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
43+
// CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64>
44+
// CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64>
45+
%0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
46+
return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64>
47+
}
48+
49+
// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim
50+
func.func @vector_deinterleave_nd_with_scalable_dim(
51+
%a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) {
52+
// The scalable dim blocks unrolling so only the first two dims are unrolled.
53+
// CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16>
54+
%0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16>
55+
return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>
56+
}
57+
58+
module attributes {transform.with_named_sequence} {
59+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
60+
%f = transform.structured.match ops{["func.func"]} in %module_op
61+
: (!transform.any_op) -> !transform.any_op
62+
63+
transform.apply_patterns to %f {
64+
transform.apply_patterns.vector.lower_interleave
65+
} : !transform.any_op
66+
transform.yield
67+
}
68+
}

0 commit comments

Comments
 (0)