Skip to content

Commit c78166b

Browse files
committed
[mlir][vector] Add n-d deinterleave lowering
This patch implements the lowering for vector deinterleave for vector of n-dimensions. Process involves unrolling the n-d vector to a series of one-dimensional vectors. The deinterleave operation is then used on these vectors. From: ``` %0, %1 = vector.deinterleave %a : vector<2x[4]xi8> -> vector<2x[2]xi8> ``` To: ``` %2 = llvm.extractvalue %0[0] : !llvm.array<2 x vector<8xf32>> %3 = llvm.mlir.poison : vector<8xf32> %4 = llvm.shufflevector %2, %3 [0, 2, 4, 6] : vector<8xf32> %5 = llvm.shufflevector %2, %3 [1, 3, 5, 7] : vector<8xf32> %6 = llvm.insertvalue %4, %1[0] : !llvm.array<2 x vector<4xf32>> %7 = llvm.insertvalue %5, %1[0] : !llvm.array<2 x vector<4xf32>> %8 = llvm.extractvalue %0[1] : !llvm.array<2 x vector<8xf32>> ...etc. ```
1 parent fecf5c7 commit c78166b

File tree

2 files changed

+75
-0
lines changed

2 files changed

+75
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,43 @@ class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
7979
int64_t targetRank = 1;
8080
};
8181

82+
class UnrollDeinterleaveOp final
83+
: public OpRewritePattern<vector::DeinterleaveOp> {
84+
public:
85+
UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
86+
PatternBenefit benefit = 1)
87+
: OpRewritePattern(context, benefit), targetRank(targetRank){};
88+
89+
LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
90+
PatternRewriter &rewriter) const override {
91+
VectorType resultType = op.getResultVectorType();
92+
auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
93+
if (!unrollIterator)
94+
return failure();
95+
96+
auto loc = op.getLoc();
97+
Value evenResult = rewriter.create<arith::ConstantOp>(
98+
loc, resultType, rewriter.getZeroAttr(resultType));
99+
Value oddResult = rewriter.create<arith::ConstantOp>(
100+
loc, resultType, rewriter.getZeroAttr(resultType));
101+
102+
for (auto position : *unrollIterator) {
103+
auto extractSrc =
104+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), position);
105+
auto deinterleave =
106+
rewriter.create<vector::DeinterleaveOp>(loc, extractSrc);
107+
evenResult = rewriter.create<vector::InsertOp>(
108+
loc, deinterleave.getRes1(), evenResult, position);
109+
oddResult = rewriter.create<vector::InsertOp>(loc, deinterleave.getRes2(),
110+
oddResult, position);
111+
}
112+
rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
113+
return success();
114+
}
115+
116+
private:
117+
int64_t targetRank = 1;
118+
};
82119
/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
83120
/// applicable: `sourceType` must be 1D and non-scalable.
84121
///
@@ -117,6 +154,8 @@ struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
117154
void mlir::vector::populateVectorInterleaveLoweringPatterns(
118155
RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
119156
patterns.add<UnrollInterleaveOp>(targetRank, patterns.getContext(), benefit);
157+
patterns.add<UnrollDeinterleaveOp>(targetRank, patterns.getContext(),
158+
benefit);
120159
}
121160

122161
void mlir::vector::populateVectorInterleaveToShufflePatterns(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+36
Original file line numberDiff line numberDiff line change
@@ -2564,3 +2564,39 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi
25642564
%0, %1 = vector.deinterleave %a : vector<[4]xi32> -> vector<[2]xi32>
25652565
return %0, %1 : vector<[2]xi32>, vector<[2]xi32>
25662566
}
2567+
2568+
// CHECK-LABEL: @vector_deinterleave_2d
2569+
// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>)
2570+
func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) {
2571+
// CHECK: %[[EXTRACT_ONE:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<8xf32>>
2572+
// CHECK: %[[POISON_ONE:.*]] = llvm.mlir.poison : vector<8xf32>
2573+
// CHECK: %[[SHUFFLE_A:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [0, 2, 4, 6] : vector<8xf32>
2574+
// CHECK: %[[SHUFFLE_B:.*]] = llvm.shufflevector %[[EXTRACT_ONE]], %[[POISON_ONE]] [1, 3, 5, 7] : vector<8xf32>
2575+
// CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[SHUFFLE_A]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>>
2576+
// CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[SHUFFLE_B]], %{{.*}}[0] : !llvm.array<2 x vector<4xf32>>
2577+
// CHECK: %[[EXTRACT_TWO:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<8xf32>>
2578+
// CHECK: %[[POISON_TWO:.*]] = llvm.mlir.poison : vector<8xf32>
2579+
// CHECK: %[[SHUFFLE_C:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [0, 2, 4, 6] : vector<8xf32>
2580+
// CHECK: %[[SHUFFLE_D:.*]] = llvm.shufflevector %[[EXTRACT_TWO]], %[[POISON_TWO]] [1, 3, 5, 7] : vector<8xf32>
2581+
// CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[SHUFFLE_C]], %[[INSERT_A]][1] : !llvm.array<2 x vector<4xf32>>
2582+
// CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[SHUFFLE_D]], %[[INSERT_B]][1] : !llvm.array<2 x vector<4xf32>>
2583+
%0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32>
2584+
return %0, %1 : vector<2x4xf32>, vector<2x4xf32>
2585+
}
2586+
2587+
func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) {
2588+
// CHECK: %[[EXTRACT_A:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<2 x vector<[8]xf32>>
2589+
// CHECK: %[[VECTOR_ONE:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_ONE]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
2590+
// CHECK: %[[EXTRACT_B:.*]] = llvm.extractvalue %[[VECTOR_ONE]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
2591+
// CHECK: %[[EXTRACT_C:.*]] = llvm.extractvalue %[[VECTOR_ONE]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
2592+
// CHECK: %[[INSERT_A:.*]] = llvm.insertvalue %[[EXTRACT_B]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
2593+
// CHECK: %[[INSERT_B:.*]] = llvm.insertvalue %[[EXTRACT_C]], %{{.*}}[0] : !llvm.array<2 x vector<[4]xf32>>
2594+
// CHECK: %[[EXTRACT_D:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<[8]xf32>>
2595+
// CHECK: %[[VECTOR_TWO:.*]] = "llvm.intr.vector.deinterleave2"(%[[EXTRACT_D]]) : (vector<[8]xf32>) -> !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
2596+
// CHECK: %[[EXTRACT_E:.*]] = llvm.extractvalue %[[VECTOR_TWO]][0] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
2597+
// CHECK: %[[EXTRACT_F:.*]] = llvm.extractvalue %[[VECTOR_TWO]][1] : !llvm.struct<(vector<[4]xf32>, vector<[4]xf32>)>
2598+
// CHECK: %[[INSERT_C:.*]] = llvm.insertvalue %[[EXTRACT_E]], %[[INSERT_A]][1] : !llvm.array<2 x vector<[4]xf32>>
2599+
// CHECK: %[[INSERT_D:.*]] = llvm.insertvalue %[[EXTRACT_F]], %[[INSERT_B]][1] : !llvm.array<2 x vector<[4]xf32>>
2600+
%0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32>
2601+
return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32>
2602+
}

0 commit comments

Comments
 (0)