Skip to content

[mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand #121304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ class AnalysisState {
/// read by themselves (e.g., ExtractSliceOp).
bool isValueRead(Value value) const;

/// Starting from `value`, follow the use-def chain in reverse, always
/// Starting from `opOperand`, follow the use-def chain in reverse, always
/// selecting the aliasing OpOperands. Find and return Values for which
/// `condition` evaluates to true. OpOperands of such matching Values are not
/// traversed any further, the visited aliasing opOperands will be preserved
Expand Down Expand Up @@ -484,7 +484,7 @@ class AnalysisState {
/// Additional stopping conditions for the traversal can be specified in
/// `config`.
SetVector<Value> findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config = TraversalConfig(),
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;

Expand Down Expand Up @@ -520,7 +520,7 @@ class AnalysisState {
///
/// Note: OpResults of unknown ops are handled conservatively and assumed to
/// be definitions.
SetVector<Value> findDefinitions(Value value) const;
SetVector<Value> findDefinitions(OpOperand *opOperand) const;

/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;

/// Find the definitions of the given tensor value or retrieve them from the
/// cache.
const SetVector<Value> &findDefinitionsCached(Value value);
/// Find the definitions of the given operand's value or
/// retrieve them from the cache.
const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);

/// Reset cached data structures.
void resetCache() override;
Expand Down
27 changes: 16 additions & 11 deletions mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
return false;
}

// Starting from `value`, follow the use-def chain in reverse, always selecting
// the aliasing OpOperands. Find and return Values for which `condition`
// evaluates to true. OpOperands of such matching Values are not traversed any
// further, the visited aliasing opOperands will be preserved through
// `visitedOpOperands`.
// Starting from `opOperand`, follow the use-def chain in reverse, always
// selecting the aliasing OpOperands. Find and return Values for which
// `condition` evaluates to true. Uses of such matching Values are not
// traversed any further, the visited aliasing opOperands will be preserved
// through `visitedOpOperands`.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition,
OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config,
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
llvm::DenseSet<Value> visited;
llvm::SetVector<Value> result, workingSet;
workingSet.insert(value);
workingSet.insert(opOperand->get());

if (visitedOpOperands)
visitedOpOperands->insert(opOperand);

while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
Expand Down Expand Up @@ -563,12 +566,14 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
return result;
}

// Find the values that define the contents of the given value.
llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
// Find the values that define the contents of the given operand's value.
llvm::SetVector<Value>
AnalysisState::findDefinitions(OpOperand *opOperand) const {
TraversalConfig config;
config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); },
config);
}

AnalysisState::AnalysisState(const BufferizationOptions &options)
Expand Down Expand Up @@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
.findValueInReverseUseDefChain(alias.opOperand->get(),
.findValueInReverseUseDefChain(alias.opOperand,
isMemoryWriteInsideOp, config)
.empty())
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
source.get(), /*condition=*/
&source, /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
&visitedOpOperands);

Expand All @@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
return llvm::count(emptyTensorOp->getUses(), *opOperand);
});
// This could be achieved when a use of `emptyTensorOp` is being
// consumed by `SubsetInsertionOpInterface`'s source directly.
if (iter == visitedOpOperands.end())
continue;

assert(iter != visitedOpOperands.end() && "could not find use");
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
Expand Down
46 changes: 26 additions & 20 deletions mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {

// If there is no preceding definition, the tensor contents are
// undefined.
if (findDefinitionsCached(opResult).empty())
if (opResult.getUses().empty())
continue;
// It does not really matter which use to take to search about
// the value's definitions.
OpOperand *opOperand = &(*opResult.getUses().begin());
if (findDefinitionsCached(opOperand).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
Expand Down Expand Up @@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// indexing. I.e., the tensor types do not change along the use-def chain,
/// apart from static <-> dynamic dim casts.
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
Value start, Value other) {
OpOperand *start,
Value other) {
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
Expand All @@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
.empty();
}

/// Return "true" if `value` is originating from a subset that is equivalent to
/// the subset that `subsetOp` inserts into.
static bool matchesInsertDestination(const AnalysisState &state, Value value,
/// Return "true" if the given operand's value is originating from a subset
/// that is equivalent to the subset that `subsetOp` inserts into.
static bool matchesInsertDestination(const AnalysisState &state,
OpOperand *opOperand,
SubsetInsertionOpInterface subsetOp) {
auto matchingSubset = [&](Value val) {
if (auto opResult = dyn_cast<OpResult>(val))
Expand All @@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
// There may be multiple leaves at which the reverse SSA use-def chain lookup
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
state.findValueInReverseUseDefChain(value, matchingSubset);
state.findValueInReverseUseDefChain(opOperand, matchingSubset);
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
}

Expand All @@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
// {inplace= [true] }

if (uRead == &subsetOp.getDestinationOperand() &&
matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
matchesInsertDestination(state, uConflictingWrite, subsetOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
Expand All @@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,

if (uRead == &subsetOp.getSourceOperand() &&
uConflictingWrite == &subsetOp.getDestinationOperand() &&
matchesInsertDestination(state, uRead->get(), subsetOp))
matchesInsertDestination(state, uRead, subsetOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
Expand Down Expand Up @@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
state.areEquivalentBufferizedValues(
uRead->get(), subsetOp.getSourceOperand().get()) &&
matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
subsetOp))
matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
return true;

return false;
Expand Down Expand Up @@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// even though that op just bufferizes to an allocation but does define
// the contents of the buffer.
SetVector<Value> definitionsOrLeaves =
state.findValueInReverseUseDefChain(
uConflictingWrite->get(),
[&](Value v) { return state.bufferizesToMemoryWrite(v); });
state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
return state.bufferizesToMemoryWrite(v);
});
assert(!definitionsOrLeaves.empty() &&
"expected at least one definition or leaf");

Expand Down Expand Up @@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// In the above example, if uRead is the OpOperand of reading_op, the
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
const SetVector<Value> &definitions =
state.findDefinitionsCached(uRead->get());
const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
Expand Down Expand Up @@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (bufferizableOp.bufferizesToElementwiseAccess(
state, {uRead, uConflictingWrite})) {
if (hasEquivalentValueInReverseUseDefChain(
state, uRead->get(), uConflictingWrite->get()) ||
state, uRead, uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
state, uConflictingWrite->get(), uRead->get())) {
state, uConflictingWrite, uRead->get())) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op bufferizes to element-wise access\n");
Expand Down Expand Up @@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Bufferization analyses.
//===----------------------------------------------------------------------===//

// Find the values that define the contents of the given value.
// Find the values that define the contents of the given operand's value.
const llvm::SetVector<Value> &
OneShotAnalysisState::findDefinitionsCached(Value value) {
OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
Value value = opOperand->get();
if (!cachedDefinitions.count(value))
cachedDefinitions[value] = findDefinitions(value);
cachedDefinitions[value] = findDefinitions(opOperand);
return cachedDefinitions[value];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
Value alloc = createAllocationForTensor(
rewriter, op->getLoc(), operand->get(), options, memorySpace);
allocs.push_back(alloc);
if (!state.findDefinitions(operand->get()).empty()) {
if (!state.findDefinitions(operand).empty()) {
// Initialize buffer with a copy of the operand data. Not needed if the
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
in->get(), /*condition=*/
in, /*condition=*/
[&](Value val) {
return val.getDefiningOp<tensor::EmptyOp>() &&
val.getType() == in->get().getType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
}

// -----

// CHECK-LABEL: func.func @direct_use_of_tensor_empty
func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
// CHECK-NOT: memref.alloc
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1 : tensor<5x6x128xf32>
}
Loading