diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index 46bb3ddec0baf..453fa73429dd1 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns( /// the ops to get converted properly. void populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target); + ConversionTarget &target, unsigned targetBitWidth); } // namespace vector } // namespace mlir diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index c535204395557..7ca0353704981 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -19,10 +19,30 @@ using namespace mlir; +static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { + auto resultTypes = op->getResultTypes(); + for (auto resType : resultTypes) { + VectorType vecType = cast(resType); + // Reject index since getElementTypeBitWidth will abort for Index types. + if (vecType.getElementType().isIndex()) + return false; + unsigned trailingVecDimBitWidth = + vecType.getShape().back() * vecType.getElementTypeBitWidth(); + if (trailingVecDimBitWidth >= targetBitWidth) + return false; + } + return true; +} + namespace { struct LinearizeConstant final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; - + LinearizeConstant( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -31,7 +51,9 @@ struct LinearizeConstant final : OpConversionPattern { getTypeConverter()->convertType(constOp.getType()); if (!resType) return rewriter.notifyMatchFailure(loc, "can't convert return type"); - + if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + loc, "Can't flatten since targetBitWidth <= OpSize"); auto dstElementsAttr = dyn_cast(constOp.getValue()); if (!dstElementsAttr) return rewriter.notifyMatchFailure(loc, "unsupported attr type"); @@ -41,15 +63,28 @@ struct LinearizeConstant final : OpConversionPattern { dstElementsAttr); return success(); } + +private: + unsigned targetVectorBitWidth; }; struct LinearizeVectorizable final : OpTraitConversionPattern { using OpTraitConversionPattern::OpTraitConversionPattern; +public: + LinearizeVectorizable( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = std::numeric_limits::max(), + PatternBenefit benefit = 1) + : OpTraitConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { + if (!isLessThanTargetBitWidth(op, targetVectorBitWidth)) + return rewriter.notifyMatchFailure( + op->getLoc(), "Can't flatten since targetBitWidth <= OpSize"); FailureOr newOp = convertOpResultTypes(op, operands, *getTypeConverter(), rewriter); if (failed(newOp)) @@ -58,12 +93,16 @@ struct LinearizeVectorizable final rewriter.replaceOp(op, (*newOp)->getResults()); return success(); } + +private: + unsigned targetVectorBitWidth; }; } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target) { + ConversionTarget &target, unsigned targetBitWidth) { + typeConverter.addConversion([](VectorType type) -> std::optional { // Ignore scalable vectors for now. if (type.getRank() <= 1 || type.isScalable()) @@ -83,15 +122,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( typeConverter.addArgumentMaterialization(materializeCast); typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); - target.markUnknownOpDynamicallyLegal( - [&](Operation *op) -> std::optional { - if (isa(op) || op->hasTrait()) - return typeConverter.isLegal(op); - + [=](Operation *op) -> std::optional { + if ((isa(op) || + op->hasTrait())) { + return (isLessThanTargetBitWidth(op, targetBitWidth) + ? typeConverter.isLegal(op) + : true); + } return std::nullopt; }); - patterns.add(typeConverter, - patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext(), targetBitWidth); } diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 85e23103eaedb..2cbf9bec7a413 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -1,19 +1,92 @@ // RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s --check-prefix=CHECK128 +// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=0 | FileCheck %s --check-prefix=CHECK0 // CHECK-LABEL: test_linearize +// CHECK128-LABEL: test_linearize +// CHECK0-LABEL: test_linearize // CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) +// CHECK128-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>) // CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> +// CHECK128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { // CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> +// CHECK128: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> +// CHECK0: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> + %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> // CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32> - +// CHECK128: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32> // Arith and math ops are handled in generic way, check some of them // CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> +// CHECK128: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> +// CHECK0: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32> + %1 = math.sin %arg0 : vector<2x2xf32> +// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32> +// CHECK128: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32> +// CHECK0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32> + + %2 = arith.addf %arg0, %0 : vector<2x2xf32> + +// CHECK: return %[[RES]] : vector<2x2xf32> +// CHECK128: return %[[RES]] : vector<2x2xf32> + return %0 : vector<2x2xf32> +} + +// CHECK-LABEL: test_partial_linearize +// CHECK128-LABEL: test_partial_linearize +// CHECK0-LABEL: test_partial_linearize +// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) +// CHECK128-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) +// CHECK0-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>, %[[ORIG_ARG2:.*]]: vector<4x4xf32>) +// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> +// CHECK128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> +// CHECK: %[[ARG2:.*]] = vector.shape_cast %[[ORIG_ARG2]] : vector<4x4xf32> to vector<16xf32> +func.func @test_partial_linearize(%arg0: vector<2x2xf32>, %arg1: vector<4x4xf32>) -> vector<2x2xf32> { +// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> +// CHECK128: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> +// CHECK0: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> + + %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> +// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32> +// CHECK128: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32> + + // CHECK: %[[C2:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 1.000000e+00, 2.000000e+00, 5.000000e+00, 6.000000e+00]> : vector<16xf32> + // CHECK128: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32> + // CHECK0: %[[C2:.*]] = arith.constant dense<{{.*}}> : vector<4x4xf32> + %5 = arith.constant dense<[[1.0, 2.0, 3.0, 4.0], [1.0, 2.0,3.0, 4.0], [1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 5.0, 6.0]]> : vector<4x4xf32> +// Arith and math ops are handled in generic way, check some of them +// CHECK: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32> +// CHECK128: %[[SIN:.*]] = math.sin %[[ARG]] : vector<4xf32> +// CHECK0: %[[SIN:.*]] = math.sin %[[ORIG_ARG]] : vector<2x2xf32> %1 = math.sin %arg0 : vector<2x2xf32> + + // CHECK: %[[SIN1:.*]] = math.sin %[[ARG2]] : vector<16xf32> +// CHECK128: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32> +// CHECK0: %[[SIN1:.*]] = math.sin %[[ORIG_ARG2]] : vector<4x4xf32> + %6 = math.sin %arg1 : vector<4x4xf32> // CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32> +// CHECK128: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32> +// CHECK0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32> + %2 = arith.addf %arg0, %0 : vector<2x2xf32> + // CHECK: %[[ADD2:.*]] = arith.addf %[[ARG2]], %[[C2]] : vector<16xf32> + // CHECK128: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32> + // CHECK0: %[[ADD2:.*]] = arith.addf %[[ORIG_ARG2]], %[[C2]] : vector<4x4xf32> + %7 = arith.addf %arg1, %5 : vector<4x4xf32> // CHECK: return %[[RES]] : vector<2x2xf32> +// CHECK128: return %[[RES]] : vector<2x2xf32> return %0 : vector<2x2xf32> } + +// CHECK-LABEL: test_index_no_linearize +// CHECK128-LABEL: test_index_no_linearize +// CHECK0-LABEL: test_index_no_linearize +func.func @test_index_no_linearize(%arg0: vector<2x2xindex>, %arg1: vector<2x2xindex>) -> vector<2x2xindex> { + // CHECK: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // CHECK128: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + // CHECK0: %[[ADD:.*]] = arith.addi {{.*}} : vector<2x2xindex> + %0 = arith.addi %arg0, %arg1 : vector<2x2xindex> + return %0 : vector<2x2xindex> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 178a58e796b24..74d2dfa44f4fe 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -842,6 +842,9 @@ struct TestVectorLinearize final : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + TestVectorLinearize() = default; + TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} + StringRef getArgument() const override { return "test-vector-linearize"; } StringRef getDescription() const override { return "Linearizes ND vectors for N >= 2 into 1D vectors"; @@ -850,6 +853,11 @@ struct TestVectorLinearize final registry.insert(); } + Option targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits::max())}; void runOnOperation() override { auto *context = &getContext(); @@ -857,8 +865,8 @@ struct TestVectorLinearize final RewritePatternSet patterns(context); ConversionTarget target(*context); - vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter, - patterns, target); + vector::populateVectorLinearizeTypeConversionsAndLegality( + typeConverter, patterns, target, targetVectorBitwidth); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure();