Skip to content

Commit d790a21

Browse files
committed
[mlir] Add getArgOperandsMutable method to CallOpInterface
Add a method to the CallOpInterface to get a mutable operand range over the function arguments. This allows to add, remove, or change the type of call arguments in a generic manner without having to assume that the argument operand range is at the end of the operand list, or having to type switch on all supported concrete operation kinds. Alternatively, a new OpInterface could be added which inherits from CallOpInterface and appends it with the mutable variants of the base interface. There will be two users of this new function in the beginning: (1) A few passes in the Arc dialect in CIRCT already use a downstream implementation of the alternative case mentioned above: https://github.com/llvm/circt/blob/main/include/circt/Dialect/Arc/ArcInterfaces.td#L15 (2) The BufferDeallocation pass will be modified to be able to pass ownership of memrefs to called private functions if the caller does not need the memref anymore by appending the function argument list with a boolean value per memref, thus enabling earlier deallocation of the memref which can lead to lower peak memory usage. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156675
1 parent c2093b8 commit d790a21

File tree

13 files changed

+70
-1
lines changed

13 files changed

+70
-1
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,6 +2347,12 @@ def fir_CallOp : fir_Op<"call",
23472347
return {arg_operand_begin() + 1, arg_operand_end()};
23482348
}
23492349

2350+
mlir::MutableOperandRange getArgOperandsMutable() {
2351+
if ((*this)->getAttrOfType<mlir::SymbolRefAttr>(getCalleeAttrName()))
2352+
return getArgsMutable();
2353+
return mlir::MutableOperandRange(*this, 1, getArgs().size() - 1);
2354+
}
2355+
23502356
operand_iterator arg_operand_begin() { return operand_begin(); }
23512357
operand_iterator arg_operand_end() { return operand_end(); }
23522358

mlir/examples/toy/Ch4/mlir/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
348348
/// call interface.
349349
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
350350

351+
/// Get the argument operands to the called function as a mutable range, this is
352+
/// required by the call interface.
353+
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
354+
return getInputsMutable();
355+
}
356+
351357
//===----------------------------------------------------------------------===//
352358
// MulOp
353359
//===----------------------------------------------------------------------===//

mlir/examples/toy/Ch5/mlir/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
348348
/// call interface.
349349
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
350350

351+
/// Get the argument operands to the called function as a mutable range, this is
352+
/// required by the call interface.
353+
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
354+
return getInputsMutable();
355+
}
356+
351357
//===----------------------------------------------------------------------===//
352358
// MulOp
353359
//===----------------------------------------------------------------------===//

mlir/examples/toy/Ch6/mlir/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
348348
/// call interface.
349349
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
350350

351+
/// Get the argument operands to the called function as a mutable range, this is
352+
/// required by the call interface.
353+
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
354+
return getInputsMutable();
355+
}
356+
351357
//===----------------------------------------------------------------------===//
352358
// MulOp
353359
//===----------------------------------------------------------------------===//

mlir/examples/toy/Ch7/mlir/Dialect.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,12 @@ void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
377377
/// call interface.
378378
Operation::operand_range GenericCallOp::getArgOperands() { return getInputs(); }
379379

380+
/// Get the argument operands to the called function as a mutable range, this is
381+
/// required by the call interface.
382+
MutableOperandRange GenericCallOp::getArgOperandsMutable() {
383+
return getInputsMutable();
384+
}
385+
380386
//===----------------------------------------------------------------------===//
381387
// MulOp
382388
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Async/IR/AsyncOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ def Async_CallOp : Async_Op<"call",
264264
return {arg_operand_begin(), arg_operand_end()};
265265
}
266266

267+
MutableOperandRange getArgOperandsMutable() {
268+
return getOperandsMutable();
269+
}
270+
267271
operand_iterator arg_operand_begin() { return operand_begin(); }
268272
operand_iterator arg_operand_end() { return operand_end(); }
269273

mlir/include/mlir/Dialect/Func/IR/FuncOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def CallOp : Func_Op<"call",
8383
return {arg_operand_begin(), arg_operand_end()};
8484
}
8585

86+
MutableOperandRange getArgOperandsMutable() {
87+
return getOperandsMutable();
88+
}
89+
8690
operand_iterator arg_operand_begin() { return operand_begin(); }
8791
operand_iterator arg_operand_end() { return operand_end(); }
8892

@@ -152,6 +156,10 @@ def CallIndirectOp : Func_Op<"call_indirect", [
152156
return {arg_operand_begin(), arg_operand_end()};
153157
}
154158

159+
MutableOperandRange getArgOperandsMutable() {
160+
return getCalleeOperandsMutable();
161+
}
162+
155163
operand_iterator arg_operand_begin() { return ++operand_begin(); }
156164
operand_iterator arg_operand_end() { return operand_end(); }
157165

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
616616
}];
617617

618618
dag args = (ins OptionalAttr<FlatSymbolRefAttr>:$callee,
619-
Variadic<LLVM_Type>,
619+
Variadic<LLVM_Type>:$callee_operands,
620620
DefaultValuedAttr<LLVM_FastmathFlagsAttr,
621621
"{}">:$fastmathFlags,
622622
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,10 @@ def IncludeOp : TransformDialectOp<"include",
632632
::mlir::Operation::operand_range getArgOperands() {
633633
return getOperands();
634634
}
635+
636+
::mlir::MutableOperandRange getArgOperandsMutable() {
637+
return getOperandsMutable();
638+
}
635639
}];
636640
}
637641

mlir/include/mlir/Interfaces/CallInterfaces.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def CallOpInterface : OpInterface<"CallOpInterface"> {
5555
}],
5656
"::mlir::Operation::operand_range", "getArgOperands"
5757
>,
58+
InterfaceMethod<[{
59+
Returns the operands within this call that are used as arguments to the
60+
callee as a mutable range.
61+
}],
62+
"::mlir::MutableOperandRange", "getArgOperandsMutable">,
5863
];
5964

6065
let extraClassDeclaration = [{

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,11 @@ Operation::operand_range CallOp::getArgOperands() {
10031003
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
10041004
}
10051005

1006+
MutableOperandRange CallOp::getArgOperandsMutable() {
1007+
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1008+
getCalleeOperands().size());
1009+
}
1010+
10061011
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10071012
if (getNumResults() > 1)
10081013
return emitOpError("must have 0 or 1 result");
@@ -1237,6 +1242,11 @@ Operation::operand_range InvokeOp::getArgOperands() {
12371242
return getOperands().drop_front(getCallee().has_value() ? 0 : 1);
12381243
}
12391244

1245+
MutableOperandRange InvokeOp::getArgOperandsMutable() {
1246+
return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1247+
getCalleeOperands().size());
1248+
}
1249+
12401250
LogicalResult InvokeOp::verify() {
12411251
if (getNumResults() > 1)
12421252
return emitOpError("must have 0 or 1 result");

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ Operation::operand_range FunctionCallOp::getArgOperands() {
208208
return getArguments();
209209
}
210210

211+
MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
212+
return getArgumentsMutable();
213+
}
214+
211215
//===----------------------------------------------------------------------===//
212216
// spirv.mlir.loop
213217
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,10 @@ Operation::operand_range TestCallAndStoreOp::getArgOperands() {
12631263
return getCalleeOperands();
12641264
}
12651265

1266+
MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() {
1267+
return getCalleeOperandsMutable();
1268+
}
1269+
12661270
void TestStoreWithARegion::getSuccessorRegions(
12671271
std::optional<unsigned> index, ArrayRef<Attribute> operands,
12681272
SmallVectorImpl<RegionSuccessor> &regions) {

0 commit comments

Comments
 (0)