diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 983f7a29cb220..d1a102e2a6e4e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -456,7 +456,7 @@ class AnalysisState { /// read by themselves (e.g., ExtractSliceOp). bool isValueRead(Value value) const; - /// Starting from `value`, follow the use-def chain in reverse, always + /// Starting from `opOperand`, follow the use-def chain in reverse, always /// selecting the aliasing OpOperands. Find and return Values for which /// `condition` evaluates to true. OpOperands of such matching Values are not /// traversed any further, the visited aliasing opOperands will be preserved @@ -484,7 +484,7 @@ class AnalysisState { /// Additional stopping conditions for the traversal can be specified in /// `config`. SetVector findValueInReverseUseDefChain( - Value value, llvm::function_ref condition, + OpOperand *opOperand, llvm::function_ref condition, TraversalConfig config = TraversalConfig(), llvm::DenseSet *visitedOpOperands = nullptr) const; @@ -520,7 +520,7 @@ class AnalysisState { /// /// Note: OpResults of unknown ops are handled conservatively and assumed to /// be definitions. - SetVector findDefinitions(Value value) const; + SetVector findDefinitions(OpOperand *opOperand) const; /// Return `true` if the given OpResult has been decided to bufferize inplace. virtual bool isInPlace(OpOperand &opOperand) const; diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h index d50a3042aeeac..bd23a19f74728 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h @@ -127,9 +127,9 @@ class OneShotAnalysisState : public AnalysisState { /// Return true if the buffer of the given tensor value is writable. bool isWritable(Value value) const; - /// Find the definitions of the given tensor value or retrieve them from the - /// cache. - const SetVector &findDefinitionsCached(Value value); + /// Find the definitions of the given operand's value or + /// retrieve them from the cache. + const SetVector &findDefinitionsCached(OpOperand *opOperand); /// Reset cached data structures. void resetCache() override; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 349841f06959c..1eb27e44810b0 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -480,18 +480,21 @@ bool AnalysisState::isValueRead(Value value) const { return false; } -// Starting from `value`, follow the use-def chain in reverse, always selecting -// the aliasing OpOperands. Find and return Values for which `condition` -// evaluates to true. OpOperands of such matching Values are not traversed any -// further, the visited aliasing opOperands will be preserved through -// `visitedOpOperands`. +// Starting from `opOperand`, follow the use-def chain in reverse, always +// selecting the aliasing OpOperands. Find and return Values for which +// `condition` evaluates to true. Uses of such matching Values are not +// traversed any further, the visited aliasing opOperands will be preserved +// through `visitedOpOperands`. llvm::SetVector AnalysisState::findValueInReverseUseDefChain( - Value value, llvm::function_ref condition, + OpOperand *opOperand, llvm::function_ref condition, TraversalConfig config, llvm::DenseSet *visitedOpOperands) const { llvm::DenseSet visited; llvm::SetVector result, workingSet; - workingSet.insert(value); + workingSet.insert(opOperand->get()); + + if (visitedOpOperands) + visitedOpOperands->insert(opOperand); while (!workingSet.empty()) { Value value = workingSet.pop_back_val(); @@ -563,12 +566,14 @@ llvm::SetVector AnalysisState::findValueInReverseUseDefChain( return result; } -// Find the values that define the contents of the given value. -llvm::SetVector AnalysisState::findDefinitions(Value value) const { +// Find the values that define the contents of the given operand's value. +llvm::SetVector +AnalysisState::findDefinitions(OpOperand *opOperand) const { TraversalConfig config; config.alwaysIncludeLeaves = false; return findValueInReverseUseDefChain( - value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config); + opOperand, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, + config); } AnalysisState::AnalysisState(const BufferizationOptions &options) @@ -892,7 +897,7 @@ bool bufferization::detail::defaultResultBufferizesToMemoryWrite( config.alwaysIncludeLeaves = false; for (AliasingOpOperand alias : opOperands) { if (!state - .findValueInReverseUseDefChain(alias.opOperand->get(), + .findValueInReverseUseDefChain(alias.opOperand, isMemoryWriteInsideOp, config) .empty()) return true; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index 98c3d8d0adc6d..2c4e362101f8f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -143,7 +143,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( // %3 = tensor.insert_slice %2 into ... config.followSameTypeOrCastsOnly = true; SetVector emptyTensors = state.findValueInReverseUseDefChain( - source.get(), /*condition=*/ + &source, /*condition=*/ [&](Value val) { return val.getDefiningOp(); }, config, &visitedOpOperands); @@ -155,10 +155,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { return llvm::count(emptyTensorOp->getUses(), *opOperand); }); - // This could be achieved when a use of `emptyTensorOp` is being - // consumed by `SubsetInsertionOpInterface`'s source directly. - if (iter == visitedOpOperands.end()) - continue; + + assert(iter != visitedOpOperands.end() && "could not find use"); OpOperand *useToBeReplaced = *iter; Operation *user = useToBeReplaced->getOwner(); auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index d1e6acef324fb..fc1b221b4f036 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { // If there is no preceding definition, the tensor contents are // undefined. - if (findDefinitionsCached(opResult).empty()) + if (opResult.getUses().empty()) + continue; + // It does not really matter which use to take to search about + // the value's definitions. + OpOperand *opOperand = &(*opResult.getUses().begin()); + if (findDefinitionsCached(opOperand).empty()) for (OpOperand &use : opResult.getUses()) undefinedTensorUses.insert(&use); } @@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, /// indexing. I.e., the tensor types do not change along the use-def chain, /// apart from static <-> dynamic dim casts. static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, - Value start, Value other) { + OpOperand *start, + Value other) { TraversalConfig config; config.followEquivalentOnly = true; config.alwaysIncludeLeaves = false; @@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, .empty(); } -/// Return "true" if `value` is originating from a subset that is equivalent to -/// the subset that `subsetOp` inserts into. -static bool matchesInsertDestination(const AnalysisState &state, Value value, +/// Return "true" if the given operand's value is originating from a subset +/// that is equivalent to the subset that `subsetOp` inserts into. +static bool matchesInsertDestination(const AnalysisState &state, + OpOperand *opOperand, SubsetInsertionOpInterface subsetOp) { auto matchingSubset = [&](Value val) { if (auto opResult = dyn_cast(val)) @@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value, // There may be multiple leaves at which the reverse SSA use-def chain lookup // terminates. All of them must be equivalent subsets. SetVector backwardSlice = - state.findValueInReverseUseDefChain(value, matchingSubset); + state.findValueInReverseUseDefChain(opOperand, matchingSubset); return static_cast(llvm::all_of(backwardSlice, matchingSubset)); } @@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, // {inplace= [true] } if (uRead == &subsetOp.getDestinationOperand() && - matchesInsertDestination(state, uConflictingWrite->get(), subsetOp)) + matchesInsertDestination(state, uConflictingWrite, subsetOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, if (uRead == &subsetOp.getSourceOperand() && uConflictingWrite == &subsetOp.getDestinationOperand() && - matchesInsertDestination(state, uRead->get(), subsetOp)) + matchesInsertDestination(state, uRead, subsetOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, if (uConflictingWrite == &subsetOp.getDestinationOperand() && state.areEquivalentBufferizedValues( uRead->get(), subsetOp.getSourceOperand().get()) && - matchesInsertDestination(state, subsetOp.getSourceOperand().get(), - subsetOp)) + matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp)) return true; return false; @@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, // even though that op just bufferizes to an allocation but does define // the contents of the buffer. SetVector definitionsOrLeaves = - state.findValueInReverseUseDefChain( - uConflictingWrite->get(), - [&](Value v) { return state.bufferizesToMemoryWrite(v); }); + state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) { + return state.bufferizesToMemoryWrite(v); + }); assert(!definitionsOrLeaves.empty() && "expected at least one definition or leaf"); @@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, // In the above example, if uRead is the OpOperand of reading_op, the // definition is %0. Note that operations that create an alias but do not // bufferize to a memory write (such as ExtractSliceOp) are skipped. - const SetVector &definitions = - state.findDefinitionsCached(uRead->get()); + const SetVector &definitions = state.findDefinitionsCached(uRead); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. LLVM_DEBUG(llvm::dbgs() @@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet &usesRead, if (bufferizableOp.bufferizesToElementwiseAccess( state, {uRead, uConflictingWrite})) { if (hasEquivalentValueInReverseUseDefChain( - state, uRead->get(), uConflictingWrite->get()) || + state, uRead, uConflictingWrite->get()) || hasEquivalentValueInReverseUseDefChain( - state, uConflictingWrite->get(), uRead->get())) { + state, uConflictingWrite, uRead->get())) { LLVM_DEBUG( llvm::dbgs() << " no conflict: op bufferizes to element-wise access\n"); @@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand, // Bufferization analyses. //===----------------------------------------------------------------------===// -// Find the values that define the contents of the given value. +// Find the values that define the contents of the given operand's value. const llvm::SetVector & -OneShotAnalysisState::findDefinitionsCached(Value value) { +OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) { + Value value = opOperand->get(); if (!cachedDefinitions.count(value)) - cachedDefinitions[value] = findDefinitions(value); + cachedDefinitions[value] = findDefinitions(opOperand); return cachedDefinitions[value]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 6801b68a85381..6c1087730ebba 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -553,7 +553,7 @@ Value linalg::bufferizeToAllocation( Value alloc = createAllocationForTensor( rewriter, op->getLoc(), operand->get(), options, memorySpace); allocs.push_back(alloc); - if (!state.findDefinitions(operand->get()).empty()) { + if (!state.findDefinitions(operand).empty()) { // Initialize buffer with a copy of the operand data. Not needed if the // tensor is uninitialized. createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp index 4776883ed95c5..b710bde87f9f3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp @@ -59,7 +59,7 @@ LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( config.followEquivalentOnly = true; config.alwaysIncludeLeaves = false; SetVector emptyTensors = state.findValueInReverseUseDefChain( - in->get(), /*condition=*/ + in, /*condition=*/ [&](Value val) { return val.getDefiningOp() && val.getType() == in->get().getType(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir index 26434774730e1..820fb3dfa5e5e 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -465,3 +465,14 @@ func.func @mutli_use_of_the_same_tensor_empty_creates_non_existent_read(%arg1: t : tensor<5x6x64xf32> into tensor<5x6x128xf32> return %inserted_slice_1, %res_2 : tensor<5x6x128xf32>, tensor<5x6x64xf32> } + +// ----- + +// CHECK-LABEL: func.func @direct_use_of_tensor_empty +func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x128xf32> { + // CHECK-NOT: memref.alloc + %empty_1 = tensor.empty() : tensor<5x6x64xf32> + %inserted_slice_1 = tensor.insert_slice %empty_1 into %arg0[0, 0, 0][5, 6, 64][1, 1, 1] + : tensor<5x6x64xf32> into tensor<5x6x128xf32> + return %inserted_slice_1 : tensor<5x6x128xf32> +}