Skip to content

Commit d9111f1

Browse files
authored
[mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand (llvm#121304)
Edit the `findValueInReverseUseDefChain` method to accept `OpOperand` instead of the `Value` type, This change will make sure that the populated `visitedOpOperands` argument is fully accurate and contains the opOperand we have started the reverse chain from.
1 parent accd4a4 commit d9111f1

File tree

8 files changed

+64
-44
lines changed

8 files changed

+64
-44
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ class AnalysisState {
456456
/// read by themselves (e.g., ExtractSliceOp).
457457
bool isValueRead(Value value) const;
458458

459-
/// Starting from `value`, follow the use-def chain in reverse, always
459+
/// Starting from `opOperand`, follow the use-def chain in reverse, always
460460
/// selecting the aliasing OpOperands. Find and return Values for which
461461
/// `condition` evaluates to true. OpOperands of such matching Values are not
462462
/// traversed any further, the visited aliasing opOperands will be preserved
@@ -484,7 +484,7 @@ class AnalysisState {
484484
/// Additional stopping conditions for the traversal can be specified in
485485
/// `config`.
486486
SetVector<Value> findValueInReverseUseDefChain(
487-
Value value, llvm::function_ref<bool(Value)> condition,
487+
OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
488488
TraversalConfig config = TraversalConfig(),
489489
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
490490

@@ -520,7 +520,7 @@ class AnalysisState {
520520
///
521521
/// Note: OpResults of unknown ops are handled conservatively and assumed to
522522
/// be definitions.
523-
SetVector<Value> findDefinitions(Value value) const;
523+
SetVector<Value> findDefinitions(OpOperand *opOperand) const;
524524

525525
/// Return `true` if the given OpResult has been decided to bufferize inplace.
526526
virtual bool isInPlace(OpOperand &opOperand) const;

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState {
127127
/// Return true if the buffer of the given tensor value is writable.
128128
bool isWritable(Value value) const;
129129

130-
/// Find the definitions of the given tensor value or retrieve them from the
131-
/// cache.
132-
const SetVector<Value> &findDefinitionsCached(Value value);
130+
/// Find the definitions of the given operand's value or
131+
/// retrieve them from the cache.
132+
const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
133133

134134
/// Reset cached data structures.
135135
void resetCache() override;

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
480480
return false;
481481
}
482482

483-
// Starting from `value`, follow the use-def chain in reverse, always selecting
484-
// the aliasing OpOperands. Find and return Values for which `condition`
485-
// evaluates to true. OpOperands of such matching Values are not traversed any
486-
// further, the visited aliasing opOperands will be preserved through
487-
// `visitedOpOperands`.
483+
// Starting from `opOperand`, follow the use-def chain in reverse, always
484+
// selecting the aliasing OpOperands. Find and return Values for which
485+
// `condition` evaluates to true. Uses of such matching Values are not
486+
// traversed any further, the visited aliasing opOperands will be preserved
487+
// through `visitedOpOperands`.
488488
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
489-
Value value, llvm::function_ref<bool(Value)> condition,
489+
OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
490490
TraversalConfig config,
491491
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
492492
llvm::DenseSet<Value> visited;
493493
llvm::SetVector<Value> result, workingSet;
494-
workingSet.insert(value);
494+
workingSet.insert(opOperand->get());
495+
496+
if (visitedOpOperands)
497+
visitedOpOperands->insert(opOperand);
495498

496499
while (!workingSet.empty()) {
497500
Value value = workingSet.pop_back_val();
@@ -563,12 +566,14 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
563566
return result;
564567
}
565568

566-
// Find the values that define the contents of the given value.
567-
llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
569+
// Find the values that define the contents of the given operand's value.
570+
llvm::SetVector<Value>
571+
AnalysisState::findDefinitions(OpOperand *opOperand) const {
568572
TraversalConfig config;
569573
config.alwaysIncludeLeaves = false;
570574
return findValueInReverseUseDefChain(
571-
value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
575+
opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
576+
config);
572577
}
573578

574579
AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
892897
config.alwaysIncludeLeaves = false;
893898
for (AliasingOpOperand alias : opOperands) {
894899
if (!state
895-
.findValueInReverseUseDefChain(alias.opOperand->get(),
900+
.findValueInReverseUseDefChain(alias.opOperand,
896901
isMemoryWriteInsideOp, config)
897902
.empty())
898903
return true;

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
143143
// %3 = tensor.insert_slice %2 into ...
144144
config.followSameTypeOrCastsOnly = true;
145145
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
146-
source.get(), /*condition=*/
146+
&source, /*condition=*/
147147
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
148148
&visitedOpOperands);
149149

@@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
155155
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
156156
return llvm::count(emptyTensorOp->getUses(), *opOperand);
157157
});
158-
// This could be achieved when a use of `emptyTensorOp` is being
159-
// consumed by `SubsetInsertionOpInterface`'s source directly.
160-
if (iter == visitedOpOperands.end())
161-
continue;
158+
159+
assert(iter != visitedOpOperands.end() && "could not find use");
162160
OpOperand *useToBeReplaced = *iter;
163161
Operation *user = useToBeReplaced->getOwner();
164162
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
196196

197197
// If there is no preceding definition, the tensor contents are
198198
// undefined.
199-
if (findDefinitionsCached(opResult).empty())
199+
if (opResult.getUses().empty())
200+
continue;
201+
// It does not really matter which use to take to search about
202+
// the value's definitions.
203+
OpOperand *opOperand = &(*opResult.getUses().begin());
204+
if (findDefinitionsCached(opOperand).empty())
200205
for (OpOperand &use : opResult.getUses())
201206
undefinedTensorUses.insert(&use);
202207
}
@@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
464469
/// indexing. I.e., the tensor types do not change along the use-def chain,
465470
/// apart from static <-> dynamic dim casts.
466471
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
467-
Value start, Value other) {
472+
OpOperand *start,
473+
Value other) {
468474
TraversalConfig config;
469475
config.followEquivalentOnly = true;
470476
config.alwaysIncludeLeaves = false;
@@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
475481
.empty();
476482
}
477483

478-
/// Return "true" if `value` is originating from a subset that is equivalent to
479-
/// the subset that `subsetOp` inserts into.
480-
static bool matchesInsertDestination(const AnalysisState &state, Value value,
484+
/// Return "true" if the given operand's value is originating from a subset
485+
/// that is equivalent to the subset that `subsetOp` inserts into.
486+
static bool matchesInsertDestination(const AnalysisState &state,
487+
OpOperand *opOperand,
481488
SubsetInsertionOpInterface subsetOp) {
482489
auto matchingSubset = [&](Value val) {
483490
if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
490497
// There may be multiple leaves at which the reverse SSA use-def chain lookup
491498
// terminates. All of them must be equivalent subsets.
492499
SetVector<Value> backwardSlice =
493-
state.findValueInReverseUseDefChain(value, matchingSubset);
500+
state.findValueInReverseUseDefChain(opOperand, matchingSubset);
494501
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
495502
}
496503

@@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
516523
// {inplace= [true] }
517524

518525
if (uRead == &subsetOp.getDestinationOperand() &&
519-
matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
526+
matchesInsertDestination(state, uConflictingWrite, subsetOp))
520527
// Case 1: The main insight is that InsertSliceOp reads only part of
521528
// the destination tensor. The overwritten area is not read. If
522529
// uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
533540

534541
if (uRead == &subsetOp.getSourceOperand() &&
535542
uConflictingWrite == &subsetOp.getDestinationOperand() &&
536-
matchesInsertDestination(state, uRead->get(), subsetOp))
543+
matchesInsertDestination(state, uRead, subsetOp))
537544
// Case 2: The read of the source tensor and the write to the dest
538545
// tensor via an InsertSliceOp is not a conflict if the read is
539546
// reading exactly that part of an equivalent tensor that the
@@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
567574
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
568575
state.areEquivalentBufferizedValues(
569576
uRead->get(), subsetOp.getSourceOperand().get()) &&
570-
matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
571-
subsetOp))
577+
matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
572578
return true;
573579

574580
return false;
@@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
600606
// even though that op just bufferizes to an allocation but does define
601607
// the contents of the buffer.
602608
SetVector<Value> definitionsOrLeaves =
603-
state.findValueInReverseUseDefChain(
604-
uConflictingWrite->get(),
605-
[&](Value v) { return state.bufferizesToMemoryWrite(v); });
609+
state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
610+
return state.bufferizesToMemoryWrite(v);
611+
});
606612
assert(!definitionsOrLeaves.empty() &&
607613
"expected at least one definition or leaf");
608614

@@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
641647
// In the above example, if uRead is the OpOperand of reading_op, the
642648
// definition is %0. Note that operations that create an alias but do not
643649
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
644-
const SetVector<Value> &definitions =
645-
state.findDefinitionsCached(uRead->get());
650+
const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
646651
if (definitions.empty()) {
647652
// Fast path: No conflict if there are no definitions.
648653
LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
714719
if (bufferizableOp.bufferizesToElementwiseAccess(
715720
state, {uRead, uConflictingWrite})) {
716721
if (hasEquivalentValueInReverseUseDefChain(
717-
state, uRead->get(), uConflictingWrite->get()) ||
722+
state, uRead, uConflictingWrite->get()) ||
718723
hasEquivalentValueInReverseUseDefChain(
719-
state, uConflictingWrite->get(), uRead->get())) {
724+
state, uConflictingWrite, uRead->get())) {
720725
LLVM_DEBUG(
721726
llvm::dbgs()
722727
<< " no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
965970
// Bufferization analyses.
966971
//===----------------------------------------------------------------------===//
967972

968-
// Find the values that define the contents of the given value.
973+
// Find the values that define the contents of the given operand's value.
969974
const llvm::SetVector<Value> &
970-
OneShotAnalysisState::findDefinitionsCached(Value value) {
975+
OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
976+
Value value = opOperand->get();
971977
if (!cachedDefinitions.count(value))
972-
cachedDefinitions[value] = findDefinitions(value);
978+
cachedDefinitions[value] = findDefinitions(opOperand);
973979
return cachedDefinitions[value];
974980
}
975981

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
553553
Value alloc = createAllocationForTensor(
554554
rewriter, op->getLoc(), operand->get(), options, memorySpace);
555555
allocs.push_back(alloc);
556-
if (!state.findDefinitions(operand->get()).empty()) {
556+
if (!state.findDefinitions(operand).empty()) {
557557
// Initialize buffer with a copy of the operand data. Not needed if the
558558
// tensor is uninitialized.
559559
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);

mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
5959
config.followEquivalentOnly = true;
6060
config.alwaysIncludeLeaves = false;
6161
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
62-
in->get(), /*condition=*/
62+
in, /*condition=*/
6363
[&](Value val) {
6464
return val.getDefiningOp<tensor::EmptyOp>() &&
6565
val.getType() == in->get().getType();

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
465465
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
466466
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
467467
}
468+
469+
// -----
470+
471+
// CHECK-LABEL: func.func @direct_use_of_tensor_empty
472+
func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
473+
// CHECK-NOT: memref.alloc
474+
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
475+
%inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
476+
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
477+
return %inserted_slice_1 : tensor<5x6x128xf32>
478+
}

0 commit comments

Comments
 (0)