diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h index 7dbcddb01b241..58c37f01caef0 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.h +++ b/mlir/include/mlir/Interfaces/CallInterfaces.h @@ -23,10 +23,20 @@ namespace mlir { struct CallInterfaceCallable : public PointerUnion { using PointerUnion::PointerUnion; }; -} // namespace mlir -/// Include the generated interface declarations. -#include "mlir/Interfaces/CallInterfaces.h.inc" +class CallOpInterface; + +namespace call_interface_impl { + +/// Resolve the callable operation for given callee to a CallableOpInterface, or +/// nullptr if a valid callable was not resolved. `symbolTable` is an optional +/// parameter that will allow for using a cached symbol table for symbol lookups +/// instead of performing an O(N) scan. +Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr); + +} // namespace call_interface_impl + +} // namespace mlir namespace llvm { @@ -41,4 +51,7 @@ struct CastInfo } // namespace llvm +/// Include the generated interface declarations. +#include "mlir/Interfaces/CallInterfaces.h.inc" + #endif // MLIR_INTERFACES_CALLINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index 752de74e6e4d7..c6002da0d491c 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -59,17 +59,29 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { Returns the operands within this call that are used as arguments to the callee as a mutable range. }], - "::mlir::MutableOperandRange", "getArgOperandsMutable">, + "::mlir::MutableOperandRange", "getArgOperandsMutable" + >, + InterfaceMethod<[{ + Resolve the callable operation for given callee to a + CallableOpInterface, or nullptr if a valid callable was not resolved. + `symbolTable` parameter allow for using a cached symbol table for symbol + lookups instead of performing an O(N) scan. + }], + "::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable); + }] + >, + InterfaceMethod<[{ + Resolve the callable operation for given callee to a + CallableOpInterface, or nullptr if a valid callable was not resolved. + }], + "::mlir::Operation *", "resolveCallable", (ins), + /*methodBody=*/[{}], /*defaultImplementation=*/[{ + return ::mlir::call_interface_impl::resolveCallable($_op); + }] + > ]; - - let extraClassDeclaration = [{ - /// Resolve the callable operation for given callee to a - /// CallableOpInterface, or nullptr if a valid callable was not resolved. - /// `symbolTable` is an optional parameter that will allow for using a - /// cached symbol table for symbol lookups instead of performing an O(N) - /// scan. - ::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr); - }]; } /// Interface for callable operations. diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index ccd4676632136..780c7caee767c 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const { CallGraphNode * CallGraph::resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const { - Operation *callable = call.resolveCallable(&symbolTable); + Operation *callable = call.resolveCallableInTable(&symbolTable); if (auto callableOp = dyn_cast_or_null(callable)) if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 532480b6fad57..beb68018a3b16 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { } void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { - Operation *callableOp = call.resolveCallable(&symbolTable); + Operation *callableOp = call.resolveCallableInTable(&symbolTable); // A call to a externally-defined callable has unknown predecessors. const auto isExternalCallable = [this](Operation *op) { diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 37f4ceaaa56ce..300c6e5f9b891 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( CallOpInterface call, const AbstractDenseLattice &after, AbstractDenseLattice *before) { // Find the callee. - Operation *callee = call.resolveCallable(&symbolTable); + Operation *callee = call.resolveCallableInTable(&symbolTable); auto callable = dyn_cast_or_null(callee); // No region means the callee is only declared in this module. diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 4a73f21a18aae..1bd6defef90be 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // For function calls, connect the arguments of the entry blocks to the // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast(op)) { - Operation *callableOp = call.resolveCallable(&symbolTable); + Operation *callableOp = call.resolveCallableInTable(&symbolTable); if (auto callable = dyn_cast_or_null(callableOp)) { // Not all operands of a call op forward to arguments. Such operands are // stored in `unaccounted`. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index ca5d0688b5b59..b973618004497 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -824,7 +824,7 @@ FailureOr BufferDeallocation::handleInterface(CallOpInterface op) { // the function is referenced by SSA value instead of a Symbol, it's assumed // to be public. (And we cannot easily change the type of the SSA value // anyway.) - Operation *funcOp = op.resolveCallable(state.getSymbolTable()); + Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable()); bool isPrivate = false; if (auto symbol = dyn_cast_or_null(funcOp)) isPrivate = symbol.isPrivate() && !symbol.isDeclaration(); diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp index 455684d8e2ea7..47f8021f50cd2 100644 --- a/mlir/lib/Interfaces/CallInterfaces.cpp +++ b/mlir/lib/Interfaces/CallInterfaces.cpp @@ -14,21 +14,17 @@ using namespace mlir; // CallOpInterface //===----------------------------------------------------------------------===// -/// Resolve the callable operation for given callee to a CallableOpInterface, or -/// nullptr if a valid callable was not resolved. `symbolTable` is an optional -/// parameter that will allow for using a cached symbol table for symbol lookups -/// instead of performing an O(N) scan. Operation * -CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) { - CallInterfaceCallable callable = getCallableForCallee(); +call_interface_impl::resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable) { + CallInterfaceCallable callable = call.getCallableForCallee(); if (auto symbolVal = dyn_cast(callable)) return symbolVal.getDefiningOp(); // If the callable isn't a value, lookup the symbol reference. auto symbolRef = callable.get(); if (symbolTable) - return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef); - return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef); + return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef); + return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef); } //===----------------------------------------------------------------------===//