Skip to content

Commit 950f094

Browse files
committed
[mlir][bufferization] Factor out bufferization.dealloc lowering into separate pass
Moves the lowering of `bufferization.dealloc` to memref into a separate pass, but still registers the pattern in the conversion pass. This is helpful when some tensor values (and thus `to_memref` or `to_tensor` operations) still remain, e.g., when the function boundaries are not converted, or when constant tensors are converted to memref.get_global at a later point. However, it is still recommended to perform all bufferization before deallocation to avoid memory leaks as all memref allocations inserted after the deallocation pass was applied, have to be handled manually. Note: The buffer deallocation pass assumes that memref values defined by `bufferization.to_memref` don't return ownership and don't have to be deallocated. `bufferization.to_tensor` operations are handled similarly to `bufferization.clone` operations with the exception that the result value is not handled because it's a tensor (not a memref). Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D159180
1 parent 4fdc019 commit 950f094

File tree

11 files changed

+840
-687
lines changed

11 files changed

+840
-687
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,14 @@ def ConvertBufferizationToMemRef : Pass<"convert-bufferization-to-memref"> {
195195
This pass converts bufferization operations into memref operations.
196196

197197
In the current state, this pass only transforms a `bufferization.clone`
198-
operation into `memref.alloc` and `memref.copy` operations. This conversion
199-
is needed, since some clone operations could remain after applying several
200-
transformation processes. Currently, only `canonicalize` transforms clone
201-
operations or even eliminates them. This can lead to errors if any clone op
202-
survived after all conversion passes (starting from the bufferization
203-
dialect) are performed.
198+
operation into `memref.alloc` and `memref.copy` operations and
199+
`bufferization.dealloc` operations (the same way as the
200+
`-bufferization-lower-deallocations` pass). The conversion of `clone`
201+
operations is needed, since some clone operations could remain after
202+
applying several transformation processes. Currently, only `canonicalize`
203+
transforms clone operations or even eliminates them. This can lead to errors
204+
if any clone op survived after all conversion passes (starting from the
205+
bufferization dialect) are performed.
204206

205207
See:
206208
https://llvm.discourse.group/t/bufferization-error-related-to-memref-clone/4665

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
namespace mlir {
77
class ModuleOp;
8+
class RewritePatternSet;
9+
class OpBuilder;
10+
class SymbolTable;
811

912
namespace func {
1013
class FuncOp;
@@ -29,6 +32,98 @@ std::unique_ptr<Pass> createBufferDeallocationPass();
2932
/// static alias analysis.
3033
std::unique_ptr<Pass> createBufferDeallocationSimplificationPass();
3134

35+
/// Creates an instance of the LowerDeallocations pass to lower
36+
/// `bufferization.dealloc` operations to the `memref` dialect.
37+
std::unique_ptr<Pass> createLowerDeallocationsPass();
38+
39+
/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
40+
/// given pattern set for use in other transformation passes.
41+
void populateBufferizationDeallocLoweringPattern(
42+
RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
43+
44+
/// Construct the library function needed for the fully generic
45+
/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.
46+
/// The function can then be called at bufferization dealloc sites to determine
47+
/// aliasing and ownership.
48+
///
49+
/// The generated function takes two memrefs of indices and three memrefs of
50+
/// booleans as arguments:
51+
/// * The first argument A should contain the result of the
52+
/// extract_aligned_pointer_as_index operation applied to the memrefs to be
53+
/// deallocated
54+
/// * The second argument B should contain the result of the
55+
/// extract_aligned_pointer_as_index operation applied to the memrefs to be
56+
/// retained
57+
/// * The third argument C should contain the conditions as passed directly
58+
/// to the deallocation operation.
59+
/// * The fourth argument D is used to pass results to the caller. Those
60+
/// represent the condition under which the memref at the corresponding
61+
/// position in A should be deallocated.
62+
/// * The fifth argument E is used to pass results to the caller. It
63+
/// provides the ownership value corresponding the the memref at the same
64+
/// position in B
65+
///
66+
/// This helper function is supposed to be called once for each
67+
/// `bufferization.dealloc` operation to determine the deallocation need and new
68+
/// ownership indicator for the retained values, but does not perform the
69+
/// deallocation itself.
70+
///
71+
/// Generated code:
72+
/// ```
73+
/// func.func @dealloc_helper(
74+
/// %dyn_dealloc_base_pointer_list: memref<?xindex>,
75+
/// %dyn_retain_base_pointer_list: memref<?xindex>,
76+
/// %dyn_cond_list: memref<?xi1>,
77+
/// %dyn_dealloc_cond_out: memref<?xi1>,
78+
/// %dyn_ownership_out: memref<?xi1>) {
79+
/// %c0 = arith.constant 0 : index
80+
/// %c1 = arith.constant 1 : index
81+
/// %true = arith.constant true
82+
/// %false = arith.constant false
83+
/// %num_dealloc_memrefs = memref.dim %dyn_dealloc_base_pointer_list, %c0
84+
/// %num_retain_memrefs = memref.dim %dyn_retain_base_pointer_list, %c0
85+
/// // Zero initialize result buffer.
86+
/// scf.for %i = %c0 to %num_retain_memrefs step %c1 {
87+
/// memref.store %false, %dyn_ownership_out[%i] : memref<?xi1>
88+
/// }
89+
/// scf.for %i = %c0 to %num_dealloc_memrefs step %c1 {
90+
/// %dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%i]
91+
/// %cond = memref.load %dyn_cond_list[%i]
92+
/// // Check for aliasing with retained memrefs.
93+
/// %does_not_alias_retained = scf.for %j = %c0 to %num_retain_memrefs
94+
/// step %c1 iter_args(%does_not_alias_aggregated = %true) -> (i1) {
95+
/// %retain_bp = memref.load %dyn_retain_base_pointer_list[%j]
96+
/// %does_alias = arith.cmpi eq, %retain_bp, %dealloc_bp : index
97+
/// scf.if %does_alias {
98+
/// %curr_ownership = memref.load %dyn_ownership_out[%j]
99+
/// %updated_ownership = arith.ori %curr_ownership, %cond : i1
100+
/// memref.store %updated_ownership, %dyn_ownership_out[%j]
101+
/// }
102+
/// %does_not_alias = arith.cmpi ne, %retain_bp, %dealloc_bp : index
103+
/// %updated_aggregate = arith.andi %does_not_alias_aggregated,
104+
/// %does_not_alias : i1
105+
/// scf.yield %updated_aggregate : i1
106+
/// }
107+
/// // Check for aliasing with dealloc memrefs in the list before the
108+
/// // current one, i.e.,
109+
/// // `fix i, forall j < i: check_aliasing(%dyn_dealloc_base_pointer[j],
110+
/// // %dyn_dealloc_base_pointer[i])`
111+
/// %does_not_alias_any = scf.for %j = %c0 to %i step %c1
112+
/// iter_args(%does_not_alias_agg = %does_not_alias_retained) -> (i1) {
113+
/// %prev_dealloc_bp = memref.load %dyn_dealloc_base_pointer_list[%j]
114+
/// %does_not_alias = arith.cmpi ne, %prev_dealloc_bp, %dealloc_bp
115+
/// %updated_alias_agg = arith.andi %does_not_alias_agg, %does_not_alias
116+
/// scf.yield %updated_alias_agg : i1
117+
/// }
118+
/// %dealloc_cond = arith.andi %does_not_alias_any, %cond : i1
119+
/// memref.store %dealloc_cond, %dyn_dealloc_cond_out[%i] : memref<?xi1>
120+
/// }
121+
/// return
122+
/// }
123+
/// ```
124+
func::FuncOp buildDeallocationLibraryFunction(OpBuilder &builder, Location loc,
125+
SymbolTable &symbolTable);
126+
32127
/// Run buffer deallocation.
33128
LogicalResult deallocateBuffers(Operation *op);
34129

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,29 @@ def BufferDeallocationSimplification :
108108
];
109109
}
110110

111+
def LowerDeallocations : Pass<"bufferization-lower-deallocations"> {
112+
let summary = "Lowers `bufferization.dealloc` operations to `memref.dealloc`"
113+
"operations";
114+
let description = [{
115+
This pass lowers `bufferization.dealloc` operations to the `memref` dialect.
116+
It can be applied to a `builtin.module` or operations implementing the
117+
`FunctionOpInterface`. For the latter, only simple `dealloc` operations can
118+
be lowered because the library function necessary for the fully generic
119+
lowering cannot be inserted. In this case, an error will be emitted.
120+
Next to `memref.dealloc` operations, it may also emit operations from the
121+
`arith`, `scf`, and `func` dialects to build conditional deallocations and
122+
library functions to avoid code-size blow-up.
123+
}];
124+
125+
let constructor =
126+
"mlir::bufferization::createLowerDeallocationsPass()";
127+
128+
let dependentDialects = [
129+
"arith::ArithDialect", "memref::MemRefDialect", "scf::SCFDialect",
130+
"func::FuncDialect"
131+
];
132+
}
133+
111134
def BufferHoisting : Pass<"buffer-hoisting", "func::FuncOp"> {
112135
let summary = "Optimizes placement of allocation operations by moving them "
113136
"into common dominators and out of nested regions";

0 commit comments

Comments
 (0)