diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index f6371f39c3944..bc3c16d40520e 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -306,6 +306,20 @@ def ApplyLowerInterleavePatternsOp : Op]> { + let description = [{ + Indicates that 1D vector interleave operations should be rewritten as + vector shuffle operations. + + This is motivated by some current codegen backends not handling vector + interleave operations. + }]; + + let assemblyFormat = "attr-dict"; +} + def ApplyRewriteNarrowTypePatternsOp : Op]> { diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 350d2777cadf5..8fd9904fabc0e 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -273,6 +273,9 @@ void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank = 1, PatternBenefit benefit = 1); +void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt index bb9f793d7fe0f..113983146f5be 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt @@ -14,5 +14,6 @@ add_mlir_conversion_library(MLIRVectorToSPIRV MLIRSPIRVDialect MLIRSPIRVConversion MLIRVectorDialect + MLIRVectorTransforms MLIRTransforms ) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 868a3521e7a0f..c2dd37f481466 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -828,6 +829,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, // than the generic one that extracts all elements. patterns.add(typeConverter, patterns.getContext(), PatternBenefit(2)); + + // Need this until vector.interleave is handled. + vector::populateVectorInterleaveToShufflePatterns(patterns); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 885644864c0f7..61fd6bd972e3a 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -164,6 +164,11 @@ void transform::ApplyLowerInterleavePatternsOp::populatePatterns( vector::populateVectorInterleaveLoweringPatterns(patterns); } +void transform::ApplyInterleaveToShufflePatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorInterleaveToShufflePatterns(patterns); +} + void transform::ApplyRewriteNarrowTypePatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorNarrowTypeRewritePatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 3a456076f8fba..5326760c9b4eb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" #define DEBUG_TYPE "vector-interleave-lowering" @@ -77,9 +78,49 @@ class UnrollInterleaveOp : public OpRewritePattern { int64_t targetRank = 1; }; +/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when +/// applicable: `sourceType` must be 1D and non-scalable. +/// +/// Example: +/// +/// ```mlir +/// vector.interleave %a, %b : vector<7xi16> +/// ``` +/// +/// Is rewritten into: +/// +/// ```mlir +/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] +/// : vector<7xi16>, vector<7xi16> +/// ``` +class InterleaveToShuffle : public OpRewritePattern { +public: + InterleaveToShuffle(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {}; + + LogicalResult matchAndRewrite(vector::InterleaveOp op, + PatternRewriter &rewriter) const override { + VectorType sourceType = op.getSourceVectorType(); + if (sourceType.getRank() != 1 || sourceType.isScalable()) { + return failure(); + } + int64_t n = sourceType.getNumElements(); + auto seq = llvm::seq(2 * n); + auto zip = llvm::to_vector(llvm::map_range( + seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; })); + rewriter.replaceOpWithNewOp(op, op.getLhs(), op.getRhs(), zip); + return success(); + } +}; + } // namespace void mlir::vector::populateVectorInterleaveLoweringPatterns( RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) { patterns.add(targetRank, patterns.getContext(), benefit); } + +void mlir::vector::populateVectorInterleaveToShufflePatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir new file mode 100644 index 0000000000000..ed3b3396bf3ea --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-interleave-to-shuffle.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// CHECK-LABEL: @vector_interleave_to_shuffle +func.func @vector_interleave_to_shuffle(%a: vector<7xi16>, %b: vector<7xi16>) -> vector<14xi16> +{ + %0 = vector.interleave %a, %b : vector<7xi16> + return %0 : vector<14xi16> +} +// CHECK: vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13] : vector<7xi16>, vector<7xi16> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.interleave_to_shuffle + } : !transform.any_op + transform.yield + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 6304b7b548d81..debd8daf55497 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5877,6 +5877,7 @@ cc_library( ":Support", ":TransformUtils", ":VectorDialect", + ":VectorTransforms", "//llvm:Support", ], )