diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 77c97b2f1497c..f7e01c7b12e4f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -79,6 +79,73 @@ class UnrollInterleaveOp final : public OpRewritePattern { int64_t targetRank = 1; }; +/// A one-shot unrolling of vector.deinterleave to the `targetRank`. +/// +/// Example: +/// +/// ```mlir +/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64> +/// ``` +/// Would be unrolled to: +/// ```mlir +/// %result = arith.constant dense<0> : vector<1x2x3x4xi64> +/// %0 = vector.extract %a[0, 0, 0] ─┐ +/// : vector<8xi64> from vector<1x2x3x8xi64> | +/// %1, %2 = vector.deinterleave %0 | +/// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave +/// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled. +/// : vector<4xi64> into vector<1x2x3x4xi64> | +/// %4 = vector.insert %2, %result [0, 0, 0] | +/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ +/// %5 = vector.extract %a[0, 0, 1] ─┐ +/// : vector<8xi64> from vector<1x2x3x8xi64> | +/// %6, %7 = vector.deinterleave %5 | +/// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for +/// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled +/// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave +/// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated +/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case. +/// ``` +/// +/// Note: If any leading dimension before the `targetRank` is scalable the +/// unrolling will stop before the scalable dimension. +class UnrollDeinterleaveOp final + : public OpRewritePattern { +public: + UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), targetRank(targetRank) {}; + + LogicalResult matchAndRewrite(vector::DeinterleaveOp op, + PatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + auto unrollIterator = vector::createUnrollIterator(resultType, targetRank); + if (!unrollIterator) + return failure(); + + auto loc = op.getLoc(); + Value emptyResult = rewriter.create( + loc, resultType, rewriter.getZeroAttr(resultType)); + Value evenResult = emptyResult; + Value oddResult = emptyResult; + + for (auto position : *unrollIterator) { + auto extractSrc = + rewriter.create(loc, op.getSource(), position); + auto deinterleave = + rewriter.create(loc, extractSrc); + evenResult = rewriter.create( + loc, deinterleave.getRes1(), evenResult, position); + oddResult = rewriter.create(loc, deinterleave.getRes2(), + oddResult, position); + } + rewriter.replaceOp(op, ValueRange{evenResult, oddResult}); + return success(); + } + +private: + int64_t targetRank = 1; +}; /// Rewrite vector.interleave op into an equivalent vector.shuffle op, when /// applicable: `sourceType` must be 1D and non-scalable. /// @@ -116,7 +183,8 @@ struct InterleaveToShuffle final : OpRewritePattern { void mlir::vector::populateVectorInterleaveLoweringPatterns( RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { - patterns.add(targetRank, patterns.getContext(), benefit); + patterns.add( + targetRank, patterns.getContext(), benefit); } void mlir::vector::populateVectorInterleaveToShufflePatterns( diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 12121ea0dd70e..54dcf07053906 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2565,6 +2565,22 @@ func.func @vector_deinterleave_1d_scalable(%a: vector<[4]xi32>) -> (vector<[2]xi return %0, %1 : vector<[2]xi32>, vector<[2]xi32> } +// CHECK-LABEL: @vector_deinterleave_2d +// CHECK-SAME: %[[SRC:.*]]: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) +func.func @vector_deinterleave_2d(%a: vector<2x8xf32>) -> (vector<2x4xf32>, vector<2x4xf32>) { + // CHECK: llvm.shufflevector + // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x8xf32> + %0, %1 = vector.deinterleave %a : vector<2x8xf32> -> vector<2x4xf32> + return %0, %1 : vector<2x4xf32>, vector<2x4xf32> +} + +func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xf32>) -> (vector<2x[4]xf32>, vector<2x[4]xf32>) { + // CHECK: llvm.intr.vector.deinterleave2 + // CHECK-NOT: vector.deinterleave %{{.*}} : vector<2x[8]xf32> + %0, %1 = vector.deinterleave %a : vector<2x[8]xf32> -> vector<2x[4]xf32> + return %0, %1 : vector<2x[4]xf32>, vector<2x[4]xf32> +} + // ----- // CHECK-LABEL: func.func @vector_bitcast_2d diff --git a/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir new file mode 100644 index 0000000000000..53f4a8970c794 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-deinterleave-lowering-transforms.mlir @@ -0,0 +1,68 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// CHECK-LABEL: @vector_deinterleave_2d +// CHECK-SAME: %[[SRC:.*]]: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) +func.func @vector_deinterleave_2d(%a: vector<2x8xi32>) -> (vector<2x4xi32>, vector<2x4xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> + // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0] + // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] + // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0] + // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0] + // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1] + // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]] + // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1] + // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1] + // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x4xi32>, vector<2x4xi32> + %0, %1 = vector.deinterleave %a : vector<2x8xi32> -> vector<2x4xi32> + return %0, %1 : vector<2x4xi32>, vector<2x4xi32> +} + +// CHECK-LABEL: @vector_deinterleave_2d_scalable +// CHECK-SAME: %[[SRC:.*]]: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) +func.func @vector_deinterleave_2d_scalable(%a: vector<2x[8]xi32>) -> (vector<2x[4]xi32>, vector<2x[4]xi32>) { + // CHECK: %[[CST:.*]] = arith.constant dense<0> + // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0] + // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] + // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %[[CST]] [0] + // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %[[CST]] [0] + // CHECK: %[[SRC_1:.*]] = vector.extract %[[SRC]][1] + // CHECK: %[[UNZIP_2:.*]], %[[UNZIP_3:.*]] = vector.deinterleave %[[SRC_1]] + // CHECK: %[[RES_2:.*]] = vector.insert %[[UNZIP_2]], %[[RES_0]] [1] + // CHECK: %[[RES_3:.*]] = vector.insert %[[UNZIP_3]], %[[RES_1]] [1] + // CHECK-NEXT: return %[[RES_2]], %[[RES_3]] : vector<2x[4]xi32>, vector<2x[4]xi32> + %0, %1 = vector.deinterleave %a : vector<2x[8]xi32> -> vector<2x[4]xi32> + return %0, %1 : vector<2x[4]xi32>, vector<2x[4]xi32> +} + +// CHECK-LABEL: @vector_deinterleave_4d +// CHECK-SAME: %[[SRC:.*]]: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) +func.func @vector_deinterleave_4d(%a: vector<1x2x3x8xi64>) -> (vector<1x2x3x4xi64>, vector<1x2x3x4xi64>) { + // CHECK: %[[SRC_0:.*]] = vector.extract %[[SRC]][0, 0, 0] : vector<8xi64> from vector<1x2x3x8xi64> + // CHECK: %[[UNZIP_0:.*]], %[[UNZIP_1:.*]] = vector.deinterleave %[[SRC_0]] : vector<8xi64> -> vector<4xi64> + // CHECK: %[[RES_0:.*]] = vector.insert %[[UNZIP_0]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64> + // CHECK: %[[RES_1:.*]] = vector.insert %[[UNZIP_1]], %{{.*}} [0, 0, 0] : vector<4xi64> into vector<1x2x3x4xi64> + // CHECK-COUNT-5: vector.deinterleave %{{.*}} : vector<8xi64> -> vector<4xi64> + %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64> + return %0, %1 : vector<1x2x3x4xi64>, vector<1x2x3x4xi64> +} + +// CHECK-LABEL: @vector_deinterleave_nd_with_scalable_dim +func.func @vector_deinterleave_nd_with_scalable_dim( + %a: vector<1x3x[2]x2x3x8xf16>) -> (vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16>) { + // The scalable dim blocks unrolling so only the first two dims are unrolled. + // CHECK-COUNT-3: vector.deinterleave %{{.*}} : vector<[2]x2x3x8xf16> + %0, %1 = vector.deinterleave %a: vector<1x3x[2]x2x3x8xf16> -> vector<1x3x[2]x2x3x4xf16> + return %0, %1 : vector<1x3x[2]x2x3x4xf16>, vector<1x3x[2]x2x3x4xf16> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_interleave + } : !transform.any_op + transform.yield + } +}