@@ -14,52 +14,48 @@ using namespace mlir;
14
14
using namespace mlir ::func;
15
15
16
16
// ===----------------------------------------------------------------------===//
17
- // ValueDecomposer
17
+ // Helper functions
18
18
// ===----------------------------------------------------------------------===//
19
19
20
- void ValueDecomposer::decomposeValue (OpBuilder &builder, Location loc,
21
- Type type, Value value,
22
- SmallVectorImpl<Value> &results) {
23
- for (auto &conversion : decomposeValueConversions)
24
- if (conversion (builder, loc, type, value, results))
25
- return ;
26
- results.push_back (value);
20
+ // / If the given value can be decomposed with the type converter, decompose it.
21
+ // / Otherwise, return the given value.
22
+ // TODO: Value decomposition should happen automatically through a 1:N adaptor.
23
+ // This function will disappear when the 1:1 and 1:N drivers are merged.
24
+ static SmallVector<Value> decomposeValue (OpBuilder &builder, Location loc,
25
+ Value value,
26
+ const TypeConverter *converter) {
27
+ // Try to convert the given value's type. If that fails, just return the
28
+ // given value.
29
+ SmallVector<Type> convertedTypes;
30
+ if (failed (converter->convertType (value.getType (), convertedTypes)))
31
+ return {value};
32
+ if (convertedTypes.empty ())
33
+ return {};
34
+
35
+ // If the given value's type is already legal, just return the given value.
36
+ TypeRange convertedTypeRange (convertedTypes);
37
+ if (convertedTypeRange == TypeRange (value.getType ()))
38
+ return {value};
39
+
40
+ // Try to materialize a target conversion. If the materialization did not
41
+ // produce values of the requested type, the materialization failed. Just
42
+ // return the given value in that case.
43
+ SmallVector<Value> result = converter->materializeTargetConversion (
44
+ builder, loc, convertedTypeRange, value);
45
+ if (result.empty ())
46
+ return {value};
47
+ return result;
27
48
}
28
49
29
- // ===----------------------------------------------------------------------===//
30
- // DecomposeCallGraphTypesOpConversionPattern
31
- // ===----------------------------------------------------------------------===//
32
-
33
- namespace {
34
- // / Base OpConversionPattern class to make a ValueDecomposer available to
35
- // / inherited patterns.
36
- template <typename SourceOp>
37
- class DecomposeCallGraphTypesOpConversionPattern
38
- : public OpConversionPattern<SourceOp> {
39
- public:
40
- DecomposeCallGraphTypesOpConversionPattern (const TypeConverter &typeConverter,
41
- MLIRContext *context,
42
- ValueDecomposer &decomposer,
43
- PatternBenefit benefit = 1 )
44
- : OpConversionPattern<SourceOp>(typeConverter, context, benefit),
45
- decomposer (decomposer) {}
46
-
47
- protected:
48
- ValueDecomposer &decomposer;
49
- };
50
- } // namespace
51
-
52
50
// ===----------------------------------------------------------------------===//
53
51
// DecomposeCallGraphTypesForFuncArgs
54
52
// ===----------------------------------------------------------------------===//
55
53
56
54
namespace {
57
- // / Expand function arguments according to the provided TypeConverter and
58
- // / ValueDecomposer.
55
+ // / Expand function arguments according to the provided TypeConverter.
59
56
struct DecomposeCallGraphTypesForFuncArgs
60
- : public DecomposeCallGraphTypesOpConversionPattern<func::FuncOp> {
61
- using DecomposeCallGraphTypesOpConversionPattern::
62
- DecomposeCallGraphTypesOpConversionPattern;
57
+ : public OpConversionPattern<func::FuncOp> {
58
+ using OpConversionPattern::OpConversionPattern;
63
59
64
60
LogicalResult
65
61
matchAndRewrite (func::FuncOp op, OpAdaptor adaptor,
@@ -100,19 +96,22 @@ struct DecomposeCallGraphTypesForFuncArgs
100
96
// ===----------------------------------------------------------------------===//
101
97
102
98
namespace {
103
- // / Expand return operands according to the provided TypeConverter and
104
- // / ValueDecomposer.
99
+ // / Expand return operands according to the provided TypeConverter.
105
100
struct DecomposeCallGraphTypesForReturnOp
106
- : public DecomposeCallGraphTypesOpConversionPattern <ReturnOp> {
107
- using DecomposeCallGraphTypesOpConversionPattern::
108
- DecomposeCallGraphTypesOpConversionPattern;
101
+ : public OpConversionPattern <ReturnOp> {
102
+ using OpConversionPattern::OpConversionPattern;
103
+
109
104
LogicalResult
110
105
matchAndRewrite (ReturnOp op, OpAdaptor adaptor,
111
106
ConversionPatternRewriter &rewriter) const final {
112
107
SmallVector<Value, 2 > newOperands;
113
- for (Value operand : adaptor.getOperands ())
114
- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
115
- operand, newOperands);
108
+ for (Value operand : adaptor.getOperands ()) {
109
+ // TODO: We can directly take the values from the adaptor once this is a
110
+ // 1:N conversion pattern.
111
+ llvm::append_range (newOperands,
112
+ decomposeValue (rewriter, operand.getLoc (), operand,
113
+ getTypeConverter ()));
114
+ }
116
115
rewriter.replaceOpWithNewOp <ReturnOp>(op, newOperands);
117
116
return success ();
118
117
}
@@ -124,22 +123,23 @@ struct DecomposeCallGraphTypesForReturnOp
124
123
// ===----------------------------------------------------------------------===//
125
124
126
125
namespace {
127
- // / Expand call op operands and results according to the provided TypeConverter
128
- // / and ValueDecomposer.
129
- struct DecomposeCallGraphTypesForCallOp
130
- : public DecomposeCallGraphTypesOpConversionPattern<CallOp> {
131
- using DecomposeCallGraphTypesOpConversionPattern::
132
- DecomposeCallGraphTypesOpConversionPattern;
126
+ // / Expand call op operands and results according to the provided TypeConverter.
127
+ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern <CallOp> {
128
+ using OpConversionPattern::OpConversionPattern;
133
129
134
130
LogicalResult
135
131
matchAndRewrite (CallOp op, OpAdaptor adaptor,
136
132
ConversionPatternRewriter &rewriter) const final {
137
133
138
134
// Create the operands list of the new `CallOp`.
139
135
SmallVector<Value, 2 > newOperands;
140
- for (Value operand : adaptor.getOperands ())
141
- decomposer.decomposeValue (rewriter, op.getLoc (), operand.getType (),
142
- operand, newOperands);
136
+ for (Value operand : adaptor.getOperands ()) {
137
+ // TODO: We can directly take the values from the adaptor once this is a
138
+ // 1:N conversion pattern.
139
+ llvm::append_range (newOperands,
140
+ decomposeValue (rewriter, operand.getLoc (), operand,
141
+ getTypeConverter ()));
142
+ }
143
143
144
144
// Create the new result types for the new `CallOp` and track the indices in
145
145
// the new call op's results that correspond to the old call op's results.
@@ -189,9 +189,8 @@ struct DecomposeCallGraphTypesForCallOp
189
189
190
190
void mlir::populateDecomposeCallGraphTypesPatterns (
191
191
MLIRContext *context, const TypeConverter &typeConverter,
192
- ValueDecomposer &decomposer, RewritePatternSet &patterns) {
192
+ RewritePatternSet &patterns) {
193
193
patterns
194
194
.add <DecomposeCallGraphTypesForCallOp, DecomposeCallGraphTypesForFuncArgs,
195
- DecomposeCallGraphTypesForReturnOp>(typeConverter, context,
196
- decomposer);
195
+ DecomposeCallGraphTypesForReturnOp>(typeConverter, context);
197
196
}
0 commit comments