Skip to content

Commit 3d0ca2c

Browse files
[mlir][bufferization] Allow cyclic function graphs without tensors (#68632)
Cyclic function call graphs are generally not supported by One-Shot Bufferize. However, they can be allowed when a function does not have tensor arguments or results. This is because it is then no longer necessary that the callee will be bufferized before the caller.
1 parent c8b5f4c commit 3d0ca2c

File tree

3 files changed

+41
-7
lines changed

3 files changed

+41
-7
lines changed

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ static void equivalenceAnalysis(func::FuncOp funcOp,
274274
});
275275
}
276276

277+
/// Return "true" if the given function signature has tensor semantics.
278+
static bool hasTensorSignature(func::FuncOp funcOp) {
279+
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
280+
return llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) ||
281+
llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor);
282+
}
283+
277284
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
278285
/// callee-caller order (i.e. callees without callers first).
279286
/// Store the map of FuncOp to all its callers in `callerMap`.
@@ -297,10 +304,16 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
297304
"without a unique ReturnOp";
298305
}
299306

307+
// Collect function calls and populate the caller map.
300308
numberCallOpsContainedInFuncOp[funcOp] = 0;
301309
return funcOp.walk([&](func::CallOp callOp) -> WalkResult {
302310
func::FuncOp calledFunction = getCalledFunction(callOp);
303311
assert(calledFunction && "could not retrieved called func::FuncOp");
312+
// If the called function does not have any tensors in its signature, then
313+
// it is not necessary to bufferize the callee before the caller.
314+
if (!hasTensorSignature(calledFunction))
315+
return WalkResult::skip();
316+
304317
callerMap[calledFunction].insert(callOp);
305318
if (calledBy[calledFunction].insert(funcOp).second) {
306319
numberCallOpsContainedInFuncOp[funcOp]++;
@@ -310,7 +323,7 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
310323
});
311324
if (res.wasInterrupted())
312325
return failure();
313-
// Iteratively remove function operation that do not call any of the
326+
// Iteratively remove function operations that do not call any of the
314327
// functions remaining in the callCounter map and add them to the worklist.
315328
while (!numberCallOpsContainedInFuncOp.empty()) {
316329
auto it = llvm::find_if(numberCallOpsContainedInFuncOp,

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ func.func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>
2727

2828
// expected-error @-3 {{expected callgraph to be free of circular dependencies}}
2929

30-
func.func @foo() {
31-
call @bar() : () -> ()
32-
return
30+
func.func @foo(%t: tensor<5xf32>) -> tensor<5xf32> {
31+
%0 = call @bar(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
32+
return %0 : tensor<5xf32>
3333
}
3434

35-
func.func @bar() {
36-
call @foo() : () -> ()
37-
return
35+
func.func @bar(%t: tensor<5xf32>) -> tensor<5xf32>{
36+
%0 = call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
37+
return %0 : tensor<5xf32>
3838
}
3939

4040
// -----

mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,3 +662,24 @@ func.func @br_in_func(%t: tensor<5xf32>) -> tensor<5xf32> {
662662
^bb1(%arg1 : tensor<5xf32>):
663663
func.return %arg1 : tensor<5xf32>
664664
}
665+
666+
// -----
667+
668+
// Cyclic call graphs with tensors are not supported by One-Shot Bufferize.
669+
// However, if a function signature does not have any tensor arguments or
670+
// results, calls to that function are not seen as an "edge" in the fuction
671+
// call graph.
672+
673+
// CHECK-LABEL: func.func @foo(%{{.*}}: memref<5xf32>) -> memref<5xf32>
674+
func.func @foo(%m: memref<5xf32>) -> memref<5xf32> {
675+
%0 = tensor.empty() : tensor<5xf32>
676+
%1 = func.call @bar(%0, %m)
677+
: (tensor<5xf32>, memref<5xf32>) -> (memref<5xf32>)
678+
return %1 : memref<5xf32>
679+
}
680+
681+
// CHECK: func.func @bar(%{{.*}}: memref<5xf32, strided<[?], offset: ?>>, %arg1: memref<5xf32>) -> memref<5xf32>
682+
func.func @bar(%t: tensor<5xf32>, %m: memref<5xf32>) -> memref<5xf32> {
683+
%0 = func.call @foo(%m) : (memref<5xf32>) -> (memref<5xf32>)
684+
return %0 : memref<5xf32>
685+
}

0 commit comments

Comments
 (0)