diff --git a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h index 6ccbc40bdd603..2e9c297f20182 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h @@ -150,7 +150,7 @@ std::unique_ptr createLowerForeachToSCFPass(); //===----------------------------------------------------------------------===// /// Type converter for iter_space and iterator. -struct SparseIterationTypeConverter : public OneToNTypeConverter { +struct SparseIterationTypeConverter : public TypeConverter { SparseIterationTypeConverter(); }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ff36160dd616..5e5957170e646 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -173,7 +173,9 @@ class TypeConverter { /// conversion has finished. /// /// Note: Target materializations may optionally accept an additional Type - /// parameter, which is the original type of the SSA value. + /// parameter, which is the original type of the SSA value. Furthermore, `T` + /// can be a TypeRange; in that case, the function must return a + /// SmallVector. /// This method registers a materialization that will be called when /// converting (potentially multiple) block arguments that were the result of @@ -210,6 +212,9 @@ class TypeConverter { /// will be invoked with: outputType = "t3", inputs = "v2", // originalType = "t1". Note that the original type "t1" cannot be recovered /// from just "t3" and "v2"; that's why the originalType parameter exists. + /// + /// Note: During a 1:N conversion, the result types can be a TypeRange. In + /// that case the materialization produces a SmallVector. template >::template arg_t<1>> void addTargetMaterialization(FnT &&callback) { @@ -316,6 +321,11 @@ class TypeConverter { Value materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType = {}) const; + SmallVector materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultType, + ValueRange inputs, + Type originalType = {}) const; /// Convert an attribute present `attr` from within the type `type` using /// the registered conversion functions. If no applicable conversion has been @@ -340,9 +350,9 @@ class TypeConverter { /// The signature of the callback used to materialize a target conversion. /// - /// Arguments: builder, result type, inputs, location, original type - using TargetMaterializationCallbackFn = - std::function; + /// Arguments: builder, result types, inputs, location, original type + using TargetMaterializationCallbackFn = std::function( + OpBuilder &, TypeRange, ValueRange, Location, Type)>; /// The signature of the callback used to convert a type attribute. using TypeAttributeConversionCallbackFn = @@ -409,22 +419,46 @@ class TypeConverter { /// callback. /// /// With callback of form: - /// `Value(OpBuilder &, T, ValueRange, Location, Type)` + /// - Value(OpBuilder &, T, ValueRange, Location, Type) + /// - SmallVector(OpBuilder &, TypeRange, ValueRange, Location, Type) template std::enable_if_t< std::is_invocable_v, TargetMaterializationCallbackFn> wrapTargetMaterialization(FnT &&callback) const { return [callback = std::forward(callback)]( - OpBuilder &builder, Type resultType, ValueRange inputs, - Location loc, Type originalType) -> Value { - if (T derivedType = dyn_cast(resultType)) - return callback(builder, derivedType, inputs, loc, originalType); - return Value(); + OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc, Type originalType) -> SmallVector { + SmallVector result; + if constexpr (std::is_same::value) { + // This is a 1:N target materialization. Return the produces values + // directly. + result = callback(builder, resultTypes, inputs, loc, originalType); + } else if constexpr (std::is_assignable::value) { + // This is a 1:1 target materialization. Invoke the callback only if a + // single SSA value is requested. + if (resultTypes.size() == 1) { + // Invoke the callback only if the type class of the callback matches + // the requested result type. + if (T derivedType = dyn_cast(resultTypes.front())) { + // 1:1 materializations produce single values, but we store 1:N + // target materialization functions in the type converter. Wrap the + // result value in a SmallVector. + Value val = + callback(builder, derivedType, inputs, loc, originalType); + if (val) + result.push_back(val); + } + } + } else { + static_assert(sizeof(T) == 0, "T must be a Type or a TypeRange"); + } + return result; }; } /// With callback of form: - /// `Value(OpBuilder &, T, ValueRange, Location)` + /// - Value(OpBuilder &, T, ValueRange, Location) + /// - SmallVector(OpBuilder &, TypeRange, ValueRange, Location) template std::enable_if_t< std::is_invocable_v, @@ -432,9 +466,9 @@ class TypeConverter { wrapTargetMaterialization(FnT &&callback) const { return wrapTargetMaterialization( [callback = std::forward(callback)]( - OpBuilder &builder, T resultType, ValueRange inputs, Location loc, - Type originalType) -> Value { - return callback(builder, resultType, inputs, loc); + OpBuilder &builder, T resultTypes, ValueRange inputs, Location loc, + Type originalType) { + return callback(builder, resultTypes, inputs, loc); }); } diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h index c59a3a52f028f..7b4dd65cbff7b 100644 --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -33,49 +33,6 @@ namespace mlir { -/// Extends `TypeConverter` with 1:N target materializations. Such -/// materializations have to provide the "reverse" of 1:N type conversions, -/// i.e., they need to materialize N values with target types into one value -/// with a source type (which isn't possible in the base class currently). -class OneToNTypeConverter : public TypeConverter { -public: - /// Callback that expresses user-provided materialization logic from the given - /// value to N values of the given types. This is useful for expressing target - /// materializations for 1:N type conversions, which materialize one value in - /// a source type as N values in target types. - using OneToNMaterializationCallbackFn = - std::function>(OpBuilder &, TypeRange, - Value, Location)>; - - /// Creates the mapping of the given range of original types to target types - /// of the conversion and stores that mapping in the given (signature) - /// conversion. This function simply calls - /// `TypeConverter::convertSignatureArgs` and exists here with a different - /// name to reflect the broader semantic. - LogicalResult computeTypeMapping(TypeRange types, - SignatureConversion &result) const { - return convertSignatureArgs(types, result); - } - - /// Applies one of the user-provided 1:N target materializations. If several - /// exists, they are tried out in the reverse order in which they have been - /// added until the first one succeeds. If none succeeds, the functions - /// returns `std::nullopt`. - std::optional> - materializeTargetConversion(OpBuilder &builder, Location loc, - TypeRange resultTypes, Value input) const; - - /// Adds a 1:N target materialization to the converter. Such materializations - /// build IR that converts N values with target types into 1 value of the - /// source type. - void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { - oneToNTargetMaterializations.emplace_back(std::move(callback)); - } - -private: - SmallVector oneToNTargetMaterializations; -}; - /// Stores a 1:N mapping of types and provides several useful accessors. This /// class extends `SignatureConversion`, which already supports 1:N type /// mappings but lacks some accessors into the mapping as well as access to the @@ -295,7 +252,7 @@ class OneToNOpConversionPattern : public OneToNConversionPattern { /// not fail if some ops or types remain unconverted (i.e., the conversion is /// only "partial"). LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns); /// Add a pattern to the given pattern list to convert the signature of a diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 4968c4fc463d0..e908a536e6fb2 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -921,7 +921,7 @@ struct VectorLegalizationPass : public arm_sme::impl::VectorLegalizationBase { void runOnOperation() override { auto *context = &getContext(); - OneToNTypeConverter converter; + TypeConverter converter; RewritePatternSet patterns(context); converter.addConversion([](Type type) { return type; }); converter.addConversion( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3cfcaa965f354..3d0c81867e0cc 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2831,11 +2831,29 @@ Value TypeConverter::materializeTargetConversion(OpBuilder &builder, Location loc, Type resultType, ValueRange inputs, Type originalType) const { + SmallVector result = materializeTargetConversion( + builder, loc, TypeRange(resultType), inputs, originalType); + if (result.empty()) + return nullptr; + assert(result.size() == 1 && "expected single result"); + return result.front(); +} + +SmallVector TypeConverter::materializeTargetConversion( + OpBuilder &builder, Location loc, TypeRange resultTypes, ValueRange inputs, + Type originalType) const { for (const TargetMaterializationCallbackFn &fn : - llvm::reverse(targetMaterializations)) - if (Value result = fn(builder, resultType, inputs, loc, originalType)) - return result; - return nullptr; + llvm::reverse(targetMaterializations)) { + SmallVector result = + fn(builder, resultTypes, inputs, loc, originalType); + if (result.empty()) + continue; + assert(TypeRange(result) == resultTypes && + "callback produced incorrect number of values or values with " + "incorrect types"); + return result; + } + return {}; } std::optional diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index 19e29d48623e0..c208716891ef1 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -17,20 +17,6 @@ using namespace llvm; using namespace mlir; -std::optional> -OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, - Location loc, - TypeRange resultTypes, - Value input) const { - for (const OneToNMaterializationCallbackFn &fn : - llvm::reverse(oneToNTargetMaterializations)) { - if (std::optional> result = - fn(builder, resultTypes, input, loc)) - return *result; - } - return std::nullopt; -} - TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { TypeRange convertedTypes = getConvertedTypes(); if (auto mapping = getInputMapping(originalTypeNo)) @@ -268,20 +254,20 @@ Block *OneToNPatternRewriter::applySignatureConversion( LogicalResult OneToNConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - auto *typeConverter = getTypeConverter(); + auto *typeConverter = getTypeConverter(); // Construct conversion mapping for results. Operation::result_type_range originalResultTypes = op->getResultTypes(); OneToNTypeMapping resultMapping(originalResultTypes); - if (failed(typeConverter->computeTypeMapping(originalResultTypes, - resultMapping))) + if (failed(typeConverter->convertSignatureArgs(originalResultTypes, + resultMapping))) return failure(); // Construct conversion mapping for operands. Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); OneToNTypeMapping operandMapping(originalOperandTypes); - if (failed(typeConverter->computeTypeMapping(originalOperandTypes, - operandMapping))) + if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, + operandMapping))) return failure(); // Cast operands to target types. @@ -318,7 +304,7 @@ namespace mlir { // inserted by this pass are annotated with a string attribute that also // documents which kind of the cast (source, argument, or target). LogicalResult -applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, +applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, const FrozenRewritePatternSet &patterns) { #ifndef NDEBUG // Remember existing unrealized casts. This data structure is only used in @@ -370,15 +356,13 @@ applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, // Target materialization. assert(!areOperandTypesLegal && areResultsTypesLegal && operands.size() == 1 && "found unexpected target cast"); - std::optional> maybeResults = - typeConverter.materializeTargetConversion( - rewriter, castOp->getLoc(), resultTypes, operands.front()); - if (!maybeResults) { + materializedResults = typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (materializedResults.empty()) { emitError(castOp->getLoc()) << "failed to create target materialization"; return failure(); } - materializedResults = maybeResults.value(); } else { // Source and argument materializations. assert(areOperandTypesLegal && !areResultsTypesLegal && @@ -427,18 +411,18 @@ class FunctionOpInterfaceSignatureConversion : public OneToNConversionPattern { const OneToNTypeMapping &resultMapping, ValueRange convertedOperands) const override { auto funcOp = cast(op); - auto *typeConverter = getTypeConverter(); + auto *typeConverter = getTypeConverter(); // Construct mapping for function arguments. OneToNTypeMapping argumentMapping(funcOp.getArgumentTypes()); - if (failed(typeConverter->computeTypeMapping(funcOp.getArgumentTypes(), - argumentMapping))) + if (failed(typeConverter->convertSignatureArgs(funcOp.getArgumentTypes(), + argumentMapping))) return failure(); // Construct mapping for function results. OneToNTypeMapping funcResultMapping(funcOp.getResultTypes()); - if (failed(typeConverter->computeTypeMapping(funcOp.getResultTypes(), - funcResultMapping))) + if (failed(typeConverter->convertSignatureArgs(funcOp.getResultTypes(), + funcResultMapping))) return failure(); // Nothing to do if the op doesn't have any non-identity conversions for its diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp index 5c03ac12d1e58..b18dfd8bb22cb 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -147,9 +147,14 @@ populateDecomposeTuplesTestPatterns(const TypeConverter &typeConverter, /// /// This function has been copied (with small adaptions) from /// TestDecomposeCallGraphTypes.cpp. -static std::optional> -buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, - Location loc) { +static SmallVector buildGetTupleElementOps(OpBuilder &builder, + TypeRange resultTypes, + ValueRange inputs, + Location loc) { + if (inputs.size() != 1) + return {}; + Value input = inputs.front(); + TupleType inputType = dyn_cast(input.getType()); if (!inputType) return {}; @@ -222,7 +227,7 @@ void TestOneToNTypeConversionPass::runOnOperation() { auto *context = &getContext(); // Assemble type converter. - OneToNTypeConverter typeConverter; + TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( @@ -234,6 +239,11 @@ void TestOneToNTypeConversionPass::runOnOperation() { typeConverter.addArgumentMaterialization(buildMakeTupleOp); typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildGetTupleElementOps); + // Test the other target materialization variant that takes the original type + // as additional argument. This materialization function always fails. + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, TypeRange resultTypes, ValueRange inputs, + Location loc, Type originalType) -> SmallVector { return {}; }); // Assemble patterns. RewritePatternSet patterns(context);