From 4cbb3c7fcfc5243c1e4a5de3ff6fb251e398d33b Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Wed, 30 Oct 2024 08:51:44 +0100 Subject: [PATCH] [mlir][func] Replace `ValueDecomposer` with target materialization The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) --- .../Func/Transforms/DecomposeCallGraphTypes.h | 62 +--------- .../Transforms/DecomposeCallGraphTypes.cpp | 111 +++++++++--------- .../Func/TestDecomposeCallGraphTypes.cpp | 60 ++++++---- 3 files changed, 93 insertions(+), 140 deletions(-) diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h index 1d311b37b37a4..1be406bf3adf9 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h @@ -23,70 +23,10 @@ namespace mlir { -/// This class provides a hook that expands one Value into multiple Value's, -/// with a TypeConverter-inspired callback registration mechanism. -/// -/// For folks that are familiar with the dialect conversion framework / -/// TypeConverter, this is effectively the inverse of a source/argument -/// materialization. A target materialization is not what we want here because -/// it always produces a single Value, but in this case the whole point is to -/// decompose a Value into multiple Value's. -/// -/// The reason we need this inverse is easily understood by looking at what we -/// need to do for decomposing types for a return op. When converting a return -/// op, the dialect conversion framework will give the list of converted -/// operands, and will ensure that each converted operand, even if it expanded -/// into multiple types, is materialized as a single result. We then need to -/// undo that materialization to a single result, which we do with the -/// decomposeValue hooks registered on this object. -/// -/// TODO: Eventually, the type conversion infra should have this hook built-in. -/// See -/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2 -class ValueDecomposer { -public: - /// This method tries to decompose a value of a certain type using provided - /// decompose callback functions. If it is unable to do so, the original value - /// is returned. - void decomposeValue(OpBuilder &, Location, Type, Value, - SmallVectorImpl &); - - /// This method registers a callback function that will be called to decompose - /// a value of a certain type into 0, 1, or multiple values. - template >::template arg_t<2>> - void addDecomposeValueConversion(FnT &&callback) { - decomposeValueConversions.emplace_back( - wrapDecomposeValueConversionCallback(std::forward(callback))); - } - -private: - using DecomposeValueConversionCallFn = - std::function( - OpBuilder &, Location, Type, Value, SmallVectorImpl &)>; - - /// Generate a wrapper for the given decompose value conversion callback. - template - DecomposeValueConversionCallFn - wrapDecomposeValueConversionCallback(FnT &&callback) { - return - [callback = std::forward(callback)]( - OpBuilder &builder, Location loc, Type type, Value value, - SmallVectorImpl &newValues) -> std::optional { - if (T derivedType = dyn_cast(type)) - return callback(builder, loc, derivedType, value, newValues); - return std::nullopt; - }; - } - - SmallVector decomposeValueConversions; -}; - /// Populates the patterns needed to drive the conversion process for -/// decomposing call graph types with the given `ValueDecomposer`. +/// decomposing call graph types with the given `TypeConverter`. void populateDecomposeCallGraphTypesPatterns(MLIRContext *context, const TypeConverter &typeConverter, - ValueDecomposer &decomposer, RewritePatternSet &patterns); } // namespace mlir diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp index 357f993710a26..de4aba2ed327d 100644 --- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp @@ -14,52 +14,48 @@ using namespace mlir; using namespace mlir::func; //===----------------------------------------------------------------------===// -// ValueDecomposer +// Helper functions //===----------------------------------------------------------------------===// -void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc, - Type type, Value value, - SmallVectorImpl &results) { - for (auto &conversion : decomposeValueConversions) - if (conversion(builder, loc, type, value, results)) - return; - results.push_back(value); +/// If the given value can be decomposed with the type converter, decompose it. +/// Otherwise, return the given value. +// TODO: Value decomposition should happen automatically through a 1:N adaptor. +// This function will disappear when the 1:1 and 1:N drivers are merged. +static SmallVector decomposeValue(OpBuilder &builder, Location loc, + Value value, + const TypeConverter *converter) { + // Try to convert the given value's type. If that fails, just return the + // given value. + SmallVector convertedTypes; + if (failed(converter->convertType(value.getType(), convertedTypes))) + return {value}; + if (convertedTypes.empty()) + return {}; + + // If the given value's type is already legal, just return the given value. + TypeRange convertedTypeRange(convertedTypes); + if (convertedTypeRange == TypeRange(value.getType())) + return {value}; + + // Try to materialize a target conversion. If the materialization did not + // produce values of the requested type, the materialization failed. Just + // return the given value in that case. + SmallVector result = converter->materializeTargetConversion( + builder, loc, convertedTypeRange, value); + if (result.empty()) + return {value}; + return result; } -//===----------------------------------------------------------------------===// -// DecomposeCallGraphTypesOpConversionPattern -//===----------------------------------------------------------------------===// - -namespace { -/// Base OpConversionPattern class to make a ValueDecomposer available to -/// inherited patterns. -template -class DecomposeCallGraphTypesOpConversionPattern - : public OpConversionPattern { -public: - DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter, - MLIRContext *context, - ValueDecomposer &decomposer, - PatternBenefit benefit = 1) - : OpConversionPattern(typeConverter, context, benefit), - decomposer(decomposer) {} - -protected: - ValueDecomposer &decomposer; -}; -} // namespace - //===----------------------------------------------------------------------===// // DecomposeCallGraphTypesForFuncArgs //===----------------------------------------------------------------------===// namespace { -/// Expand function arguments according to the provided TypeConverter and -/// ValueDecomposer. +/// Expand function arguments according to the provided TypeConverter. struct DecomposeCallGraphTypesForFuncArgs - : public DecomposeCallGraphTypesOpConversionPattern { - using DecomposeCallGraphTypesOpConversionPattern:: - DecomposeCallGraphTypesOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor adaptor, @@ -100,19 +96,22 @@ struct DecomposeCallGraphTypesForFuncArgs //===----------------------------------------------------------------------===// namespace { -/// Expand return operands according to the provided TypeConverter and -/// ValueDecomposer. +/// Expand return operands according to the provided TypeConverter. struct DecomposeCallGraphTypesForReturnOp - : public DecomposeCallGraphTypesOpConversionPattern { - using DecomposeCallGraphTypesOpConversionPattern:: - DecomposeCallGraphTypesOpConversionPattern; + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector newOperands; - for (Value operand : adaptor.getOperands()) - decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), - operand, newOperands); + for (Value operand : adaptor.getOperands()) { + // TODO: We can directly take the values from the adaptor once this is a + // 1:N conversion pattern. + llvm::append_range(newOperands, + decomposeValue(rewriter, operand.getLoc(), operand, + getTypeConverter())); + } rewriter.replaceOpWithNewOp(op, newOperands); return success(); } @@ -124,12 +123,9 @@ struct DecomposeCallGraphTypesForReturnOp //===----------------------------------------------------------------------===// namespace { -/// Expand call op operands and results according to the provided TypeConverter -/// and ValueDecomposer. -struct DecomposeCallGraphTypesForCallOp - : public DecomposeCallGraphTypesOpConversionPattern { - using DecomposeCallGraphTypesOpConversionPattern:: - DecomposeCallGraphTypesOpConversionPattern; +/// Expand call op operands and results according to the provided TypeConverter. +struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(CallOp op, OpAdaptor adaptor, @@ -137,9 +133,13 @@ struct DecomposeCallGraphTypesForCallOp // Create the operands list of the new `CallOp`. SmallVector newOperands; - for (Value operand : adaptor.getOperands()) - decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), - operand, newOperands); + for (Value operand : adaptor.getOperands()) { + // TODO: We can directly take the values from the adaptor once this is a + // 1:N conversion pattern. + llvm::append_range(newOperands, + decomposeValue(rewriter, operand.getLoc(), operand, + getTypeConverter())); + } // Create the new result types for the new `CallOp` and track the indices in // the new call op's results that correspond to the old call op's results. @@ -189,9 +189,8 @@ struct DecomposeCallGraphTypesForCallOp void mlir::populateDecomposeCallGraphTypesPatterns( MLIRContext *context, const TypeConverter &typeConverter, - ValueDecomposer &decomposer, RewritePatternSet &patterns) { + RewritePatternSet &patterns) { patterns .add(typeConverter, context, - decomposer); + DecomposeCallGraphTypesForReturnOp>(typeConverter, context); } diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index 92216da9f201e..de511c58ae6ee 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -21,23 +21,40 @@ namespace { /// given tuple value. If some tuple elements are, in turn, tuples, the elements /// of those are extracted recursively such that the returned values have the /// same types as `resultTypes.getFlattenedTypes()`. -static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc, - TupleType resultType, Value value, - SmallVectorImpl &values) { - for (unsigned i = 0, e = resultType.size(); i < e; ++i) { - Type elementType = resultType.getType(i); - Value element = builder.create( - loc, elementType, value, builder.getI32IntegerAttr(i)); - if (auto nestedTupleType = dyn_cast(elementType)) { - // Recurse if the current element is also a tuple. - if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element, - values))) - return failure(); - } else { - values.push_back(element); +static SmallVector buildDecomposeTuple(OpBuilder &builder, + TypeRange resultTypes, + ValueRange inputs, Location loc) { + // Skip materialization if the single input value is not a tuple. + if (inputs.size() != 1) + return {}; + Value tuple = inputs.front(); + auto tupleType = dyn_cast(tuple.getType()); + if (!tupleType) + return {}; + // Skip materialization if the flattened types do not match the requested + // result types. + SmallVector flattenedTypes; + tupleType.getFlattenedTypes(flattenedTypes); + if (TypeRange(resultTypes) != TypeRange(flattenedTypes)) + return {}; + // Recursively decompose the tuple. + SmallVector result; + std::function decompose = [&](Value tuple) { + auto tupleType = dyn_cast(tuple.getType()); + if (!tupleType) { + // This is not a tuple. + result.push_back(tuple); + return; } - } - return success(); + for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { + Type elementType = tupleType.getType(i); + Value element = builder.create( + loc, elementType, tuple, builder.getI32IntegerAttr(i)); + decompose(element); + } + }; + decompose(tuple); + return result; } /// Creates a `test.make_tuple` op out of the given inputs building a tuple of @@ -82,8 +99,8 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType, /// A pass for testing call graph type decomposition. /// -/// This instantiates the patterns with a TypeConverter and ValueDecomposer -/// that splits tuple types into their respective element types. +/// This instantiates the patterns with a TypeConverter that splits tuple types +/// into their respective element types. /// For example, `tuple --> T1, T2, T3`. struct TestDecomposeCallGraphTypes : public PassWrapper> { @@ -123,12 +140,9 @@ struct TestDecomposeCallGraphTypes return success(); }); typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addTargetMaterialization(buildDecomposeTuple); - ValueDecomposer decomposer; - decomposer.addDecomposeValueConversion(buildDecomposeTuple); - - populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer, - patterns); + populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns); if (failed(applyPartialConversion(module, target, std::move(patterns)))) return signalPassFailure();