diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h index f1100d5cf8b68..34a94e6ea7051 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -407,13 +407,22 @@ void populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); /// Initialize `typeConverter` and `conversionTarget` for vector linearization. -/// This registers (1) which operations are legal and hence should not be -/// linearized, (2) what converted types are (rank-1 vectors) and how to +/// +/// Definition: here 'linearization' means converting a single operation with +/// 1+ vector operand/result of rank>1, into a new single operation whose +/// vector operands and results are all of rank<=1. +/// +/// This function registers (1) which operations are legal, and hence should not +/// be linearized, (2) what the converted types are (rank-1 vectors) and how to /// materialze the conversion (with shape_cast) /// /// Note: the set of legal operations can be extended by a user if for example -/// certain rank>1 vectors are considered valid, but adding additional +/// certain rank>1 vectors are considered valid, by adding additional /// dynamically legal ops to `conversionTarget`. +/// +/// Further note: the choice to use a dialect conversion design for +/// linearization is to make it easy to reuse generic structural type +/// conversions for linearizing scf/cf/func operations void populateForVectorLinearize(TypeConverter &typeConverter, ConversionTarget &conversionTarget); diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 40d2e254fb7dd..09326242eec2a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -99,7 +99,7 @@ class ConvertForOpTypes // PR47938 tracks this issue, but it seems hard to fix. Instead, we need // to clone the op. // - // 2. We need to resue the original region instead of cloning it, otherwise + // 2. We need to reuse the original region instead of cloning it, otherwise // the dialect conversion framework thinks that we just inserted all the // cloned child ops. But what we want is to "take" the child regions and let // the dialect conversion framework continue recursively into ops inside diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 90e0479a515d5..060ce7d1d6643 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -626,45 +626,49 @@ struct LinearizeVectorCreateMask final } // namespace -/// Return true if the operation `op` does not support scalable vectors and -/// has at least 1 scalable vector result. These ops should all eventually -/// support scalable vectors, and this function should be removed. -static bool isNotLinearizableBecauseScalable(Operation *op) { - - bool unsupported = - isa(op); - if (!unsupported) - return false; - - // Check if any of the results is a scalable vector type. - auto types = op->getResultTypes(); - bool containsScalableResult = - std::any_of(types.begin(), types.end(), [](Type type) { - auto vecType = dyn_cast(type); - return vecType && vecType.isScalable(); - }); - - return containsScalableResult; -} - -static bool isNotLinearizable(Operation *op) { +/// This method defines the set of operations that are linearizable, and hence +/// that are considered illegal for the conversion target. +static bool isLinearizable(Operation *op) { // Only ops that are in the vector dialect, are ConstantLike, or // are Vectorizable might be linearized currently. StringLiteral vectorDialect = vector::VectorDialect::getDialectNamespace(); StringRef opDialect = op->getDialect()->getNamespace(); - bool unsupported = (opDialect != vectorDialect) && - !op->hasTrait() && - !op->hasTrait(); - if (unsupported) - return true; - - // Some ops currently don't support scalable vectors. - if (isNotLinearizableBecauseScalable(op)) - return true; + bool supported = (opDialect == vectorDialect) || + op->hasTrait() || + op->hasTrait(); + if (!supported) + return false; - return false; + return TypeSwitch(op) + // As type legalization is done with vector.shape_cast, shape_cast + // itself cannot be linearized (will create new shape_casts to linearize + // ad infinitum). + .Case([&](auto) { return false; }) + // The operations + // - vector.extract_strided_slice + // - vector.extract + // - vector.insert_strided_slice + // - vector.insert + // are linearized to a rank-1 vector.shuffle by the current patterns. + // vector.shuffle only supports fixed size vectors, so it is impossible to + // use this approach to linearize these ops if they operate on scalable + // vectors. + .Case( + [&](vector::ExtractStridedSliceOp extractOp) { + return !extractOp.getType().isScalable(); + }) + .Case( + [&](vector::InsertStridedSliceOp insertOp) { + return !insertOp.getType().isScalable(); + }) + .Case([&](vector::InsertOp insertOp) { + return !insertOp.getType().isScalable(); + }) + .Case([&](vector::ExtractOp extractOp) { + return !extractOp.getSourceVectorType().isScalable(); + }) + .Default([&](auto) { return true; }); } void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, @@ -698,7 +702,7 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if (isNotLinearizable(op)) + if (!isLinearizable(op)) return true; // This will return true if, for all operand and result types `t`, // convertType(t) = t. This is true if there are no rank>=2 vectors. diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 40445d3781228..9cbf319ffddb2 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -392,6 +392,28 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { // ----- +// CHECK-LABEL: test_linearize_across_for +func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> { + %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<2x2xi8> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + + // CHECK: scf.for {{.*}} -> (vector<4xi8>) + %1 = scf.for %i = %c0 to %c4 step %c1 iter_args(%arg1 = %0) -> (vector<2x2xi8>) { + + // CHECK: arith.addi {{.*}} : vector<4xi8> + %2 = arith.addi %arg1, %0 : vector<2x2xi8> + + // CHECK: scf.yield {{.*}} : vector<4xi8> + scf.yield %2 : vector<2x2xi8> + } + %3 = vector.shape_cast %1 : vector<2x2xi8> to vector<4xi8> + return %3 : vector<4xi8> +} + +// ----- + // CHECK-LABEL: linearize_vector_splat // CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32> func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> { @@ -414,6 +436,7 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> { // CHECK: return %[[CAST]] : vector<4x[2]xi32> %0 = vector.splat %arg0 : vector<4x[2]xi32> return %0 : vector<4x[2]xi32> + } // ----- diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 54defd949c264..ccba2e2806862 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" @@ -836,9 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final } }; -// TODO: move this code into the user project. -namespace vendor { - /// Get the set of operand/result types to check for sufficiently /// small inner-most dimension size. static SmallVector> @@ -960,8 +958,6 @@ struct TestVectorBitWidthLinearize final } }; -} // namespace vendor - struct TestVectorLinearize final : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) @@ -987,6 +983,8 @@ struct TestVectorLinearize final vector::populateVectorLinearizeBasePatterns(converter, target, patterns); vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, patterns); + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + converter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -1067,7 +1065,7 @@ void registerTestVectorLowerings() { PassRegistration(); - PassRegistration(); + PassRegistration(); PassRegistration(); }