Skip to content

Commit d043670

Browse files
[mlir][func] Replace ValueDecomposer with target materialization (#114192)
The `ValueDecomposer` in `DecomposeCallGraphTypes` was a workaround around missing 1:N support in the dialect conversion. Since #113032, the dialect conversion infrastructure supports 1:N type conversions and 1:N target materializations. The `ValueDecomposer` class is no longer needed. (However, target materializations must still be inserted manually, until we fully merge the 1:1 and 1:N drivers.) Note for LLVM integration: Register 1:N target materializations on the type converter instead of "decompose value conversions" on the `ValueDecomposer`.
1 parent 1cecc58 commit d043670

File tree

3 files changed

+93
-140
lines changed

3 files changed

+93
-140
lines changed

mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -23,70 +23,10 @@
2323

2424
namespace mlir {
2525

26-
/// This class provides a hook that expands one Value into multiple Value's,
27-
/// with a TypeConverter-inspired callback registration mechanism.
28-
///
29-
/// For folks that are familiar with the dialect conversion framework /
30-
/// TypeConverter, this is effectively the inverse of a source/argument
31-
/// materialization. A target materialization is not what we want here because
32-
/// it always produces a single Value, but in this case the whole point is to
33-
/// decompose a Value into multiple Value's.
34-
///
35-
/// The reason we need this inverse is easily understood by looking at what we
36-
/// need to do for decomposing types for a return op. When converting a return
37-
/// op, the dialect conversion framework will give the list of converted
38-
/// operands, and will ensure that each converted operand, even if it expanded
39-
/// into multiple types, is materialized as a single result. We then need to
40-
/// undo that materialization to a single result, which we do with the
41-
/// decomposeValue hooks registered on this object.
42-
///
43-
/// TODO: Eventually, the type conversion infra should have this hook built-in.
44-
/// See
45-
/// https://llvm.discourse.group/t/extending-type-conversion-infrastructure/779/2
46-
class ValueDecomposer {
47-
public:
48-
/// This method tries to decompose a value of a certain type using provided
49-
/// decompose callback functions. If it is unable to do so, the original value
50-
/// is returned.
51-
void decomposeValue(OpBuilder &, Location, Type, Value,
52-
SmallVectorImpl<Value> &);
53-
54-
/// This method registers a callback function that will be called to decompose
55-
/// a value of a certain type into 0, 1, or multiple values.
56-
template <typename FnT, typename T = typename llvm::function_traits<
57-
std::decay_t<FnT>>::template arg_t<2>>
58-
void addDecomposeValueConversion(FnT &&callback) {
59-
decomposeValueConversions.emplace_back(
60-
wrapDecomposeValueConversionCallback<T>(std::forward<FnT>(callback)));
61-
}
62-
63-
private:
64-
using DecomposeValueConversionCallFn =
65-
std::function<std::optional<LogicalResult>(
66-
OpBuilder &, Location, Type, Value, SmallVectorImpl<Value> &)>;
67-
68-
/// Generate a wrapper for the given decompose value conversion callback.
69-
template <typename T, typename FnT>
70-
DecomposeValueConversionCallFn
71-
wrapDecomposeValueConversionCallback(FnT &&callback) {
72-
return
73-
[callback = std::forward<FnT>(callback)](
74-
OpBuilder &builder, Location loc, Type type, Value value,
75-
SmallVectorImpl<Value> &newValues) -> std::optional<LogicalResult> {
76-
if (T derivedType = dyn_cast<T>(type))
77-
return callback(builder, loc, derivedType, value, newValues);
78-
return std::nullopt;
79-
};
80-
}
81-
82-
SmallVector<DecomposeValueConversionCallFn, 2> decomposeValueConversions;
83-
};
84-
8526
/// Populates the patterns needed to drive the conversion process for
86-
/// decomposing call graph types with the given `ValueDecomposer`.
27+
/// decomposing call graph types with the given `TypeConverter`.
8728
void populateDecomposeCallGraphTypesPatterns(MLIRContext *context,
8829
const TypeConverter &typeConverter,
89-
ValueDecomposer &decomposer,
9030
RewritePatternSet &patterns);
9131

9232
} // namespace mlir

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,52 +14,48 @@ using namespace mlir;
1414
using namespace mlir::func;
1515

1616
//===----------------------------------------------------------------------===//
17-
// ValueDecomposer
17+
// Helper functions
1818
//===----------------------------------------------------------------------===//
1919

20-
void ValueDecomposer::decomposeValue(OpBuilder &builder, Location loc,
21-
Type type, Value value,
22-
SmallVectorImpl<Value> &results) {
23-
for (auto &conversion : decomposeValueConversions)
24-
if (conversion(builder, loc, type, value, results))
25-
return;
26-
results.push_back(value);
20+
/// If the given value can be decomposed with the type converter, decompose it.
21+
/// Otherwise, return the given value.
22+
// TODO: Value decomposition should happen automatically through a 1:N adaptor.
23+
// This function will disappear when the 1:1 and 1:N drivers are merged.
24+
static SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc,
25+
Value value,
26+
const TypeConverter *converter) {
27+
// Try to convert the given value's type. If that fails, just return the
28+
// given value.
29+
SmallVector<Type> convertedTypes;
30+
if (failed(converter->convertType(value.getType(), convertedTypes)))
31+
return {value};
32+
if (convertedTypes.empty())
33+
return {};
34+
35+
// If the given value's type is already legal, just return the given value.
36+
TypeRange convertedTypeRange(convertedTypes);
37+
if (convertedTypeRange == TypeRange(value.getType()))
38+
return {value};
39+
40+
// Try to materialize a target conversion. If the materialization did not
41+
// produce values of the requested type, the materialization failed. Just
42+
// return the given value in that case.
43+
SmallVector<Value> result = converter->materializeTargetConversion(
44+
builder, loc, convertedTypeRange, value);
45+
if (result.empty())
46+
return {value};
47+
return result;
2748
}
2849

29-
//===----------------------------------------------------------------------===//
30-
// DecomposeCallGraphTypesOpConversionPattern
31-
//===----------------------------------------------------------------------===//
32-
33-
namespace {
34-
/// Base OpConversionPattern class to make a ValueDecomposer available to
35-
/// inherited patterns.
36-
template <typename SourceOp>
37-
class DecomposeCallGraphTypesOpConversionPattern
38-
: public OpConversionPattern<SourceOp> {
39-
public:
40-
DecomposeCallGraphTypesOpConversionPattern(const TypeConverter &typeConverter,
41-
MLIRContext *context,
42-
ValueDecomposer &decomposer,
43-
PatternBenefit benefit = 1)
44-
: OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45-
decomposer(decomposer) {}
46-
47-
protected:
48-
ValueDecomposer &decomposer;
49-
};
50-
} // namespace
51-
5250
//===----------------------------------------------------------------------===//
5351
// DecomposeCallGraphTypesForFuncArgs
5452
//===----------------------------------------------------------------------===//
5553

5654
namespace {
57-
/// Expand function arguments according to the provided TypeConverter and
58-
/// ValueDecomposer.
55+
/// Expand function arguments according to the provided TypeConverter.
5956
struct DecomposeCallGraphTypesForFuncArgs
60-
: public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61-
using DecomposeCallGraphTypesOpConversionPattern::
62-
DecomposeCallGraphTypesOpConversionPattern;
57+
: public OpConversionPattern<func::FuncOp> {
58+
using OpConversionPattern::OpConversionPattern;
6359

6460
LogicalResult
6561
matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +96,22 @@ struct DecomposeCallGraphTypesForFuncArgs
10096
//===----------------------------------------------------------------------===//
10197

10298
namespace {
103-
/// Expand return operands according to the provided TypeConverter and
104-
/// ValueDecomposer.
99+
/// Expand return operands according to the provided TypeConverter.
105100
struct DecomposeCallGraphTypesForReturnOp
106-
: public DecomposeCallGraphTypesOpConversionPattern<ReturnOp> {
107-
using DecomposeCallGraphTypesOpConversionPattern::
108-
DecomposeCallGraphTypesOpConversionPattern;
101+
: public OpConversionPattern<ReturnOp> {
102+
using OpConversionPattern::OpConversionPattern;
103+
109104
LogicalResult
110105
matchAndRewrite(ReturnOp op, OpAdaptor adaptor,
111106
ConversionPatternRewriter &rewriter) const final {
112107
SmallVector<Value, 2> newOperands;
113-
for (Value operand : adaptor.getOperands())
114-
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
115-
operand, newOperands);
108+
for (Value operand : adaptor.getOperands()) {
109+
// TODO: We can directly take the values from the adaptor once this is a
110+
// 1:N conversion pattern.
111+
llvm::append_range(newOperands,
112+
decomposeValue(rewriter, operand.getLoc(), operand,
113+
getTypeConverter()));
114+
}
116115
rewriter.replaceOpWithNewOp<ReturnOp>(op, newOperands);
117116
return success();
118117
}
@@ -124,22 +123,23 @@ struct DecomposeCallGraphTypesForReturnOp
124123
//===----------------------------------------------------------------------===//
125124

126125
namespace {
127-
/// Expand call op operands and results according to the provided TypeConverter
128-
/// and ValueDecomposer.
129-
struct DecomposeCallGraphTypesForCallOp
130-
: public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131-
using DecomposeCallGraphTypesOpConversionPattern::
132-
DecomposeCallGraphTypesOpConversionPattern;
126+
/// Expand call op operands and results according to the provided TypeConverter.
127+
struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
128+
using OpConversionPattern::OpConversionPattern;
133129

134130
LogicalResult
135131
matchAndRewrite(CallOp op, OpAdaptor adaptor,
136132
ConversionPatternRewriter &rewriter) const final {
137133

138134
// Create the operands list of the new `CallOp`.
139135
SmallVector<Value, 2> newOperands;
140-
for (Value operand : adaptor.getOperands())
141-
decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(),
142-
operand, newOperands);
136+
for (Value operand : adaptor.getOperands()) {
137+
// TODO: We can directly take the values from the adaptor once this is a
138+
// 1:N conversion pattern.
139+
llvm::append_range(newOperands,
140+
decomposeValue(rewriter, operand.getLoc(), operand,
141+
getTypeConverter()));
142+
}
143143

144144
// Create the new result types for the new `CallOp` and track the indices in
145145
// the new call op's results that correspond to the old call op's results.
@@ -189,9 +189,8 @@ struct DecomposeCallGraphTypesForCallOp
189189

190190
void mlir::populateDecomposeCallGraphTypesPatterns(
191191
MLIRContext *context, const TypeConverter &typeConverter,
192-
ValueDecomposer &decomposer, RewritePatternSet &patterns) {
192+
RewritePatternSet &patterns) {
193193
patterns
194194
.add<DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195-
DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196-
decomposer);
195+
DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197196
}

mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,40 @@ namespace {
2121
/// given tuple value. If some tuple elements are, in turn, tuples, the elements
2222
/// of those are extracted recursively such that the returned values have the
2323
/// same types as `resultTypes.getFlattenedTypes()`.
24-
static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc,
25-
TupleType resultType, Value value,
26-
SmallVectorImpl<Value> &values) {
27-
for (unsigned i = 0, e = resultType.size(); i < e; ++i) {
28-
Type elementType = resultType.getType(i);
29-
Value element = builder.create<test::GetTupleElementOp>(
30-
loc, elementType, value, builder.getI32IntegerAttr(i));
31-
if (auto nestedTupleType = dyn_cast<TupleType>(elementType)) {
32-
// Recurse if the current element is also a tuple.
33-
if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element,
34-
values)))
35-
return failure();
36-
} else {
37-
values.push_back(element);
24+
static SmallVector<Value> buildDecomposeTuple(OpBuilder &builder,
25+
TypeRange resultTypes,
26+
ValueRange inputs, Location loc) {
27+
// Skip materialization if the single input value is not a tuple.
28+
if (inputs.size() != 1)
29+
return {};
30+
Value tuple = inputs.front();
31+
auto tupleType = dyn_cast<TupleType>(tuple.getType());
32+
if (!tupleType)
33+
return {};
34+
// Skip materialization if the flattened types do not match the requested
35+
// result types.
36+
SmallVector<Type> flattenedTypes;
37+
tupleType.getFlattenedTypes(flattenedTypes);
38+
if (TypeRange(resultTypes) != TypeRange(flattenedTypes))
39+
return {};
40+
// Recursively decompose the tuple.
41+
SmallVector<Value> result;
42+
std::function<void(Value)> decompose = [&](Value tuple) {
43+
auto tupleType = dyn_cast<TupleType>(tuple.getType());
44+
if (!tupleType) {
45+
// This is not a tuple.
46+
result.push_back(tuple);
47+
return;
3848
}
39-
}
40-
return success();
49+
for (unsigned i = 0, e = tupleType.size(); i < e; ++i) {
50+
Type elementType = tupleType.getType(i);
51+
Value element = builder.create<test::GetTupleElementOp>(
52+
loc, elementType, tuple, builder.getI32IntegerAttr(i));
53+
decompose(element);
54+
}
55+
};
56+
decompose(tuple);
57+
return result;
4158
}
4259

4360
/// Creates a `test.make_tuple` op out of the given inputs building a tuple of
@@ -82,8 +99,8 @@ static Value buildMakeTupleOp(OpBuilder &builder, TupleType resultType,
8299

83100
/// A pass for testing call graph type decomposition.
84101
///
85-
/// This instantiates the patterns with a TypeConverter and ValueDecomposer
86-
/// that splits tuple types into their respective element types.
102+
/// This instantiates the patterns with a TypeConverter that splits tuple types
103+
/// into their respective element types.
87104
/// For example, `tuple<T1, T2, T3> --> T1, T2, T3`.
88105
struct TestDecomposeCallGraphTypes
89106
: public PassWrapper<TestDecomposeCallGraphTypes, OperationPass<ModuleOp>> {
@@ -123,12 +140,9 @@ struct TestDecomposeCallGraphTypes
123140
return success();
124141
});
125142
typeConverter.addArgumentMaterialization(buildMakeTupleOp);
143+
typeConverter.addTargetMaterialization(buildDecomposeTuple);
126144

127-
ValueDecomposer decomposer;
128-
decomposer.addDecomposeValueConversion(buildDecomposeTuple);
129-
130-
populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer,
131-
patterns);
145+
populateDecomposeCallGraphTypesPatterns(context, typeConverter, patterns);
132146

133147
if (failed(applyPartialConversion(module, target, std::move(patterns))))
134148
return signalPassFailure();

0 commit comments

Comments
 (0)