From 13fd7a28cd7bd0e06b61c3f56c563e28c6104c7e Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 10 Sep 2024 10:23:27 +0200 Subject: [PATCH] =?UTF-8?q?Revert=20"[MLIR]=20Make=20`resolveCallable`=20c?= =?UTF-8?q?ustomizable=20in=20`CallOpInterface`=20(#100=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 958f59d90fbc1cb2ae186246c8a64fec9e3ecd6e. --- mlir/include/mlir/Interfaces/CallInterfaces.h | 19 ++--------- .../include/mlir/Interfaces/CallInterfaces.td | 32 ++++++------------- mlir/lib/Analysis/CallGraph.cpp | 2 +- .../Analysis/DataFlow/DeadCodeAnalysis.cpp | 2 +- mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 2 +- mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 2 +- .../OwnershipBasedBufferDeallocation.cpp | 2 +- mlir/lib/Interfaces/CallInterfaces.cpp | 12 ++++--- 8 files changed, 26 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h index 58c37f01caef0..7dbcddb01b241 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.h +++ b/mlir/include/mlir/Interfaces/CallInterfaces.h @@ -23,21 +23,11 @@ namespace mlir { struct CallInterfaceCallable : public PointerUnion { using PointerUnion::PointerUnion; }; - -class CallOpInterface; - -namespace call_interface_impl { - -/// Resolve the callable operation for given callee to a CallableOpInterface, or -/// nullptr if a valid callable was not resolved. `symbolTable` is an optional -/// parameter that will allow for using a cached symbol table for symbol lookups -/// instead of performing an O(N) scan. -Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr); - -} // namespace call_interface_impl - } // namespace mlir +/// Include the generated interface declarations. +#include "mlir/Interfaces/CallInterfaces.h.inc" + namespace llvm { // Allow llvm::cast style functions. @@ -51,7 +41,4 @@ struct CastInfo } // namespace llvm -/// Include the generated interface declarations. -#include "mlir/Interfaces/CallInterfaces.h.inc" - #endif // MLIR_INTERFACES_CALLINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index c6002da0d491c..752de74e6e4d7 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -59,29 +59,17 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { Returns the operands within this call that are used as arguments to the callee as a mutable range. }], - "::mlir::MutableOperandRange", "getArgOperandsMutable" - >, - InterfaceMethod<[{ - Resolve the callable operation for given callee to a - CallableOpInterface, or nullptr if a valid callable was not resolved. - `symbolTable` parameter allow for using a cached symbol table for symbol - lookups instead of performing an O(N) scan. - }], - "::mlir::Operation *", "resolveCallableInTable", (ins "::mlir::SymbolTableCollection *":$symbolTable), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return ::mlir::call_interface_impl::resolveCallable($_op, symbolTable); - }] - >, - InterfaceMethod<[{ - Resolve the callable operation for given callee to a - CallableOpInterface, or nullptr if a valid callable was not resolved. - }], - "::mlir::Operation *", "resolveCallable", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ - return ::mlir::call_interface_impl::resolveCallable($_op); - }] - > + "::mlir::MutableOperandRange", "getArgOperandsMutable">, ]; + + let extraClassDeclaration = [{ + /// Resolve the callable operation for given callee to a + /// CallableOpInterface, or nullptr if a valid callable was not resolved. + /// `symbolTable` is an optional parameter that will allow for using a + /// cached symbol table for symbol lookups instead of performing an O(N) + /// scan. + ::mlir::Operation *resolveCallable(::mlir::SymbolTableCollection *symbolTable = nullptr); + }]; } /// Interface for callable operations. diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp index 780c7caee767c..ccd4676632136 100644 --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -146,7 +146,7 @@ CallGraphNode *CallGraph::lookupNode(Region *region) const { CallGraphNode * CallGraph::resolveCallable(CallOpInterface call, SymbolTableCollection &symbolTable) const { - Operation *callable = call.resolveCallableInTable(&symbolTable); + Operation *callable = call.resolveCallable(&symbolTable); if (auto callableOp = dyn_cast_or_null(callable)) if (auto *node = lookupNode(callableOp.getCallableRegion())) return node; diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index beb68018a3b16..532480b6fad57 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -297,7 +297,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint point) { } void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { - Operation *callableOp = call.resolveCallableInTable(&symbolTable); + Operation *callableOp = call.resolveCallable(&symbolTable); // A call to a externally-defined callable has unknown predecessors. const auto isExternalCallable = [this](Operation *op) { diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 300c6e5f9b891..37f4ceaaa56ce 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -284,7 +284,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( CallOpInterface call, const AbstractDenseLattice &after, AbstractDenseLattice *before) { // Find the callee. - Operation *callee = call.resolveCallableInTable(&symbolTable); + Operation *callee = call.resolveCallable(&symbolTable); auto callable = dyn_cast_or_null(callee); // No region means the callee is only declared in this module. diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index 1bd6defef90be..4a73f21a18aae 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -442,7 +442,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // For function calls, connect the arguments of the entry blocks to the // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast(op)) { - Operation *callableOp = call.resolveCallableInTable(&symbolTable); + Operation *callableOp = call.resolveCallable(&symbolTable); if (auto callable = dyn_cast_or_null(callableOp)) { // Not all operands of a call op forward to arguments. Such operands are // stored in `unaccounted`. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index b973618004497..ca5d0688b5b59 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -824,7 +824,7 @@ FailureOr BufferDeallocation::handleInterface(CallOpInterface op) { // the function is referenced by SSA value instead of a Symbol, it's assumed // to be public. (And we cannot easily change the type of the SSA value // anyway.) - Operation *funcOp = op.resolveCallableInTable(state.getSymbolTable()); + Operation *funcOp = op.resolveCallable(state.getSymbolTable()); bool isPrivate = false; if (auto symbol = dyn_cast_or_null(funcOp)) isPrivate = symbol.isPrivate() && !symbol.isDeclaration(); diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp index 47f8021f50cd2..455684d8e2ea7 100644 --- a/mlir/lib/Interfaces/CallInterfaces.cpp +++ b/mlir/lib/Interfaces/CallInterfaces.cpp @@ -14,17 +14,21 @@ using namespace mlir; // CallOpInterface //===----------------------------------------------------------------------===// +/// Resolve the callable operation for given callee to a CallableOpInterface, or +/// nullptr if a valid callable was not resolved. `symbolTable` is an optional +/// parameter that will allow for using a cached symbol table for symbol lookups +/// instead of performing an O(N) scan. Operation * -call_interface_impl::resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable) { - CallInterfaceCallable callable = call.getCallableForCallee(); +CallOpInterface::resolveCallable(SymbolTableCollection *symbolTable) { + CallInterfaceCallable callable = getCallableForCallee(); if (auto symbolVal = dyn_cast(callable)) return symbolVal.getDefiningOp(); // If the callable isn't a value, lookup the symbol reference. auto symbolRef = callable.get(); if (symbolTable) - return symbolTable->lookupNearestSymbolFrom(call.getOperation(), symbolRef); - return SymbolTable::lookupNearestSymbolFrom(call.getOperation(), symbolRef); + return symbolTable->lookupNearestSymbolFrom(getOperation(), symbolRef); + return SymbolTable::lookupNearestSymbolFrom(getOperation(), symbolRef); } //===----------------------------------------------------------------------===//