Skip to content

Commit b0688ed

Browse files
committed
[mlir][bufferization] Add DeallocOp canonicalizer to remove memrefs also present in the retained list
Since memrefs in the retained list will never be deallocated, we can remove them from the list of memrefs to be deallocated. If the list of memrefs to deallocate becomes empty, we can just delete the dealloc operation. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D156186
1 parent 9451233 commit b0688ed

File tree

2 files changed

+53
-11
lines changed

2 files changed

+53
-11
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -791,11 +791,19 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
791791
LogicalResult matchAndRewrite(DeallocOp deallocOp,
792792
PatternRewriter &rewriter) const override {
793793
// Unique memrefs to be deallocated.
794+
DenseSet<Value> retained(deallocOp.getRetained().begin(),
795+
deallocOp.getRetained().end());
794796
DenseMap<Value, unsigned> memrefToCondition;
795797
SmallVector<Value> newMemrefs, newConditions, newRetained;
796-
SmallVector<unsigned> resultIndices;
797-
for (auto [memref, cond] :
798-
llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
798+
SmallVector<int32_t> resultIndices(deallocOp.getMemrefs().size(), -1);
799+
for (auto [i, memref, cond] :
800+
llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
801+
if (retained.contains(memref)) {
802+
rewriter.replaceAllUsesWith(deallocOp.getResult(i),
803+
deallocOp.getConditions()[i]);
804+
continue;
805+
}
806+
799807
if (memrefToCondition.count(memref)) {
800808
// If the dealloc conditions don't match, we need to make sure that the
801809
// dealloc happens on the union of cases.
@@ -808,7 +816,7 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
808816
newMemrefs.push_back(memref);
809817
newConditions.push_back(cond);
810818
}
811-
resultIndices.push_back(memrefToCondition[memref]);
819+
resultIndices[i] = memrefToCondition[memref];
812820
}
813821

814822
// Unique retained values
@@ -831,19 +839,38 @@ struct DeallocRemoveDuplicates : public OpRewritePattern<DeallocOp> {
831839
auto newDealloc = rewriter.create<DeallocOp>(deallocOp.getLoc(), newMemrefs,
832840
newConditions, newRetained);
833841
for (auto [i, newIdx] : llvm::enumerate(resultIndices))
834-
rewriter.replaceAllUsesWith(deallocOp.getResult(i),
835-
newDealloc.getResult(newIdx));
842+
if (newIdx != -1)
843+
rewriter.replaceAllUsesWith(deallocOp.getResult(i),
844+
newDealloc.getResult(newIdx));
836845

837846
rewriter.eraseOp(deallocOp);
838847
return success();
839848
}
840849
};
841850

851+
/// Erase deallocation operations where the variadic list of memrefs to
852+
/// deallocate is emtpy. Example:
853+
/// ```mlir
854+
/// bufferization.dealloc retain (%arg0: memref<2xi32>)
855+
/// ```
856+
struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
857+
using OpRewritePattern<DeallocOp>::OpRewritePattern;
858+
859+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
860+
PatternRewriter &rewriter) const override {
861+
if (deallocOp.getMemrefs().empty()) {
862+
rewriter.eraseOp(deallocOp);
863+
return success();
864+
}
865+
return failure();
866+
}
867+
};
868+
842869
} // anonymous namespace
843870

844871
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
845872
MLIRContext *context) {
846-
results.add<DeallocRemoveDuplicates>(context);
873+
results.add<DeallocRemoveDuplicates, EraseEmptyDealloc>(context);
847874
}
848875

849876
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,15 +282,30 @@ func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<
282282

283283
// -----
284284

285-
func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
286-
%0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg4, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
285+
func.func @dealloc_canonicalize_duplicates(%arg0: memref<2xi32>, %arg1: i1, %arg2: i1, %arg3: memref<2xi32>, %arg4: memref<2xi32>, %arg5: memref<2xi32>) -> (i1, i1, i1, i1, i1) {
286+
%0:3 = bufferization.dealloc (%arg4, %arg0, %arg0 : memref<2xi32>, memref<2xi32>, memref<2xi32>) if (%arg1, %arg1, %arg1) retain (%arg3, %arg5, %arg3 : memref<2xi32>, memref<2xi32>, memref<2xi32>)
287287
%1:2 = bufferization.dealloc (%arg0, %arg0 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg2)
288288
return %0#0, %0#1, %0#2, %1#0, %1#1 : i1, i1, i1, i1, i1
289289
}
290290

291291
// CHECK-LABEL: func @dealloc_canonicalize_duplicates
292-
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>)
293-
// CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG4]] : memref<2xi32>, memref<2xi32>)
292+
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: i1, [[ARG3:%.+]]: memref<2xi32>, [[ARG4:%.+]]: memref<2xi32>, [[ARG5:%.+]]: memref<2xi32>)
293+
// CHECK-NEXT: [[V0:%.+]]:2 = bufferization.dealloc ([[ARG4]], [[ARG0]] : memref<2xi32>, memref<2xi32>) if ([[ARG1]], [[ARG1]]) retain ([[ARG3]], [[ARG5]] : memref<2xi32>, memref<2xi32>)
294294
// CHECK-NEXT: [[NEW_COND:%.+]] = arith.ori [[ARG1]], [[ARG2]] : i1
295295
// CHECK-NEXT: [[V1:%.+]] = bufferization.dealloc ([[ARG0]] : memref<2xi32>) if ([[NEW_COND]])
296296
// CHECK-NEXT: return [[V0]]#0, [[V0]]#1, [[V0]]#1, [[V1]], [[V1]] :
297+
298+
// -----
299+
300+
func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %arg1: i1, %arg2: memref<2xi32>) -> (i1, i1, i1) {
301+
%0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg1) retain (%arg0 : memref<2xi32>)
302+
%1:2 = bufferization.dealloc (%arg0, %arg2 : memref<2xi32>, memref<2xi32>) if (%arg1, %arg1) retain (%arg0 : memref<2xi32>)
303+
bufferization.dealloc
304+
bufferization.dealloc retain (%arg0 : memref<2xi32>)
305+
return %0, %1#0, %1#1 : i1, i1, i1
306+
}
307+
308+
// CHECK-LABEL: func @dealloc_canonicalize_retained_and_deallocated
309+
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
310+
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
311+
// CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :

0 commit comments

Comments
 (0)