-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][func] Replace ValueDecomposer
with target materialization
#114192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][func] Replace ValueDecomposer
with target materialization
#114192
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-func Author: Matthias Springer (matthias-springer) ChangesThe Full diff: https://github.com/llvm/llvm-project/pull/114192.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
index 1d311b37b37a4f..1be406bf3adf92 100644
--- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
+++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h
@@ -23,70 +23,10 @@
namespace mlir {
-/// This class provides a hook that expands one Value into multiple Value's,
-/// with a TypeConverter-inspired callback registration mechanism.
-///
-/// For folks that are familiar with the dialect conversion framework /
-/// TypeConverter, this is effectively the inverse of a source/argument
-/// materialization. A target materialization is not what we want here because
-/// it always produces a single Value, but in this case the whole point is to
-/// decompose a Value into multiple Value's.
-///
-/// The reason we need this inverse is easily understood by looking at what we
-/// need to do for decomposing types for a return op. When converting a return
-/// op, the dialect conversion framework will give the list of converted
-/// operands, and will ensure that each converted operand, even if it expanded
-/// into multiple types, is materialized as a single result. We then need to
-/// undo that materialization to a single result, which we do with the
-/// decomposeValue hooks registered on this object.
-///
-/// TODO: Eventually, the type conversion infra should have this hook built-in.
-/// See
-/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
-class ValueDecomposer {
-public:
- /// This method tries to decompose a value of a certain type using provided
- /// decompose callback functions. If it is unable to do so, the original value
- /// is returned.
- void decomposeValue(OpBuilder &, Location, Type, Value,
- SmallVectorImpl<Value> &);
-
- /// This method registers a callback function that will be called to decompose
- /// a value of a certain type into 0, 1, or multiple values.
- template <typename FnT, typename T = typename llvm::function_traits<
- std::decay_t<FnT>>::template arg_t<2>>
- void addDecomposeValueConversion(FnT &&callback) {
- decomposeValueConversions.emplace_back(
- wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
- }
-
-private:
- using DecomposeValueConversionCallFn =
- std::function<std::optional<LogicalResult>(
- OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
-
- /// Generate a wrapper for the given decompose value conversion callback.
- template <typename T, typename FnT>
- DecomposeValueConversionCallFn
- wrapDecomposeValueConversionCallback(FnT &&callback) {
- return
- [callback = std::forward<FnT>(callback)](
- OpBuilder &builder, Location loc, Type type, Value value,
- SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
- if (T derivedType = dyn_cast<T>(type))
- return callback(builder, loc, derivedType, value, newValues);
- return std::nullopt;
- };
- }
-
- SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
-};
-
/// Populates the patterns needed to drive the conversion process for
-/// decomposing call graph types with the given `ValueDecomposer`.
+/// decomposing call graph types with the given `TypeConverter`.
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
const TypeConverter &typeConverter,
- ValueDecomposer &decomposer,
RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index 357f993710a26a..8800ffd0be96dc 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -14,52 +14,46 @@ using namespace mlir;
using namespace mlir::func;
//===----------------------------------------------------------------------===//
-// ValueDecomposer
+// Helper functions
//===----------------------------------------------------------------------===//
-void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
- Type type, Value value,
- SmallVectorImpl<Value> &results) {
- for (auto &conversion : decomposeValueConversions)
- if (conversion(builder, loc, type, value, results))
- return;
- results.push_back(value);
+/// If the given value can be decomposed with the type converter, decompose it.
+/// Otherwise, return the given value.
+static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
+ Value value,
+ const TypeConverter *converter) {
+ // Try to convert the given value's type. If that fails, just return the
+ // given value.
+ SmallVector<Type> convertedTypes;
+ if (failed(converter->convertType(value.getType(), convertedTypes)))
+ return {value};
+ if (convertedTypes.empty())
+ return {};
+
+ // If the given value's type is already legal, just return the given value.
+ TypeRange convertedTypeRange(convertedTypes);
+ if (convertedTypeRange == TypeRange(value.getType()))
+ return {value};
+
+ // Try to materialize a target conversion. If the materialization did not
+ // produce values of the requested type, the materialization failed. Just
+ // return the given value in that case.
+ SmallVector<Value> result = converter->materializeTargetConversion(
+ builder, loc, convertedTypeRange, value);
+ if (result.empty())
+ return {value};
+ return result;
}
-//===----------------------------------------------------------------------===//
-// DecomposeCallGraphTypesOpConversionPattern
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// Base OpConversionPattern class to make a ValueDecomposer available to
-/// inherited patterns.
-template <typename SourceOp>
-class DecomposeCallGraphTypesOpConversionPattern
- : public OpConversionPattern<SourceOp> {
-public:
- DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
- MLIRContext *context,
- ValueDecomposer &decomposer,
- PatternBenefit benefit = 1)
- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
- decomposer(decomposer) {}
-
-protected:
- ValueDecomposer &decomposer;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
namespace {
-/// Expand function arguments according to the provided TypeConverter and
-/// ValueDecomposer.
+/// Expand function arguments according to the provided TypeConverter.
struct DecomposeCallGraphTypesForFuncArgs
- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
- using DecomposeCallGraphTypesOpConversionPattern::
- DecomposeCallGraphTypesOpConversionPattern;
+ : public OpConversionPattern<func::FuncOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +94,22 @@ struct DecomposeCallGraphTypesForFuncArgs
//===----------------------------------------------------------------------===//
namespace {
-/// Expand return operands according to the provided TypeConverter and
-/// ValueDecomposer.
+/// Expand return operands according to the provided TypeConverter.
struct DecomposeCallGraphTypesForReturnOp
- : public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
- using DecomposeCallGraphTypesOpConversionPattern::
- DecomposeCallGraphTypesOpConversionPattern;
+ : public OpConversionPattern<ReturnOp> {
+ using OpConversionPattern::OpConversionPattern;
+
LogicalResult
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands())
- decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
- operand, newOperands);
+ for (Value operand : adaptor.getOperands()) {
+ // TODO: We can directly take the values from the adaptor once this is a
+ // 1:N conversion pattern.
+ llvm::append_range(newOperands,
+ decomposeValue(rewriter, operand.getLoc(), operand,
+ getTypeConverter()));
+ }
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
return success();
}
@@ -124,12 +121,9 @@ struct DecomposeCallGraphTypesForReturnOp
//===----------------------------------------------------------------------===//
namespace {
-/// Expand call op operands and results according to the provided TypeConverter
-/// and ValueDecomposer.
-struct DecomposeCallGraphTypesForCallOp
- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
- using DecomposeCallGraphTypesOpConversionPattern::
- DecomposeCallGraphTypesOpConversionPattern;
+/// Expand call op operands and results according to the provided TypeConverter.
+struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(CallOp op, OpAdaptor adaptor,
@@ -137,9 +131,13 @@ struct DecomposeCallGraphTypesForCallOp
// Create the operands list of the new `CallOp`.
SmallVector<Value, 2> newOperands;
- for (Value operand : adaptor.getOperands())
- decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
- operand, newOperands);
+ for (Value operand : adaptor.getOperands()) {
+ // TODO: We can directly take the values from the adaptor once this is a
+ // 1:N conversion pattern.
+ llvm::append_range(newOperands,
+ decomposeValue(rewriter, operand.getLoc(), operand,
+ getTypeConverter()));
+ }
// Create the new result types for the new `CallOp` and track the indices in
// the new call op's results that correspond to the old call op's results.
@@ -189,9 +187,8 @@ struct DecomposeCallGraphTypesForCallOp
void mlir::populateDecomposeCallGraphTypesPatterns(
MLIRContext *context, const TypeConverter &typeConverter,
- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
+ RewritePatternSet &patterns) {
patterns
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
- decomposer);
+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
}
diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
index 92216da9f201e6..de511c58ae6ee0 100644
--- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp
@@ -21,23 +21,40 @@ namespace {
/// given tuple value. If some tuple elements are, in turn, tuples, the elements
/// of those are extracted recursively such that the returned values have the
/// same types as `resultTypes.getFlattenedTypes()`.
-static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
- TupleType resultType, Value value,
- SmallVectorImpl<Value> &values) {
- for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
- Type elementType = resultType.getType(i);
- Value element = builder.create<test::GetTupleElementOp>(
- loc, elementType, value, builder.getI32IntegerAttr(i));
- if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
- // Recurse if the current element is also a tuple.
- if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
- values)))
- return failure();
- } else {
- values.push_back(element);
+static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
+ TypeRange resultTypes,
+ ValueRange inputs, Location loc) {
+ // Skip materialization if the single input value is not a tuple.
+ if (inputs.size() != 1)
+ return {};
+ Value tuple = inputs.front();
+ auto tupleType = dyn_cast<TupleType>(tuple.getType());
+ if (!tupleType)
+ return {};
+ // Skip materialization if the flattened types do not match the requested
+ // result types.
+ SmallVector<Type> flattenedTypes;
+ tupleType.getFlattenedTypes(flattenedTypes);
+ if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
+ return {};
+ // Recursively decompose the tuple.
+ SmallVector<Value> result;
+ std::function<void(Value)> decompose = [&](Value tuple) {
+ auto tupleType = dyn_cast<TupleType>(tuple.getType());
+ if (!tupleType) {
+ // This is not a tuple.
+ result.push_back(tuple);
+ return;
}
- }
- return success();
+ for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
+ Type elementType = tupleType.getType(i);
+ Value element = builder.create<test::GetTupleElementOp>(
+ loc, elementType, tuple, builder.getI32IntegerAttr(i));
+ decompose(element);
+ }
+ };
+ decompose(tuple);
+ return result;
}
/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
@@ -82,8 +99,8 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
/// A pass for testing call graph type decomposition.
///
-/// This instantiates the patterns with a TypeConverter and ValueDecomposer
-/// that splits tuple types into their respective element types.
+/// This instantiates the patterns with a TypeConverter that splits tuple types
+/// into their respective element types.
/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
struct TestDecomposeCallGraphTypes
: public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
@@ -123,12 +140,9 @@ struct TestDecomposeCallGraphTypes
return success();
});
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
+ typeConverter.addTargetMaterialization(buildDecomposeTuple);
- ValueDecomposer decomposer;
- decomposer.addDecomposeValueConversion(buildDecomposeTuple);
-
- populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
- patterns);
+ populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
if (failed(applyPartialConversion(module, target, std::move(patterns))))
return signalPassFailure();
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Nice -- more specialized code gone and replace by generic one!
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032,tThe dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)
a7dce07
to
4cbb3c7
Compare
…lvm#114192) The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since llvm#113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the `ValueDecomposer`.
…lvm#114192) The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since llvm#113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the `ValueDecomposer`.
The
ValueDecomposer
inDecomposeCallGraphTypes
was a workaround around missing 1:N support in the dialect conversion. Since #113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. TheValueDecomposer
class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.)Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the
ValueDecomposer
.