Skip to content

[mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand #121304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

amirBish
Copy link
Contributor

Edit the findValueInReverseUseDefChain method to accept OpOperand instead of the Value type, This change will make sure that the populated visitedOpOperands argument is fully accurate and contains the opOperand we have started the reverse chain from.

@llvmbot
Copy link
Member

llvmbot commented Dec 29, 2024

@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir-linalg

Author: Amir Bishara (amirBish)

Changes

Edit the findValueInReverseUseDefChain method to accept OpOperand instead of the Value type, This change will make sure that the populated visitedOpOperands argument is fully accurate and contains the opOperand we have started the reverse chain from.


Full diff: https://github.com/llvm/llvm-project/pull/121304.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h (+3-3)
  • (modified) mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h (+3-3)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp (+10-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+3-5)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp (+22-16)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp (+1-1)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+11)
diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
index 983f7a29cb2206..d1a102e2a6e4e8 100644
--- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h
@@ -456,7 +456,7 @@ class AnalysisState {
   /// read by themselves (e.g., ExtractSliceOp).
   bool isValueRead(Value value) const;
 
-  /// Starting from `value`, follow the use-def chain in reverse, always
+  /// Starting from `opOperand`, follow the use-def chain in reverse, always
   /// selecting the aliasing OpOperands. Find and return Values for which
   /// `condition` evaluates to true. OpOperands of such matching Values are not
   /// traversed any further, the visited aliasing opOperands will be preserved
@@ -484,7 +484,7 @@ class AnalysisState {
   /// Additional stopping conditions for the traversal can be specified in
   /// `config`.
   SetVector<Value> findValueInReverseUseDefChain(
-      Value value, llvm::function_ref<bool(Value)> condition,
+      OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
       TraversalConfig config = TraversalConfig(),
       llvm::DenseSet<OpOperand *> *visitedOpOperands = nullptr) const;
 
@@ -520,7 +520,7 @@ class AnalysisState {
   ///
   /// Note: OpResults of unknown ops are handled conservatively and assumed to
   /// be definitions.
-  SetVector<Value> findDefinitions(Value value) const;
+  SetVector<Value> findDefinitions(OpOperand *opOperand) const;
 
   /// Return `true` if the given OpResult has been decided to bufferize inplace.
   virtual bool isInPlace(OpOperand &opOperand) const;
diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
index d50a3042aeeacf..da3094a6d6f546 100644
--- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
+++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h
@@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState {
   /// Return true if the buffer of the given tensor value is writable.
   bool isWritable(Value value) const;
 
-  /// Find the definitions of the given tensor value or retrieve them from the
-  /// cache.
-  const SetVector<Value> &findDefinitionsCached(Value value);
+  /// Find the definitions of the given tensor value related to `opOperand` or
+  /// retrieve them from the cache.
+  const SetVector<Value> &findDefinitionsCached(OpOperand *opOperand);
 
   /// Reset cached data structures.
   void resetCache() override;
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
index 349841f06959c3..7ca9659ef86ee2 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp
@@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const {
   return false;
 }
 
-// Starting from `value`, follow the use-def chain in reverse, always selecting
+// Starting from `opOperand`, follow the use-def chain in reverse, always selecting
 // the aliasing OpOperands. Find and return Values for which `condition`
 // evaluates to true. OpOperands of such matching Values are not traversed any
 // further, the visited aliasing opOperands will be preserved through
 // `visitedOpOperands`.
 llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
-    Value value, llvm::function_ref<bool(Value)> condition,
+    OpOperand *opOperand, llvm::function_ref<bool(Value)> condition,
     TraversalConfig config,
     llvm::DenseSet<OpOperand *> *visitedOpOperands) const {
   llvm::DenseSet<Value> visited;
   llvm::SetVector<Value> result, workingSet;
-  workingSet.insert(value);
+  workingSet.insert(opOperand->get());
+
+  if (visitedOpOperands)
+    visitedOpOperands->insert(opOperand);
 
   while (!workingSet.empty()) {
     Value value = workingSet.pop_back_val();
@@ -563,12 +566,12 @@ llvm::SetVector<Value> AnalysisState::findValueInReverseUseDefChain(
   return result;
 }
 
-// Find the values that define the contents of the given value.
-llvm::SetVector<Value> AnalysisState::findDefinitions(Value value) const {
+// Find the values that define the contents of the given opOperand.
+llvm::SetVector<Value> AnalysisState::findDefinitions(OpOperand *opOperand) const {
   TraversalConfig config;
   config.alwaysIncludeLeaves = false;
   return findValueInReverseUseDefChain(
-      value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
+      opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config);
 }
 
 AnalysisState::AnalysisState(const BufferizationOptions &options)
@@ -892,7 +895,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite(
   config.alwaysIncludeLeaves = false;
   for (AliasingOpOperand alias : opOperands) {
     if (!state
-             .findValueInReverseUseDefChain(alias.opOperand->get(),
+             .findValueInReverseUseDefChain(alias.opOperand,
                                             isMemoryWriteInsideOp, config)
              .empty())
       return true;
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 98c3d8d0adc6d2..84c2da6df093bd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
     // %3 = tensor.insert_slice %2 into ...
     config.followSameTypeOrCastsOnly = true;
     SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
-        source.get(), /*condition=*/
+        &source, /*condition=*/
         [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
         &visitedOpOperands);
 
@@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
           visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
             return llvm::count(emptyTensorOp->getUses(), *opOperand);
           });
-      // This could be achieved when a use of `emptyTensorOp` is being
-      // consumed by `SubsetInsertionOpInterface`'s source directly.
-      if (iter == visitedOpOperands.end())
-        continue;
+
+      assert (iter != visitedOpOperands.end());
       OpOperand *useToBeReplaced = *iter;
       Operation *user = useToBeReplaced->getOwner();
       auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d1e6acef324fbd..2f50b0f02876dd 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
 
       // If there is no preceding definition, the tensor contents are
       // undefined.
-      if (findDefinitionsCached(opResult).empty())
+      if (opResult.getUses().empty())
+        return WalkResult::skip();
+      // It does not really matter which use to take to search about
+      // the value's definitions.
+      OpOperand *opOperand = &(*opResult.getUses().begin());
+      if (findDefinitionsCached(opOperand).empty())
         for (OpOperand &use : opResult.getUses())
           undefinedTensorUses.insert(&use);
     }
@@ -464,20 +469,20 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
 /// indexing. I.e., the tensor types do not change along the use-def chain,
 /// apart from static <-> dynamic dim casts.
 static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
-                                                   Value start, Value other) {
+                                                   OpOperand *start, OpOperand *other) {
   TraversalConfig config;
   config.followEquivalentOnly = true;
   config.alwaysIncludeLeaves = false;
   config.followSameTypeOrCastsOnly = true;
   return !state
               .findValueInReverseUseDefChain(
-                  start, [&](Value v) { return v == other; }, config)
+                  start, [&](Value v) { return v == other->get(); }, config)
               .empty();
 }
 
-/// Return "true" if `value` is originating from a subset that is equivalent to
+/// Return "true" if `opOperand` is originating from a subset that is equivalent to
 /// the subset that `subsetOp` inserts into.
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
+static bool matchesInsertDestination(const AnalysisState &state, OpOperand *opOperand,
                                      SubsetInsertionOpInterface subsetOp) {
   auto matchingSubset = [&](Value val) {
     if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +495,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
   // There may be multiple leaves at which the reverse SSA use-def chain lookup
   // terminates. All of them must be equivalent subsets.
   SetVector<Value> backwardSlice =
-      state.findValueInReverseUseDefChain(value, matchingSubset);
+      state.findValueInReverseUseDefChain(opOperand, matchingSubset);
   return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
 }
 
@@ -516,7 +521,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
     //     {inplace= [true] }
 
     if (uRead == &subsetOp.getDestinationOperand() &&
-        matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+        matchesInsertDestination(state, uConflictingWrite, subsetOp))
       // Case 1: The main insight is that InsertSliceOp reads only part of
       // the destination tensor. The overwritten area is not read. If
       // uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +538,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
 
     if (uRead == &subsetOp.getSourceOperand() &&
         uConflictingWrite == &subsetOp.getDestinationOperand() &&
-        matchesInsertDestination(state, uRead->get(), subsetOp))
+        matchesInsertDestination(state, uRead, subsetOp))
       // Case 2: The read of the source tensor and the write to the dest
       // tensor via an InsertSliceOp is not a conflict if the read is
       // reading exactly that part of an equivalent tensor that the
@@ -567,7 +572,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
     if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
         state.areEquivalentBufferizedValues(
             uRead->get(), subsetOp.getSourceOperand().get()) &&
-        matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
+        matchesInsertDestination(state, &subsetOp.getSourceOperand(),
                                  subsetOp))
       return true;
 
@@ -601,7 +606,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
       // the contents of the buffer.
       SetVector<Value> definitionsOrLeaves =
           state.findValueInReverseUseDefChain(
-              uConflictingWrite->get(),
+              uConflictingWrite,
               [&](Value v) { return state.bufferizesToMemoryWrite(v); });
       assert(!definitionsOrLeaves.empty() &&
              "expected at least one definition or leaf");
@@ -642,7 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
     // definition is %0. Note that operations that create an alias but do not
     // bufferize to a memory write (such as ExtractSliceOp) are skipped.
     const SetVector<Value> &definitions =
-        state.findDefinitionsCached(uRead->get());
+        state.findDefinitionsCached(uRead);
     if (definitions.empty()) {
       // Fast path: No conflict if there are no definitions.
       LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
             if (bufferizableOp.bufferizesToElementwiseAccess(
                     state, {uRead, uConflictingWrite})) {
               if (hasEquivalentValueInReverseUseDefChain(
-                      state, uRead->get(), uConflictingWrite->get()) ||
+                      state, uRead, uConflictingWrite) ||
                   hasEquivalentValueInReverseUseDefChain(
-                      state, uConflictingWrite->get(), uRead->get())) {
+                      state, uConflictingWrite, uRead)) {
                 LLVM_DEBUG(
                     llvm::dbgs()
                     << "  no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
 // Bufferization analyses.
 //===----------------------------------------------------------------------===//
 
-// Find the values that define the contents of the given value.
+// Find the values that define the contents of the given opOperand.
 const llvm::SetVector<Value> &
-OneShotAnalysisState::findDefinitionsCached(Value value) {
+OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
+  Value value = opOperand->get();
   if (!cachedDefinitions.count(value))
-    cachedDefinitions[value] = findDefinitions(value);
+    cachedDefinitions[value] = findDefinitions(opOperand);
   return cachedDefinitions[value];
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 6801b68a853815..6c1087730ebba8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation(
     Value alloc = createAllocationForTensor(
         rewriter, op->getLoc(), operand->get(), options, memorySpace);
     allocs.push_back(alloc);
-    if (!state.findDefinitions(operand->get()).empty()) {
+    if (!state.findDefinitions(operand).empty()) {
       // Initialize buffer with a copy of the operand data. Not needed if the
       // tensor is uninitialized.
       createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
index 4776883ed95c5c..b710bde87f9f33 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp
@@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep(
       config.followEquivalentOnly = true;
       config.alwaysIncludeLeaves = false;
       SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
-          in->get(), /*condition=*/
+          in, /*condition=*/
           [&](Value val) {
             return val.getDefiningOp<tensor::EmptyOp>() &&
                    val.getType() == in->get().getType();
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 26434774730e1b..41ab9cd113b39a 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32>
 }
+
+// -----
+
+// CHECK-LABEL:   func.func @direct_use_of_tensor_empty
+func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> {
+  // CHECK-NOT: memref.alloc
+  %empty_1 = tensor.empty() : tensor<5x6x64xf32>
+  %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1]
+      : tensor<5x6x64xf32> into tensor<5x6x128xf32>
+  return %inserted_slice_1 : tensor<5x6x128xf32>
+}
\ No newline at end of file

@amirBish amirBish requested a review from cathyzhyi December 29, 2024 22:07
Copy link

github-actions bot commented Dec 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@amirBish amirBish force-pushed the amirBish/mlir/edit-bufferization-analyisis-method branch from c883d9e to 7b5e2c4 Compare December 29, 2024 22:09
…t opOperand

Edit the `findValueInReverseUseDefChain` method to accept `OpOperand`
instead of the `Value` type, This change will make sure that the
populated `visitedOpOperands` argument is fully accurate and contains
the opOperand we have started the reverse chain from.
@amirBish amirBish force-pushed the amirBish/mlir/edit-bufferization-analyisis-method branch from 7b5e2c4 to f35c557 Compare December 30, 2024 19:00
@amirBish
Copy link
Contributor Author

Thanks for the review, fixed the threads. Feel free to have a look again :)

@amirBish amirBish merged commit d9111f1 into llvm:main Dec 30, 2024
8 checks passed
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Feb 7, 2025
Local branch amd-gfx 00799b0 Merged main:c7d237085bf9 into amd-gfx:4b28d14fac77
Remote branch main d9111f1 [mlir][bufferization]-Refactor findValueInReverseUseDefChain to accept opOperand (llvm#121304)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants