Skip to content

Commit 7a7a055

Browse files
committed
[mlir] Argument and result attribute handling during inlining.
The revision adds the handleArgument and handleResult handlers that allow users of the inlining interface to implement argument and result conversions that take argument and result attributes into account. The motivating use cases for this revision are taken from the LLVM dialect inliner, which has to copy arguments that are marked as byval and that also has to consider zeroext / signext when converting integers. All type conversions are currently handled by the materializeCallConversion hook. It runs before isLegalToInline and supports only the introduction of a single cast operation since it may have to rollback. The new handlers run shortly before and after inlining and cannot fail. As a result, they can introduce more complex ir such as copying a struct argument. At the moment, the new hooks cannot be used to perform type conversions since all type conversions have to be done using the materializeCallConversion. A follow up revision will either relax this constraint or drop materializeCallConversion in favor of the new and more flexible handlers. The revision also extends the CallableOpInterface to provide access to the argument and result attributes if available. Differential Revision: https://reviews.llvm.org/D145582
1 parent 0ffea21 commit 7a7a055

File tree

7 files changed

+242
-1
lines changed

7 files changed

+242
-1
lines changed

mlir/docs/Interfaces.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,8 @@ interface section goes as follows:
731731
* `CallableOpInterface` - Used to represent the target callee of call.
732732
- `Region * getCallableRegion()`
733733
- `ArrayRef<Type> getCallableResults()`
734+
- `ArrayAttr getCallableArgAttrs()`
735+
- `ArrayAttr getCallableResAttrs()`
734736

735737
##### RegionKindInterfaces
736738

mlir/include/mlir/Dialect/Func/IR/FuncOps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,18 @@ def FuncOp : Func_Op<"func", [
299299
/// executed.
300300
ArrayRef<Type> getCallableResults() { return getFunctionType().getResults(); }
301301

302+
/// Returns the argument attributes for all callable region arguments or
303+
/// null if there are none.
304+
::mlir::ArrayAttr getCallableArgAttrs() {
305+
return getArgAttrs().value_or(nullptr);
306+
}
307+
308+
/// Returns the result attributes for all callable region results or
309+
/// null if there are none.
310+
::mlir::ArrayAttr getCallableResAttrs() {
311+
return getResAttrs().value_or(nullptr);
312+
}
313+
302314
//===------------------------------------------------------------------===//
303315
// FunctionOpInterface Methods
304316
//===------------------------------------------------------------------===//

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,28 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> {
8484
}],
8585
"::llvm::ArrayRef<::mlir::Type>", "getCallableResults"
8686
>,
87+
InterfaceMethod<
88+
/*desc=*/[{
89+
Returns the argument attributes for all callable region arguments or
90+
null if there are none.
91+
}],
92+
/*retType=*/"::mlir::ArrayAttr",
93+
/*methodName=*/"getCallableArgAttrs",
94+
/*args=*/(ins),
95+
/*methodBody=*/"",
96+
/*defaultImplementation=*/[{ return {}; }]
97+
>,
98+
InterfaceMethod<
99+
/*desc=*/[{
100+
Returns the result attributes for all callable region results or null
101+
if there are none.
102+
}],
103+
/*retType=*/"::mlir::ArrayAttr",
104+
/*methodName=*/"getCallableResAttrs",
105+
/*args=*/(ins),
106+
/*methodBody=*/"",
107+
/*defaultImplementation=*/[{ return {}; }]
108+
>,
87109
];
88110
}
89111

mlir/include/mlir/Transforms/InliningUtils.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_TRANSFORMS_INLININGUTILS_H
1414
#define MLIR_TRANSFORMS_INLININGUTILS_H
1515

16+
#include "mlir/IR/BuiltinAttributes.h"
1617
#include "mlir/IR/DialectInterface.h"
1718
#include "mlir/IR/Location.h"
1819
#include "mlir/IR/Region.h"
@@ -141,6 +142,40 @@ class DialectInlinerInterface
141142
return nullptr;
142143
}
143144

145+
/// Hook to transform the call arguments before using them to replace the
146+
/// callee arguments. It returns the transformation result or `argument`
147+
/// itself if the hook did not change anything. The type of the returned value
148+
/// has to match `targetType`, and the `argumentAttrs` dictionary is non-null
149+
/// even if no attribute is present. The hook is called after converting the
150+
/// callsite argument types using the materializeCallConversion callback, and
151+
/// right before inlining the callee region. Any operations created using the
152+
/// provided `builder` are inserted right before the inlined callee region.
153+
/// Example use cases are the insertion of copies for by value arguments, or
154+
/// integer conversions that require signedness information.
155+
virtual Value handleArgument(OpBuilder &builder, Operation *call,
156+
Operation *callable, Value argument,
157+
Type targetType,
158+
DictionaryAttr argumentAttrs) const {
159+
return argument;
160+
}
161+
162+
/// Hook to transform the callee results before using them to replace the call
163+
/// results. It returns the transformation result or the `result` itself if
164+
/// the hook did not change anything. The type of the returned values has to
165+
/// match `targetType`, and the `resultAttrs` dictionary is non-null even if
166+
/// no attribute is present. The hook is called right before handling
167+
/// terminators, and obtains the callee result before converting its type
168+
/// using the `materializeCallConversion` callback. Any operations created
169+
/// using the provided `builder` are inserted right after the inlined callee
170+
/// region. Example use cases are the insertion of copies for by value results
171+
/// or integer conversions that require signedness information.
172+
/// NOTE: This hook is invoked after inlining the `callable` region.
173+
virtual Value handleResult(OpBuilder &builder, Operation *call,
174+
Operation *callable, Value result, Type targetType,
175+
DictionaryAttr resultAttrs) const {
176+
return result;
177+
}
178+
144179
/// Process a set of blocks that have been inlined for a call. This callback
145180
/// is invoked before inlined terminator operations have been processed.
146181
virtual void processInlinedCallBlocks(
@@ -183,6 +218,15 @@ class InlinerInterface
183218
virtual void handleTerminator(Operation *op, Block *newDest) const;
184219
virtual void handleTerminator(Operation *op,
185220
ArrayRef<Value> valuesToRepl) const;
221+
222+
virtual Value handleArgument(OpBuilder &builder, Operation *call,
223+
Operation *callable, Value argument,
224+
Type targetType,
225+
DictionaryAttr argumentAttrs) const;
226+
virtual Value handleResult(OpBuilder &builder, Operation *call,
227+
Operation *callable, Value result, Type targetType,
228+
DictionaryAttr resultAttrs) const;
229+
186230
virtual void processInlinedCallBlocks(
187231
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const;
188232
};

mlir/lib/Transforms/Utils/InliningUtils.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,26 @@ void InlinerInterface::handleTerminator(Operation *op,
103103
handler->handleTerminator(op, valuesToRepl);
104104
}
105105

106+
Value InlinerInterface::handleArgument(OpBuilder &builder, Operation *call,
107+
Operation *callable, Value argument,
108+
Type targetType,
109+
DictionaryAttr argumentAttrs) const {
110+
auto *handler = getInterfaceFor(call);
111+
assert(handler && "expected valid dialect handler");
112+
return handler->handleArgument(builder, call, callable, argument, targetType,
113+
argumentAttrs);
114+
}
115+
116+
Value InlinerInterface::handleResult(OpBuilder &builder, Operation *call,
117+
Operation *callable, Value result,
118+
Type targetType,
119+
DictionaryAttr resultAttrs) const {
120+
auto *handler = getInterfaceFor(call);
121+
assert(handler && "expected valid dialect handler");
122+
return handler->handleResult(builder, call, callable, result, targetType,
123+
resultAttrs);
124+
}
125+
106126
void InlinerInterface::processInlinedCallBlocks(
107127
Operation *call, iterator_range<Region::iterator> inlinedBlocks) const {
108128
auto *handler = getInterfaceFor(call);
@@ -141,6 +161,77 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
141161
// Inline Methods
142162
//===----------------------------------------------------------------------===//
143163

164+
static void handleArgumentImpl(InlinerInterface &interface, OpBuilder &builder,
165+
CallOpInterface call,
166+
CallableOpInterface callable,
167+
IRMapping &mapper) {
168+
if (!call || !callable)
169+
return;
170+
171+
// Unpack the argument attributes if there are any.
172+
SmallVector<DictionaryAttr> argAttrs(
173+
callable.getCallableRegion()->getNumArguments(),
174+
builder.getDictionaryAttr({}));
175+
if (ArrayAttr arrayAttr = callable.getCallableArgAttrs()) {
176+
assert(arrayAttr.size() == argAttrs.size());
177+
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
178+
argAttrs[idx] = dyn_cast<DictionaryAttr>(attr);
179+
}
180+
181+
// Run the argument attribute handler for the given argument and attribute.
182+
for (auto [blockArg, argAttr] :
183+
llvm::zip(callable.getCallableRegion()->getArguments(), argAttrs)) {
184+
Value newArgument = interface.handleArgument(builder, call, callable,
185+
mapper.lookup(blockArg),
186+
blockArg.getType(), argAttr);
187+
assert(newArgument.getType() == blockArg.getType() &&
188+
"expected the handled argument type to match the target type");
189+
190+
// Update the mapping to point the new argument returned by the handler.
191+
mapper.map(blockArg, newArgument);
192+
}
193+
}
194+
195+
static void handleResultImpl(InlinerInterface &interface, OpBuilder &builder,
196+
CallOpInterface call, CallableOpInterface callable,
197+
ValueRange results) {
198+
if (!call || !callable)
199+
return;
200+
201+
// Unpack the result attributes if there are any.
202+
SmallVector<DictionaryAttr> resAttrs(results.size(),
203+
builder.getDictionaryAttr({}));
204+
if (ArrayAttr arrayAttr = callable.getCallableResAttrs()) {
205+
assert(arrayAttr.size() == resAttrs.size());
206+
for (auto [idx, attr] : llvm::enumerate(arrayAttr))
207+
resAttrs[idx] = dyn_cast<DictionaryAttr>(attr);
208+
}
209+
210+
// Run the result attribute handler for the given result and attribute.
211+
SmallVector<DictionaryAttr> resultAttributes;
212+
for (auto [result, resAttr] : llvm::zip(results, resAttrs)) {
213+
// Store the original result users before running the handler.
214+
DenseSet<Operation *> resultUsers;
215+
for (Operation *user : result.getUsers())
216+
resultUsers.insert(user);
217+
218+
// TODO: Use the type of the call result to replace once the hook can be
219+
// used for type conversions. At the moment, all type conversions have to be
220+
// done using materializeCallConversion.
221+
Type targetType = result.getType();
222+
223+
Value newResult = interface.handleResult(builder, call, callable, result,
224+
targetType, resAttr);
225+
assert(newResult.getType() == targetType &&
226+
"expected the handled result type to match the target type");
227+
228+
// Replace the result uses except for the ones introduce by the handler.
229+
result.replaceUsesWithIf(newResult, [&](OpOperand &operand) {
230+
return resultUsers.count(operand.getOwner());
231+
});
232+
}
233+
}
234+
144235
static LogicalResult
145236
inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
146237
Block::iterator inlinePoint, IRMapping &mapper,
@@ -166,6 +257,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
166257
mapper))
167258
return failure();
168259

260+
// Run the argument attribute handler before inlining the callable region.
261+
OpBuilder builder(inlineBlock, inlinePoint);
262+
auto callable = dyn_cast<CallableOpInterface>(src->getParentOp());
263+
handleArgumentImpl(interface, builder, call, callable, mapper);
264+
169265
// Check to see if the region is being cloned, or moved inline. In either
170266
// case, move the new blocks after the 'insertBlock' to improve IR
171267
// readability.
@@ -199,8 +295,13 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
199295

200296
// Handle the case where only a single block was inlined.
201297
if (std::next(newBlocks.begin()) == newBlocks.end()) {
298+
// Run the result attribute handler on the terminator operands.
299+
Operation *firstBlockTerminator = firstNewBlock->getTerminator();
300+
builder.setInsertionPoint(firstBlockTerminator);
301+
handleResultImpl(interface, builder, call, callable,
302+
firstBlockTerminator->getOperands());
303+
202304
// Have the interface handle the terminator of this block.
203-
auto *firstBlockTerminator = firstNewBlock->getTerminator();
204305
interface.handleTerminator(firstBlockTerminator,
205306
llvm::to_vector<6>(resultsToReplace));
206307
firstBlockTerminator->erase();
@@ -218,6 +319,11 @@ inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock,
218319
resultToRepl.value().getLoc()));
219320
}
220321

322+
// Run the result attribute handler on the post insertion block arguments.
323+
builder.setInsertionPointToStart(postInsertBlock);
324+
handleResultImpl(interface, builder, call, callable,
325+
postInsertBlock->getArguments());
326+
221327
/// Handle the terminators for each of the new blocks.
222328
for (auto &newBlock : newBlocks)
223329
interface.handleTerminator(newBlock.getTerminator(), postInsertBlock);

mlir/test/Transforms/inlining.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,40 @@ func.func @func_with_block_args_location_callee2(%arg0 : i32) {
226226
call @func_with_block_args_location(%arg0) : (i32) -> ()
227227
return
228228
}
229+
230+
// Check that we can handle argument and result attributes.
231+
func.func @handle_attr_callee_fn_multi_arg(%arg0 : i16, %arg1 : i16 {"test.handle_argument"}) -> (i16 {"test.handle_result"}, i16) {
232+
%0 = arith.addi %arg0, %arg1 : i16
233+
%1 = arith.subi %arg0, %arg1 : i16
234+
return %0, %1 : i16, i16
235+
}
236+
func.func @handle_attr_callee_fn(%arg0 : i32 {"test.handle_argument"}) -> (i32 {"test.handle_result"}) {
237+
return %arg0 : i32
238+
}
239+
240+
// CHECK-LABEL: func @inline_handle_attr_call
241+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
242+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
243+
func.func @inline_handle_attr_call(%arg0 : i16, %arg1 : i16) -> (i16, i16) {
244+
245+
// CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[ARG1]]) : (i16) -> i16
246+
// CHECK: %[[SUM:.*]] = arith.addi %[[ARG0]], %[[CHANGE_INPUT]]
247+
// CHECK: %[[DIFF:.*]] = arith.subi %[[ARG0]], %[[CHANGE_INPUT]]
248+
// CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[SUM]]) : (i16) -> i16
249+
// CHECK-NEXT: return %[[CHANGE_RESULT]], %[[DIFF]]
250+
%res0, %res1 = "test.conversion_call_op"(%arg0, %arg1) { callee=@handle_attr_callee_fn_multi_arg } : (i16, i16) -> (i16, i16)
251+
return %res0, %res1 : i16, i16
252+
}
253+
254+
// CHECK-LABEL: func @inline_convert_and_handle_attr_call
255+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
256+
func.func @inline_convert_and_handle_attr_call(%arg0 : i16) -> (i16) {
257+
258+
// CHECK: %[[CAST_INPUT:.*]] = "test.cast"(%[[ARG0]]) : (i16) -> i32
259+
// CHECK: %[[CHANGE_INPUT:.*]] = "test.type_changer"(%[[CAST_INPUT]]) : (i32) -> i32
260+
// CHECK: %[[CHANGE_RESULT:.*]] = "test.type_changer"(%[[CHANGE_INPUT]]) : (i32) -> i32
261+
// CHECK: %[[CAST_RESULT:.*]] = "test.cast"(%[[CHANGE_RESULT]]) : (i32) -> i16
262+
// CHECK: return %[[CAST_RESULT]]
263+
%res = "test.conversion_call_op"(%arg0) { callee=@handle_attr_callee_fn } : (i16) -> (i16)
264+
return %res : i16
265+
}

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,24 @@ struct TestInlinerInterface : public DialectInlinerInterface {
245245
return builder.create<TestCastOp>(conversionLoc, resultType, input);
246246
}
247247

248+
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
249+
Value argument, Type targetType,
250+
DictionaryAttr argumentAttrs) const final {
251+
if (!argumentAttrs.contains("test.handle_argument"))
252+
return argument;
253+
return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
254+
argument);
255+
}
256+
257+
Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
258+
Value result, Type targetType,
259+
DictionaryAttr resultAttrs) const final {
260+
if (!resultAttrs.contains("test.handle_result"))
261+
return result;
262+
return builder.create<TestTypeChangerOp>(call->getLoc(), targetType,
263+
result);
264+
}
265+
248266
void processInlinedCallBlocks(
249267
Operation *call,
250268
iterator_range<Region::iterator> inlinedBlocks) const final {

0 commit comments

Comments
 (0)