@@ -644,24 +644,173 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
644
644
return replayCasP (consumer, producer, compute_at_axis, root_map);
645
645
}
646
646
647
+ namespace {
648
+
649
+ int getMatchedLeafPosWithoutReplay (
650
+ const TensorView* producer,
651
+ const TensorView* consumer,
652
+ int consumer_or_producer_pos,
653
+ bool consumer_pos = true ) {
654
+ FUSER_PERF_SCOPE (" transform_replay.cpp::getMatchedLeafPosWithoutReplay" );
655
+
656
+ const auto c2p_root_map =
657
+ PairwiseRootDomainMap (producer, consumer)
658
+ .mapConsumerToProducer (consumer->domain (), producer->domain ());
659
+
660
+ // IterDomains in consumer root also in producer root
661
+ std::unordered_set<Val*> mapped_consumer_roots;
662
+ for (auto entry : c2p_root_map) {
663
+ mapped_consumer_roots.emplace (entry.first );
664
+ }
665
+
666
+ const auto consumer_domain = consumer->domain ()->domain ();
667
+
668
+ auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween (
669
+ mapped_consumer_roots, {consumer_domain.begin (), consumer_domain.end ()});
670
+
671
+ std::unordered_set<Val*> mapped_consumer_domain_ids (
672
+ mapped_consumer_domain_ids_vec.begin (),
673
+ mapped_consumer_domain_ids_vec.end ());
674
+
675
+ const auto producer_domain = producer->domain ()->domain ();
676
+
677
+ auto it_consumer = consumer_domain.begin ();
678
+ auto it_producer = producer_domain.begin ();
679
+
680
+ auto best_effort_PasC = BestEffortReplay::replayPasC (
681
+ producer, consumer, -1 , PairwiseRootDomainMap (producer, consumer));
682
+
683
+ auto c2p_map = best_effort_PasC.getReplay ();
684
+
685
+ int mismatched_consumer_pos = 0 ;
686
+ int mismatched_producer_pos = 0 ;
687
+ while (it_consumer != consumer_domain.end ()) {
688
+ auto consumer_id = *it_consumer;
689
+ if (!mapped_consumer_domain_ids.count (consumer_id)) {
690
+ ++it_consumer;
691
+ mismatched_consumer_pos++;
692
+ continue ;
693
+ }
694
+
695
+ auto c2p_it = c2p_map.find (consumer_id);
696
+ if (c2p_it == c2p_map.end ()) {
697
+ break ;
698
+ }
699
+
700
+ if (it_producer == producer_domain.end ()) {
701
+ break ;
702
+ }
703
+
704
+ auto producer_id = *it_producer;
705
+
706
+ if (c2p_it->second == producer_id) {
707
+ ++mismatched_consumer_pos;
708
+ ++mismatched_producer_pos;
709
+ ++it_consumer;
710
+ ++it_producer;
711
+ if (consumer_pos) {
712
+ if (consumer_or_producer_pos == mismatched_consumer_pos) {
713
+ return mismatched_producer_pos;
714
+ }
715
+ } else {
716
+ if (consumer_or_producer_pos == mismatched_producer_pos) {
717
+ return mismatched_consumer_pos;
718
+ }
719
+ }
720
+ } else {
721
+ break ;
722
+ }
723
+ }
724
+ return -1 ;
725
+ }
726
+
727
+ } // namespace
728
+
729
+ int TransformReplay::getMatchedLeafPosWithoutReplayPasC (
730
+ const TensorView* producer,
731
+ const TensorView* consumer,
732
+ int consumer_pos) {
733
+ return getMatchedLeafPosWithoutReplay (producer, consumer, consumer_pos, true );
734
+ }
735
+
736
+ int TransformReplay::getMatchedLeafPosWithoutReplayCasP (
737
+ const TensorView* consumer,
738
+ const TensorView* producer,
739
+ int producer_pos) {
740
+ return getMatchedLeafPosWithoutReplay (
741
+ producer, consumer, producer_pos, false );
742
+ }
743
+
744
+ bool TransformReplay::fullSelfMatching (
745
+ const TensorView* replay,
746
+ const TensorView* target) {
747
+ auto replay_root = replay->getRootDomain ();
748
+ auto replay_dom = replay->domain ()->domain ();
749
+ auto target_root = target->getRootDomain ();
750
+ auto target_dom = target->domain ()->domain ();
751
+ std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
752
+ if (replay_root.size () != target_root.size ()) {
753
+ return false ;
754
+ }
755
+ target2replay_map.reserve (replay_root.size ());
756
+ std::transform (
757
+ target_root.begin (),
758
+ target_root.end (),
759
+ replay_root.begin (),
760
+ std::inserter (target2replay_map, target2replay_map.begin ()),
761
+ [](auto a, auto b) { return std::make_pair (a, b); });
762
+ BestEffortReplay replay_ (replay_dom, target_dom, target2replay_map);
763
+ auto r = replay_.getReplay ();
764
+ for (int64_t i = 0 ; i < replay_dom.size (); i++) {
765
+ auto target_id = target_dom[i];
766
+ auto replay_it = r.find (target_id);
767
+ if (replay_it == r.end () || replay_it->second != replay_dom[i]) {
768
+ return false ;
769
+ }
770
+ }
771
+ return true ;
772
+ }
773
+
647
774
void TransformPropagator::propagateTvPasC (TensorView* from, TensorView* to) {
648
775
int pos = replayed_pos_.at (from);
649
- auto replay = TransformReplay::replayPasC (to, from, pos);
650
- to->setDomain (replay.first );
651
- replayed_pos_[to] = replay.second ;
776
+ // Note: [Using multiple TransformPropagators]
777
+ // There are cases that we use multiple TransformPropagators along different
778
+ // spanning trees with different references in the same fusion. Some of these
779
+ // spanning trees could overlap. In cases when there are overlapping nodes,
780
+ // TransformPropagator needs to respect the replay of others, because the
781
+ // current TransformPropagator might not contain the most amount of
782
+ // information on how to do the correct transformation. The logic below tells
783
+ // TransformPropagator to skip the replay when not necessary.
784
+ int new_pos =
785
+ TransformReplay::getMatchedLeafPosWithoutReplayPasC (to, from, pos);
786
+ if (new_pos < 0 ) {
787
+ auto replay = TransformReplay::replayPasC (to, from, pos);
788
+ to->setDomain (replay.first );
789
+ new_pos = replay.second ;
790
+ }
791
+ replayed_pos_[to] = new_pos;
652
792
}
653
793
654
794
void TransformPropagator::propagateTvCasP (TensorView* from, TensorView* to) {
655
795
int pos = replayed_pos_.at (from);
656
- auto replay = TransformReplay::replayCasP (to, from, pos);
657
- to->setDomain (replay.first );
658
- replayed_pos_[to] = replay.second ;
796
+ // See note [Using multiple TransformPropagators]
797
+ int new_pos =
798
+ TransformReplay::getMatchedLeafPosWithoutReplayCasP (to, from, pos);
799
+ if (new_pos < 0 ) {
800
+ auto replay = TransformReplay::replayCasP (to, from, pos);
801
+ to->setDomain (replay.first );
802
+ new_pos = replay.second ;
803
+ }
804
+ replayed_pos_[to] = new_pos;
659
805
}
660
806
661
807
void TransformPropagator::propagateTvSibling (TensorView* from, TensorView* to) {
662
808
int pos = replayed_pos_.at (from);
663
- auto replay = TransformReplay::fullSelfReplay (to->domain (), from->domain ());
664
- to->setDomain (replay);
809
+ // See note [Using multiple TransformPropagators]
810
+ if (!TransformReplay::fullSelfMatching (to, from)) {
811
+ auto replay = TransformReplay::fullSelfReplay (to->domain (), from->domain ());
812
+ to->setDomain (replay);
813
+ }
665
814
replayed_pos_[to] = pos;
666
815
}
667
816
0 commit comments