Skip to content

Commit c602bf3

Browse files
committed
[mlir][vector] Add n-d deinterleave lowering
This patch implements the lowering for vector deinterleave for vector of n-dimensions. Process involves unrolling the n-d vector to a series of one-dimensional vectors. The deinterleave operation is then used on these vectors. From: ``` %0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8> ``` To: ``` %2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>> %3 = llvm.mlir.poison : vector<8xf32> %4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32> %5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32> %6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>> %7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>> %8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>> ...etc. ```
1 parent fc5254c commit c602bf3

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

+63-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,67 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
7979
int64_t targetRank = 1;
8080
};
8181

82+
/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
83+
///
84+
/// Example:
85+
///
86+
/// ```mlir
87+
/// vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
88+
/// ```
89+
/// Would be unrolled to:
90+
/// ```mlir
91+
/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
92+
/// %0 = vector.extract %a[0, 0, 0] ─┐
93+
/// : vector<4xi64> from vector<1x2x3x4xi64> | | - Repeated 6x for
94+
/// %1, %2 = vector.deinterleave %0 : | all leading
95+
/// positions
96+
/// : vector<8xi64> -> vector<4xi64> |
97+
/// %3 = vector.insert %1, %result [0, 0, 0] |
98+
/// : vector<4xi64> into vector<1x2x3x4xi64> |
99+
/// %3 = vector.insert %2, %result [0, 0, 0] |
100+
/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
101+
/// ```
102+
///
103+
/// Note: If any leading dimension before the `targetRank` is scalable the
104+
/// unrolling will stop before the scalable dimension.
105+
106+
class UnrollDeinterleaveOp final
107+
: public OpRewritePattern<vector::DeinterleaveOp> {
108+
public:
109+
UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
110+
PatternBenefit benefit = 1)
111+
: OpRewritePattern(context, benefit), targetRank(targetRank) {};
112+
113+
LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
114+
PatternRewriter &rewriter) const override {
115+
VectorType resultType = op.getResultVectorType();
116+
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
117+
if (!unrollIterator)
118+
return failure();
119+
120+
auto loc = op.getLoc();
121+
Value evenResult = rewriter.create<arith::ConstantOp>(
122+
loc, resultType, rewriter.getZeroAttr(resultType));
123+
Value oddResult = rewriter.create<arith::ConstantOp>(
124+
loc, resultType, rewriter.getZeroAttr(resultType));
125+
126+
for (auto position : *unrollIterator) {
127+
auto extractSrc =
128+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
129+
auto deinterleave =
130+
rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
131+
evenResult = rewriter.create<vector::InsertOp>(
132+
loc, deinterleave.getRes1(), evenResult, position);
133+
oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
134+
oddResult, position);
135+
}
136+
rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
137+
return success();
138+
}
139+
140+
private:
141+
int64_t targetRank = 1;
142+
};
82143
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
83144
/// applicable: `sourceType` must be 1D and non-scalable.
84145
///
@@ -116,7 +177,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
116177

117178
void mlir::vector::populateVectorInterleaveLoweringPatterns(
118179
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
119-
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
180+
patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
181+
targetRank, patterns.getContext(), benefit);
120182
}
121183

122184
void mlir::vector::populateVectorInterleaveToShufflePatterns(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
25652565
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
25662566
}
25672567

2568+
// CHECK-LABEL: @vector_deinterleave_2d
2569+
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
2570+
func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
2571+
// CHECK: llvm.shufflevector
2572+
// CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32>
2573+
%0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
2574+
return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
2575+
}
2576+
2577+
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
2578+
// CHECK: llvm.intr.vector.deinterleave2
2579+
// CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32>
2580+
%0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
2581+
return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
2582+
}
2583+
25682584
// -----
25692585

25702586
// CHECK-LABEL: func.func @vector_bitcast_2d

0 commit comments

Comments
 (0)