diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 5a6de372e2310..7bbdeab3ea1a8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -49,19 +49,49 @@ static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, return success(); } -/// Checks if 'memref' may or must alias a MemRef in 'memrefList'. It is often a +/// Given a memref value, return the "base" value by skipping over all +/// ViewLikeOpInterface ops (if any) in the reverse use-def chain. +static Value getViewBase(Value value) { + while (auto viewLikeOp = value.getDefiningOp()) + value = viewLikeOp.getViewSource(); + return value; +} + +/// Return "true" if the given values are guaranteed to be different (and +/// non-aliasing) allocations based on the fact that one value is the result +/// of an allocation and the other value is a block argument of a parent block. +/// Note: This is a best-effort analysis that will eventually be replaced by a +/// proper "is same allocation" analysis. This function may return "false" even +/// though the two values are distinct allocations. +static bool distinctAllocAndBlockArgument(Value v1, Value v2) { + Value v1Base = getViewBase(v1); + Value v2Base = getViewBase(v2); + auto areDistinct = [](Value v1, Value v2) { + if (Operation *op = v1.getDefiningOp()) + if (hasEffect(op, v1)) + if (auto bbArg = dyn_cast(v2)) + if (bbArg.getOwner()->findAncestorOpInBlock(*op)) + return true; + return false; + }; + return areDistinct(v1Base, v2Base) || areDistinct(v2Base, v1Base); +} + +/// Checks if `memref` may or must alias a MemRef in `otherList`. It is often a /// requirement of optimization patterns that there cannot be any aliasing -/// memref in order to perform the desired simplification. The 'allowSelfAlias' -/// argument indicates whether 'memref' may be present in 'memrefList' which +/// memref in order to perform the desired simplification. The `allowSelfAlias` +/// argument indicates whether `memref` may be present in `otherList` which /// makes this helper function applicable to situations where we already know -/// that 'memref' is in the list but also when we don't want it in the list. +/// that `memref` is in the list but also when we don't want it in the list. static bool potentiallyAliasesMemref(AliasAnalysis &analysis, - ValueRange memrefList, Value memref, + ValueRange otherList, Value memref, bool allowSelfAlias) { - for (auto mr : memrefList) { - if (allowSelfAlias && mr == memref) + for (auto other : otherList) { + if (allowSelfAlias && other == memref) + continue; + if (distinctAllocAndBlockArgument(other, memref)) continue; - if (!analysis.alias(mr, memref).isNo()) + if (!analysis.alias(other, memref).isNo()) return true; } return false; diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir index dc372749fc074..ad7c4c783e907 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-region-branchop-interface.mlir @@ -270,7 +270,8 @@ func.func @loop_alloc( // CHECK: [[V0:%.+]]:2 = scf.for {{.*}} iter_args([[ARG6:%.+]] = [[ARG3]], [[ARG7:%.+]] = %false // CHECK: [[ALLOC1:%.+]] = memref.alloc() // CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG6]] -// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) retain ([[ALLOC1]] : +// CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG7]]) +// CHECK-NOT: retain // CHECK: scf.yield [[ALLOC1]], %true // CHECK: test.copy // CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0 @@ -563,8 +564,8 @@ func.func @while_two_arg(%arg0: index) { // CHECK: ^bb0([[ARG1:%.+]]: memref, [[ARG2:%.+]]: memref, [[ARG3:%.+]]: i1, [[ARG4:%.+]]: i1): // CHECK: [[ALLOC1:%.+]] = memref.alloc( // CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]] -// CHECK: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]], [[ALLOC1]] : -// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]]#0, [[ARG3]] +// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[ARG4]]) retain ([[ARG1]] : +// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[ARG3]] // CHECK: scf.yield [[ARG1]], [[ALLOC1]], [[OWN_AGG]], %true // CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0 // CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1 @@ -594,10 +595,10 @@ func.func @while_three_arg(%arg0: index) { // CHECK: [[ALLOC1:%.+]] = memref.alloc( // CHECK: [[ALLOC2:%.+]] = memref.alloc( // CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG1]] -// CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]] // CHECK: [[BASE2:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG3]] -// CHECK: [[OWN:%.+]]:3 = bufferization.dealloc ([[BASE0]], [[BASE1]], [[BASE2]], [[ALLOC1]] :{{.*}}) if ([[ARG4]], [[ARG5]], [[ARG6]], %true{{[0-9_]*}}) retain ([[ALLOC2]], [[ALLOC1]], [[ARG2]] : -// CHECK: scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN]]#2 : +// CHECK: [[OWN:%.+]] = bufferization.dealloc ([[BASE0]], [[BASE2]] :{{.*}}) if ([[ARG4]], [[ARG6]]) retain ([[ARG2]] : +// CHECK: [[OWN_AGG:%.+]] = arith.ori [[OWN]], [[ARG5]] +// CHECK: scf.yield [[ALLOC2]], [[ALLOC1]], [[ARG2]], %true{{[0-9_]*}}, %true{{[0-9_]*}}, [[OWN_AGG]] : // CHECK: } // CHECK: [[BASE0:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#0 // CHECK: [[BASE1:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[V0]]#1 diff --git a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir index 98eb038df30a3..e192e9870becd 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/buffer-deallocation-simplification.mlir @@ -133,3 +133,30 @@ func.func @dealloc_remove_dealloc_memref_contained_in_retained_with_const_true_c // CHECK: [[BASE:%[a-zA-Z0-9_]+]],{{.*}} = memref.extract_strided_metadata [[ARG2]] // CHECK: bufferization.dealloc ([[BASE]] :{{.*}}) if (%true{{[0-9_]*}}) // CHECK-NEXT: return [[ARG0]], [[ARG1]], %true{{[0-9_]*}}, %true{{[0-9_]*}} : + +// ----- + +func.func @alloc_and_bbarg(%arg0: memref<5xf32>, %arg1: index, %arg2: index, %arg3: index) -> f32 { + %true = arith.constant true + %false = arith.constant false + %0:2 = scf.for %arg4 = %arg1 to %arg2 step %arg3 iter_args(%arg5 = %arg0, %arg6 = %false) -> (memref<5xf32>, i1) { + %alloc = memref.alloc() : memref<5xf32> + memref.copy %arg5, %alloc : memref<5xf32> to memref<5xf32> + %base_buffer_0, %offset_1, %sizes_2, %strides_3 = memref.extract_strided_metadata %arg5 : memref<5xf32> -> memref, index, index, index + %2 = bufferization.dealloc (%base_buffer_0, %alloc : memref, memref<5xf32>) if (%arg6, %true) retain (%alloc : memref<5xf32>) + scf.yield %alloc, %2 : memref<5xf32>, i1 + } + %1 = memref.load %0#0[%arg1] : memref<5xf32> + %base_buffer, %offset, %sizes, %strides = memref.extract_strided_metadata %0#0 : memref<5xf32> -> memref, index, index, index + bufferization.dealloc (%base_buffer : memref) if (%0#1) + return %1 : f32 +} + +// CHECK-LABEL: func @alloc_and_bbarg +// CHECK: %[[true:.*]] = arith.constant true +// CHECK: scf.for {{.*}} iter_args(%[[iter:.*]] = %{{.*}}, %{{.*}} = %{{.*}}) +// CHECK: %[[alloc:.*]] = memref.alloc +// CHECK: %[[view:.*]], %{{.*}}, %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[iter]] +// CHECK: bufferization.dealloc (%[[view]] : memref) +// CHECK-NOT: retain +// CHECK: scf.yield %[[alloc]], %[[true]]