-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand #121304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand #121304
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir-linalg Author: Amir Bishara (amirBish) ChangesEdit the Full diff: https://github.com/llvm/llvm-project/pull/121304.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 983f7a29cb2206..d1a102e2a6e4e8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -456,7 +456,7 @@ class AnalysisState {
/// read by themselves (e.g., ExtractSliceOp).
bool isValueRead(Value value) const;
- /// Starting from `value`, follow the use-def chain in reverse, always
+ /// Starting from `opOperand`, follow the use-def chain in reverse, always
/// selecting the aliasing OpOperands. Find and return Values for which
/// `condition` evaluates to true. OpOperands of such matching Values are not
/// traversed any further, the visited aliasing opOperands will be preserved
@@ -484,7 +484,7 @@ class AnalysisState {
/// Additional stopping conditions for the traversal can be specified in
/// `config`.
SetVector<Value> findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition,
+ OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config = TraversalConfig(),
llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
@@ -520,7 +520,7 @@ class AnalysisState {
///
/// Note: OpResults of unknown ops are handled conservatively and assumed to
/// be definitions.
- SetVector<Value> findDefinitions(Value value) const;
+ SetVector<Value> findDefinitions(OpOperand *opOperand) const;
/// Return `true` if the given OpResult has been decided to bufferize inplace.
virtual bool isInPlace(OpOperand &opOperand) const;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index d50a3042aeeacf..da3094a6d6f546 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState {
/// Return true if the buffer of the given tensor value is writable.
bool isWritable(Value value) const;
- /// Find the definitions of the given tensor value or retrieve them from the
- /// cache.
- const SetVector<Value> &findDefinitionsCached(Value value);
+ /// Find the definitions of the given tensor value related to `opOperand` or
+ /// retrieve them from the cache.
+ const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
/// Reset cached data structures.
void resetCache() override;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 349841f06959c3..7ca9659ef86ee2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
return false;
}
-// Starting from `value`, follow the use-def chain in reverse, always selecting
+// Starting from `opOperand`, follow the use-def chain in reverse, always selecting
// the aliasing OpOperands. Find and return Values for which `condition`
// evaluates to true. OpOperands of such matching Values are not traversed any
// further, the visited aliasing opOperands will be preserved through
// `visitedOpOperands`.
llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
- Value value, llvm::function_ref<bool(Value)> condition,
+ OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
TraversalConfig config,
llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
llvm::DenseSet<Value> visited;
llvm::SetVector<Value> result, workingSet;
- workingSet.insert(value);
+ workingSet.insert(opOperand->get());
+
+ if (visitedOpOperands)
+ visitedOpOperands->insert(opOperand);
while (!workingSet.empty()) {
Value value = workingSet.pop_back_val();
@@ -563,12 +566,12 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
return result;
}
-// Find the values that define the contents of the given value.
-llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+// Find the values that define the contents of the given opOperand.
+llvm::SetVector<Value> AnalysisState::findDefinitions(OpOperand *opOperand) const {
TraversalConfig config;
config.alwaysIncludeLeaves = false;
return findValueInReverseUseDefChain(
- value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
+ opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
}
AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -892,7 +895,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
config.alwaysIncludeLeaves = false;
for (AliasingOpOperand alias : opOperands) {
if (!state
- .findValueInReverseUseDefChain(alias.opOperand->get(),
+ .findValueInReverseUseDefChain(alias.opOperand,
isMemoryWriteInsideOp, config)
.empty())
return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 98c3d8d0adc6d2..84c2da6df093bd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- source.get(), /*condition=*/
+ &source, /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
&visitedOpOperands);
@@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
return llvm::count(emptyTensorOp->getUses(), *opOperand);
});
- // This could be achieved when a use of `emptyTensorOp` is being
- // consumed by `SubsetInsertionOpInterface`'s source directly.
- if (iter == visitedOpOperands.end())
- continue;
+
+ assert (iter != visitedOpOperands.end());
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d1e6acef324fbd..2f50b0f02876dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitionsCached(opResult).empty())
+ if (opResult.getUses().empty())
+ return WalkResult::skip();
+ // It does not really matter which use to take to search about
+ // the value's definitions.
+ OpOperand *opOperand = &(*opResult.getUses().begin());
+ if (findDefinitionsCached(opOperand).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -464,20 +469,20 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// indexing. I.e., the tensor types do not change along the use-def chain,
/// apart from static <-> dynamic dim casts.
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
- Value start, Value other) {
+ OpOperand *start, OpOperand *other) {
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
config.followSameTypeOrCastsOnly = true;
return !state
.findValueInReverseUseDefChain(
- start, [&](Value v) { return v == other; }, config)
+ start, [&](Value v) { return v == other->get(); }, config)
.empty();
}
-/// Return "true" if `value` is originating from a subset that is equivalent to
+/// Return "true" if `opOperand` is originating from a subset that is equivalent to
/// the subset that `subsetOp` inserts into.
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
+static bool matchesInsertDestination(const AnalysisState &state, OpOperand *opOperand,
SubsetInsertionOpInterface subsetOp) {
auto matchingSubset = [&](Value val) {
if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +495,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
// There may be multiple leaves at which the reverse SSA use-def chain lookup
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
- state.findValueInReverseUseDefChain(value, matchingSubset);
+ state.findValueInReverseUseDefChain(opOperand, matchingSubset);
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
}
@@ -516,7 +521,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
// {inplace= [true] }
if (uRead == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+ matchesInsertDestination(state, uConflictingWrite, subsetOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +538,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uRead == &subsetOp.getSourceOperand() &&
uConflictingWrite == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uRead->get(), subsetOp))
+ matchesInsertDestination(state, uRead, subsetOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -567,7 +572,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
state.areEquivalentBufferizedValues(
uRead->get(), subsetOp.getSourceOperand().get()) &&
- matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
+ matchesInsertDestination(state, &subsetOp.getSourceOperand(),
subsetOp))
return true;
@@ -601,7 +606,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// the contents of the buffer.
SetVector<Value> definitionsOrLeaves =
state.findValueInReverseUseDefChain(
- uConflictingWrite->get(),
+ uConflictingWrite,
[&](Value v) { return state.bufferizesToMemoryWrite(v); });
assert(!definitionsOrLeaves.empty() &&
"expected at least one definition or leaf");
@@ -642,7 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
const SetVector<Value> &definitions =
- state.findDefinitionsCached(uRead->get());
+ state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (bufferizableOp.bufferizesToElementwiseAccess(
state, {uRead, uConflictingWrite})) {
if (hasEquivalentValueInReverseUseDefChain(
- state, uRead->get(), uConflictingWrite->get()) ||
+ state, uRead, uConflictingWrite) ||
hasEquivalentValueInReverseUseDefChain(
- state, uConflictingWrite->get(), uRead->get())) {
+ state, uConflictingWrite, uRead)) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Bufferization analyses.
//===----------------------------------------------------------------------===//
-// Find the values that define the contents of the given value.
+// Find the values that define the contents of the given opOperand.
const llvm::SetVector<Value> &
-OneShotAnalysisState::findDefinitionsCached(Value value) {
+OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
+ Value value = opOperand->get();
if (!cachedDefinitions.count(value))
- cachedDefinitions[value] = findDefinitions(value);
+ cachedDefinitions[value] = findDefinitions(opOperand);
return cachedDefinitions[value];
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6801b68a853815..6c1087730ebba8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
Value alloc = createAllocationForTensor(
rewriter, op->getLoc(), operand->get(), options, memorySpace);
allocs.push_back(alloc);
- if (!state.findDefinitions(operand->get()).empty()) {
+ if (!state.findDefinitions(operand).empty()) {
// Initialize buffer with a copy of the operand data. Not needed if the
// tensor is uninitialized.
createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 4776883ed95c5c..b710bde87f9f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- in->get(), /*condition=*/
+ in, /*condition=*/
[&](Value val) {
return val.getDefiningOp<tensor::EmptyOp>() &&
val.getType() == in->get().getType();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 26434774730e1b..41ab9cd113b39a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @direct_use_of_tensor_empty
+func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
+ // CHECK-NOT: memref.alloc
+ %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+ %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
+ : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+ return %inserted_slice_1 : tensor<5x6x128xf32>
+}
\ No newline at end of file
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
c883d9e
to
7b5e2c4
Compare
mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
Outdated
Show resolved
Hide resolved
…t opOperand Edit the `findValueInReverseUseDefChain` method to accept `OpOperand` instead of the `Value` type, This change will make sure that the populated `visitedOpOperands` argument is fully accurate and contains the opOperand we have started the reverse chain from.
7b5e2c4
to
f35c557
Compare
Thanks for the review, fixed the threads. Feel free to have a look again :) |
Local branch amd-gfx 00799b0 Merged main:c7d237085bf9 into amd-gfx:4b28d14fac77 Remote branch main d9111f1 [mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand (llvm#121304)
Edit the
findValueInReverseUseDefChain
method to acceptOpOperand
instead of theValue
type, This change will make sure that the populatedvisitedOpOperands
argument is fully accurate and contains the opOperand we have started the reverse chain from.