@@ -103,6 +103,26 @@ void InlinerInterface::handleTerminator(Operation *op,
103
103
handler->handleTerminator (op, valuesToRepl);
104
104
}
105
105
106
+ Value InlinerInterface::handleArgument (OpBuilder &builder, Operation *call,
107
+ Operation *callable, Value argument,
108
+ Type targetType,
109
+ DictionaryAttr argumentAttrs) const {
110
+ auto *handler = getInterfaceFor (call);
111
+ assert (handler && " expected valid dialect handler" );
112
+ return handler->handleArgument (builder, call, callable, argument, targetType,
113
+ argumentAttrs);
114
+ }
115
+
116
+ Value InlinerInterface::handleResult (OpBuilder &builder, Operation *call,
117
+ Operation *callable, Value result,
118
+ Type targetType,
119
+ DictionaryAttr resultAttrs) const {
120
+ auto *handler = getInterfaceFor (call);
121
+ assert (handler && " expected valid dialect handler" );
122
+ return handler->handleResult (builder, call, callable, result, targetType,
123
+ resultAttrs);
124
+ }
125
+
106
126
void InlinerInterface::processInlinedCallBlocks (
107
127
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
108
128
auto *handler = getInterfaceFor (call);
@@ -141,6 +161,77 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
141
161
// Inline Methods
142
162
// ===----------------------------------------------------------------------===//
143
163
164
+ static void handleArgumentImpl (InlinerInterface &interface, OpBuilder &builder,
165
+ CallOpInterface call,
166
+ CallableOpInterface callable,
167
+ IRMapping &mapper) {
168
+ if (!call || !callable)
169
+ return ;
170
+
171
+ // Unpack the argument attributes if there are any.
172
+ SmallVector<DictionaryAttr> argAttrs (
173
+ callable.getCallableRegion ()->getNumArguments (),
174
+ builder.getDictionaryAttr ({}));
175
+ if (ArrayAttr arrayAttr = callable.getCallableArgAttrs ()) {
176
+ assert (arrayAttr.size () == argAttrs.size ());
177
+ for (auto [idx, attr] : llvm::enumerate (arrayAttr))
178
+ argAttrs[idx] = dyn_cast<DictionaryAttr>(attr);
179
+ }
180
+
181
+ // Run the argument attribute handler for the given argument and attribute.
182
+ for (auto [blockArg, argAttr] :
183
+ llvm::zip (callable.getCallableRegion ()->getArguments (), argAttrs)) {
184
+ Value newArgument = interface.handleArgument (builder, call, callable,
185
+ mapper.lookup (blockArg),
186
+ blockArg.getType (), argAttr);
187
+ assert (newArgument.getType () == blockArg.getType () &&
188
+ " expected the handled argument type to match the target type" );
189
+
190
+ // Update the mapping to point the new argument returned by the handler.
191
+ mapper.map (blockArg, newArgument);
192
+ }
193
+ }
194
+
195
+ static void handleResultImpl (InlinerInterface &interface, OpBuilder &builder,
196
+ CallOpInterface call, CallableOpInterface callable,
197
+ ValueRange results) {
198
+ if (!call || !callable)
199
+ return ;
200
+
201
+ // Unpack the result attributes if there are any.
202
+ SmallVector<DictionaryAttr> resAttrs (results.size (),
203
+ builder.getDictionaryAttr ({}));
204
+ if (ArrayAttr arrayAttr = callable.getCallableResAttrs ()) {
205
+ assert (arrayAttr.size () == resAttrs.size ());
206
+ for (auto [idx, attr] : llvm::enumerate (arrayAttr))
207
+ resAttrs[idx] = dyn_cast<DictionaryAttr>(attr);
208
+ }
209
+
210
+ // Run the result attribute handler for the given result and attribute.
211
+ SmallVector<DictionaryAttr> resultAttributes;
212
+ for (auto [result, resAttr] : llvm::zip (results, resAttrs)) {
213
+ // Store the original result users before running the handler.
214
+ DenseSet<Operation *> resultUsers;
215
+ for (Operation *user : result.getUsers ())
216
+ resultUsers.insert (user);
217
+
218
+ // TODO: Use the type of the call result to replace once the hook can be
219
+ // used for type conversions. At the moment, all type conversions have to be
220
+ // done using materializeCallConversion.
221
+ Type targetType = result.getType ();
222
+
223
+ Value newResult = interface.handleResult (builder, call, callable, result,
224
+ targetType, resAttr);
225
+ assert (newResult.getType () == targetType &&
226
+ " expected the handled result type to match the target type" );
227
+
228
+ // Replace the result uses except for the ones introduce by the handler.
229
+ result.replaceUsesWithIf (newResult, [&](OpOperand &operand) {
230
+ return resultUsers.count (operand.getOwner ());
231
+ });
232
+ }
233
+ }
234
+
144
235
static LogicalResult
145
236
inlineRegionImpl (InlinerInterface &interface, Region *src, Block *inlineBlock,
146
237
Block::iterator inlinePoint, IRMapping &mapper,
@@ -166,6 +257,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
166
257
mapper))
167
258
return failure ();
168
259
260
+ // Run the argument attribute handler before inlining the callable region.
261
+ OpBuilder builder (inlineBlock, inlinePoint);
262
+ auto callable = dyn_cast<CallableOpInterface>(src->getParentOp ());
263
+ handleArgumentImpl (interface, builder, call, callable, mapper);
264
+
169
265
// Check to see if the region is being cloned, or moved inline. In either
170
266
// case, move the new blocks after the 'insertBlock' to improve IR
171
267
// readability.
@@ -199,8 +295,13 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
199
295
200
296
// Handle the case where only a single block was inlined.
201
297
if (std::next (newBlocks.begin ()) == newBlocks.end ()) {
298
+ // Run the result attribute handler on the terminator operands.
299
+ Operation *firstBlockTerminator = firstNewBlock->getTerminator ();
300
+ builder.setInsertionPoint (firstBlockTerminator);
301
+ handleResultImpl (interface, builder, call, callable,
302
+ firstBlockTerminator->getOperands ());
303
+
202
304
// Have the interface handle the terminator of this block.
203
- auto *firstBlockTerminator = firstNewBlock->getTerminator ();
204
305
interface.handleTerminator (firstBlockTerminator,
205
306
llvm::to_vector<6 >(resultsToReplace));
206
307
firstBlockTerminator->erase ();
@@ -218,6 +319,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
218
319
resultToRepl.value ().getLoc ()));
219
320
}
220
321
322
+ // Run the result attribute handler on the post insertion block arguments.
323
+ builder.setInsertionPointToStart (postInsertBlock);
324
+ handleResultImpl (interface, builder, call, callable,
325
+ postInsertBlock->getArguments ());
326
+
221
327
// / Handle the terminators for each of the new blocks.
222
328
for (auto &newBlock : newBlocks)
223
329
interface.handleTerminator (newBlock.getTerminator (), postInsertBlock);
0 commit comments