diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index 08d89d6db788c..d4b9134ab9eea 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -283,17 +283,23 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( AbstractDenseLattice *before) { // Find the callee. Operation *callee = call.resolveCallable(&symbolTable); - auto callable = dyn_cast_or_null(callee); - if (!callable) - return setToExitState(before); + auto callable = dyn_cast_or_null(callee); // No region means the callee is only declared in this module. - Region *region = callable.getCallableRegion(); - if (!region || region->empty() || !getSolverConfig().isInterprocedural()) { + // If that is the case or if the solver is not interprocedural, + // let the hook handle it. + if (!getSolverConfig().isInterprocedural() || + (callable && (!callable.getCallableRegion() || + callable.getCallableRegion()->empty()))) { return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, after, before); } + if (!callable) + return setToExitState(before); + + Region *region = callable.getCallableRegion(); + // Call-level control flow specifies the data flow here. // // func.func @callee() { diff --git a/mlir/test/Analysis/DataFlow/test-next-access.mlir b/mlir/test/Analysis/DataFlow/test-next-access.mlir index 70069b10a9398..8825c699dd130 100644 --- a/mlir/test/Analysis/DataFlow/test-next-access.mlir +++ b/mlir/test/Analysis/DataFlow/test-next-access.mlir @@ -575,3 +575,21 @@ func.func @call_opaque_callee(%arg0: memref) { memref.load %arg0[] {name = "post"} : memref return } + +// ----- + +// CHECK-LABEL: @indirect_call +func.func @indirect_call(%arg0: memref, %arg1: (memref) -> ()) { + // IP: name = "pre" + // IP-SAME: next_access = ["unknown"] + // IP_AR: name = "pre" + // IP_AR-SAME: next_access = ["unknown"] + // LOCAL: name = "pre" + // LOCAL-SAME: next_access = ["unknown"] + // LC_AR: name = "pre" + // LC_AR-SAME: next_access = {{\[}}["call"]] + memref.load %arg0[] {name = "pre"} : memref + func.call_indirect %arg1(%arg0) {name = "call"} : (memref) -> () + memref.load %arg0[] {name = "post"} : memref + return +}