@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
196
196
197
197
// If there is no preceding definition, the tensor contents are
198
198
// undefined.
199
- if (findDefinitionsCached (opResult).empty ())
199
+ if (opResult.getUses ().empty ())
200
+ continue ;
201
+ // It does not really matter which use to take to search about
202
+ // the value's definitions.
203
+ OpOperand *opOperand = &(*opResult.getUses ().begin ());
204
+ if (findDefinitionsCached (opOperand).empty ())
200
205
for (OpOperand &use : opResult.getUses ())
201
206
undefinedTensorUses.insert (&use);
202
207
}
@@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
464
469
// / indexing. I.e., the tensor types do not change along the use-def chain,
465
470
// / apart from static <-> dynamic dim casts.
466
471
static bool hasEquivalentValueInReverseUseDefChain (AnalysisState &state,
467
- Value start, Value other) {
472
+ OpOperand *start,
473
+ Value other) {
468
474
TraversalConfig config;
469
475
config.followEquivalentOnly = true ;
470
476
config.alwaysIncludeLeaves = false ;
@@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
475
481
.empty ();
476
482
}
477
483
478
- // / Return "true" if `value` is originating from a subset that is equivalent to
479
- // / the subset that `subsetOp` inserts into.
480
- static bool matchesInsertDestination (const AnalysisState &state, Value value,
484
+ // / Return "true" if the given operand's value is originating from a subset
485
+ // / that is equivalent to the subset that `subsetOp` inserts into.
486
+ static bool matchesInsertDestination (const AnalysisState &state,
487
+ OpOperand *opOperand,
481
488
SubsetInsertionOpInterface subsetOp) {
482
489
auto matchingSubset = [&](Value val) {
483
490
if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
490
497
// There may be multiple leaves at which the reverse SSA use-def chain lookup
491
498
// terminates. All of them must be equivalent subsets.
492
499
SetVector<Value> backwardSlice =
493
- state.findValueInReverseUseDefChain (value , matchingSubset);
500
+ state.findValueInReverseUseDefChain (opOperand , matchingSubset);
494
501
return static_cast <bool >(llvm::all_of (backwardSlice, matchingSubset));
495
502
}
496
503
@@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
516
523
// {inplace= [true] }
517
524
518
525
if (uRead == &subsetOp.getDestinationOperand () &&
519
- matchesInsertDestination (state, uConflictingWrite-> get () , subsetOp))
526
+ matchesInsertDestination (state, uConflictingWrite, subsetOp))
520
527
// Case 1: The main insight is that InsertSliceOp reads only part of
521
528
// the destination tensor. The overwritten area is not read. If
522
529
// uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
533
540
534
541
if (uRead == &subsetOp.getSourceOperand () &&
535
542
uConflictingWrite == &subsetOp.getDestinationOperand () &&
536
- matchesInsertDestination (state, uRead-> get () , subsetOp))
543
+ matchesInsertDestination (state, uRead, subsetOp))
537
544
// Case 2: The read of the source tensor and the write to the dest
538
545
// tensor via an InsertSliceOp is not a conflict if the read is
539
546
// reading exactly that part of an equivalent tensor that the
@@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
567
574
if (uConflictingWrite == &subsetOp.getDestinationOperand () &&
568
575
state.areEquivalentBufferizedValues (
569
576
uRead->get (), subsetOp.getSourceOperand ().get ()) &&
570
- matchesInsertDestination (state, subsetOp.getSourceOperand ().get (),
571
- subsetOp))
577
+ matchesInsertDestination (state, &subsetOp.getSourceOperand (), subsetOp))
572
578
return true ;
573
579
574
580
return false ;
@@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
600
606
// even though that op just bufferizes to an allocation but does define
601
607
// the contents of the buffer.
602
608
SetVector<Value> definitionsOrLeaves =
603
- state.findValueInReverseUseDefChain (
604
- uConflictingWrite-> get (),
605
- [&](Value v) { return state. bufferizesToMemoryWrite (v); });
609
+ state.findValueInReverseUseDefChain (uConflictingWrite, [&](Value v) {
610
+ return state. bufferizesToMemoryWrite (v);
611
+ });
606
612
assert (!definitionsOrLeaves.empty () &&
607
613
" expected at least one definition or leaf" );
608
614
@@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
641
647
// In the above example, if uRead is the OpOperand of reading_op, the
642
648
// definition is %0. Note that operations that create an alias but do not
643
649
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
644
- const SetVector<Value> &definitions =
645
- state.findDefinitionsCached (uRead->get ());
650
+ const SetVector<Value> &definitions = state.findDefinitionsCached (uRead);
646
651
if (definitions.empty ()) {
647
652
// Fast path: No conflict if there are no definitions.
648
653
LLVM_DEBUG (llvm::dbgs ()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
714
719
if (bufferizableOp.bufferizesToElementwiseAccess (
715
720
state, {uRead, uConflictingWrite})) {
716
721
if (hasEquivalentValueInReverseUseDefChain (
717
- state, uRead-> get () , uConflictingWrite->get ()) ||
722
+ state, uRead, uConflictingWrite->get ()) ||
718
723
hasEquivalentValueInReverseUseDefChain (
719
- state, uConflictingWrite-> get () , uRead->get ())) {
724
+ state, uConflictingWrite, uRead->get ())) {
720
725
LLVM_DEBUG (
721
726
llvm::dbgs ()
722
727
<< " no conflict: op bufferizes to element-wise access\n " );
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
965
970
// Bufferization analyses.
966
971
// ===----------------------------------------------------------------------===//
967
972
968
- // Find the values that define the contents of the given value.
973
+ // Find the values that define the contents of the given operand's value.
969
974
const llvm::SetVector<Value> &
970
- OneShotAnalysisState::findDefinitionsCached (Value value) {
975
+ OneShotAnalysisState::findDefinitionsCached (OpOperand *opOperand) {
976
+ Value value = opOperand->get ();
971
977
if (!cachedDefinitions.count (value))
972
- cachedDefinitions[value] = findDefinitions (value );
978
+ cachedDefinitions[value] = findDefinitions (opOperand );
973
979
return cachedDefinitions[value];
974
980
}
975
981
0 commit comments