Skip to content

Commit fe93bf5

Browse files
authored
Transform propagator skip replay when possible (#1782)
This comment in the code describes what this PR is doing: ```C++ // Note: [Using multiple TransformPropagators] // There are cases that we use multiple TransformPropagators along different // spanning trees with different references in the same fusion. Some of these // spanning trees could overlap. In cases when there are overlapping nodes, // TransformPropagator needs to respect the replay of others, because the // current TransformPropagator might not contain the most amount of // information on how to do the correct transformation. The logic below tells // TransformPropagator to skip the replay when not necessary. ```
1 parent ebf23a5 commit fe93bf5

File tree

4 files changed

+225
-110
lines changed

4 files changed

+225
-110
lines changed

torch/csrc/jit/codegen/cuda/compute_at.cpp

Lines changed: 7 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -342,96 +342,6 @@ void ComputeAt::runWith(
342342
ca.runPass();
343343
}
344344

345-
namespace {
346-
347-
// Checks if producer and consumer are transformed consistently so that to
348-
// satisfy the provided compute at position. This means no replay is actually
349-
// necessary for the compute at requested. If consumer_pos then
350-
// consumer_or_producer_pos is relative to the consumer and skipReplay returns
351-
// the associated position in producer.
352-
//
353-
// If producer and consumer are not transformed consistently with provided
354-
// postition, returns -1.
355-
int skipReplay(
356-
const TensorView* producer,
357-
const TensorView* consumer,
358-
int consumer_or_producer_pos,
359-
bool consumer_pos = true) {
360-
FUSER_PERF_SCOPE("transform_replay.cpp::skipReplay");
361-
362-
const auto c2p_root_map =
363-
PairwiseRootDomainMap(producer, consumer)
364-
.mapConsumerToProducer(consumer->domain(), producer->domain());
365-
366-
// IterDomains in consumer root also in producer root
367-
std::unordered_set<Val*> mapped_consumer_roots;
368-
for (auto entry : c2p_root_map) {
369-
mapped_consumer_roots.emplace(entry.first);
370-
}
371-
372-
const auto consumer_domain = consumer->domain()->domain();
373-
374-
auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
375-
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});
376-
377-
std::unordered_set<Val*> mapped_consumer_domain_ids(
378-
mapped_consumer_domain_ids_vec.begin(),
379-
mapped_consumer_domain_ids_vec.end());
380-
381-
const auto producer_domain = producer->domain()->domain();
382-
383-
auto it_consumer = consumer_domain.begin();
384-
auto it_producer = producer_domain.begin();
385-
386-
auto best_effort_PasC = BestEffortReplay::replayPasC(
387-
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));
388-
389-
auto c2p_map = best_effort_PasC.getReplay();
390-
391-
int mismatched_consumer_pos = 0;
392-
int mismatched_producer_pos = 0;
393-
while (it_consumer != consumer_domain.end()) {
394-
auto consumer_id = *it_consumer;
395-
if (!mapped_consumer_domain_ids.count(consumer_id)) {
396-
++it_consumer;
397-
mismatched_consumer_pos++;
398-
continue;
399-
}
400-
401-
auto c2p_it = c2p_map.find(consumer_id);
402-
if (c2p_it == c2p_map.end()) {
403-
break;
404-
}
405-
406-
if (it_producer == producer_domain.end()) {
407-
break;
408-
}
409-
410-
auto producer_id = *it_producer;
411-
412-
if (c2p_it->second == producer_id) {
413-
++mismatched_consumer_pos;
414-
++mismatched_producer_pos;
415-
++it_consumer;
416-
++it_producer;
417-
if (consumer_pos) {
418-
if (consumer_or_producer_pos == mismatched_consumer_pos) {
419-
return mismatched_producer_pos;
420-
}
421-
} else {
422-
if (consumer_or_producer_pos == mismatched_producer_pos) {
423-
return mismatched_consumer_pos;
424-
}
425-
}
426-
} else {
427-
break;
428-
}
429-
}
430-
return -1;
431-
}
432-
433-
} // namespace
434-
435345
// Actually applies transformation
436346
unsigned int ComputeAt::backwardComputeAt_impl(
437347
TensorView* producer,
@@ -460,9 +370,11 @@ unsigned int ComputeAt::backwardComputeAt_impl(
460370
max_consumer_compute_at_pos);
461371
}
462372

463-
// Short cut if no replay is necessary
464-
auto maybe_producer_pos =
465-
skipReplay(producer, consumer, (int)consumer_compute_at_pos, true);
373+
// Checks if producer and consumer are transformed consistently so that to
374+
// satisfy the provided compute at position. This means no replay is actually
375+
// necessary for the compute at requested.
376+
auto maybe_producer_pos = TransformReplay::getMatchedLeafPosWithoutReplayPasC(
377+
producer, consumer, consumer_compute_at_pos);
466378
if (maybe_producer_pos >= 0) {
467379
if (!producer->isFusionInput()) {
468380
producer->setComputeAt((unsigned int)maybe_producer_pos);
@@ -536,8 +448,8 @@ unsigned int ComputeAt::forwardComputeAt_impl(
536448
}
537449

538450
// Short cut if no replay is necessary
539-
auto maybe_consumer_pos =
540-
skipReplay(producer, consumer, (int)producer_compute_at_pos, false);
451+
auto maybe_consumer_pos = TransformReplay::getMatchedLeafPosWithoutReplayCasP(
452+
consumer, producer, producer_compute_at_pos);
541453
if (maybe_consumer_pos > -1) {
542454
if (!producer->isFusionInput()) {
543455
producer->setComputeAt(producer_compute_at_pos);

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23710,7 +23710,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSibling_CUDA) {
2371023710
for (auto tensors : siblings) {
2371123711
for (auto t1 : tensors) {
2371223712
for (auto t2 : tensors) {
23713-
checkSiblingConsistency(t1, t2);
23713+
TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2));
2371423714
}
2371523715
}
2371623716
}
@@ -23769,7 +23769,7 @@ TEST_F(NVFuserTest, FusionTransformPropagateSelectorSibling_CUDA) {
2376923769
for (auto tensors : siblings) {
2377023770
for (auto t1 : tensors) {
2377123771
for (auto t2 : tensors) {
23772-
checkSiblingConsistency(t1, t2);
23772+
TORCH_CHECK(TransformReplay::fullSelfMatching(t1, t2));
2377323773
}
2377423774
}
2377523775
}
@@ -23922,7 +23922,7 @@ TEST_F(NVFuserTest, FusionTransformPropagatorSelector) {
2392223922
TORCH_CHECK(tv4->nDims() == 1);
2392323923
}
2392423924

23925-
TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) {
23925+
TEST_F(NVFuserTest, FusionTransformPropagatorPos_CUDA) {
2392623926
auto fusion = std::make_unique<Fusion>();
2392723927
FusionGuard fg(fusion.get());
2392823928

@@ -23939,10 +23939,9 @@ TEST_F(NVFuserTest, FusionTransormPropagatorPos_CUDA) {
2393923939
TransformPropagator propagator(tv1, 2);
2394023940
MaxRootDomainInfoSpanningTree(tv1, 2).traverse(&propagator);
2394123941

23942-
TORCH_CHECK(tv0->nDims() == 3);
23943-
TORCH_CHECK(tv0->axis(0)->extent()->evaluateInt() == 11);
23944-
TORCH_CHECK(tv0->axis(1)->extent()->evaluateInt() == 2);
23945-
TORCH_CHECK(tv0->axis(2)->extent()->evaluateInt() == 105);
23942+
auto expect = makeConcreteTensor({22, 105});
23943+
expect->split(0, 2);
23944+
TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv0));
2394623945
}
2394723946

2394823947
TEST_F(NVFuserTest, FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA) {
@@ -23996,6 +23995,40 @@ to: 2
2399623995
TORCH_CHECK(printer2.ss.str() == expect);
2399723996
}
2399823997

23998+
TEST_F(NVFuserTest, FusionTransformPropagatorNoOverwrite_CUDA) {
23999+
auto fusion = std::make_unique<Fusion>();
24000+
FusionGuard fg(fusion.get());
24001+
24002+
auto tv0 = makeSymbolicTensor(1);
24003+
fusion->addInput(tv0);
24004+
auto tv1 = broadcast(tv0, {true, false, true});
24005+
auto tv2 = sin(tv1);
24006+
fusion->addOutput(tv2);
24007+
24008+
tv0->split(0, 2);
24009+
tv2->split(1, 2);
24010+
tv2->split(0, 4);
24011+
24012+
MaxRootDomainInfoSpanningTree path1(tv2);
24013+
TransformPropagator propagator1(tv2);
24014+
path1.traverse(&propagator1);
24015+
24016+
MaxRootDomainInfoSpanningTree path2(tv0);
24017+
TransformPropagator propagator2(tv0);
24018+
path2.traverse(&propagator2);
24019+
24020+
TORCH_CHECK(tv1->axis(0)->isBroadcast());
24021+
TORCH_CHECK(tv1->axis(1)->isBroadcast());
24022+
TORCH_CHECK(!tv1->axis(2)->isBroadcast());
24023+
TORCH_CHECK(!tv1->axis(3)->isBroadcast());
24024+
TORCH_CHECK(tv1->axis(4)->isBroadcast());
24025+
24026+
auto expect = makeSymbolicTensor(3);
24027+
expect->split(1, 2);
24028+
expect->split(0, 4);
24029+
TORCH_CHECK(TransformReplay::fullSelfMatching(expect, tv1));
24030+
}
24031+
2399924032
TEST_F(NVFuserTest, FusionIssue1785Repro_CUDA) {
2400024033
Fusion fusion;
2400124034
FusionGuard fg(&fusion);

torch/csrc/jit/codegen/cuda/transform_replay.cpp

Lines changed: 157 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -644,24 +644,173 @@ std::pair<TensorDomain*, unsigned int> TransformReplay::replayCasP(
644644
return replayCasP(consumer, producer, compute_at_axis, root_map);
645645
}
646646

647+
namespace {
648+
649+
int getMatchedLeafPosWithoutReplay(
650+
const TensorView* producer,
651+
const TensorView* consumer,
652+
int consumer_or_producer_pos,
653+
bool consumer_pos = true) {
654+
FUSER_PERF_SCOPE("transform_replay.cpp::getMatchedLeafPosWithoutReplay");
655+
656+
const auto c2p_root_map =
657+
PairwiseRootDomainMap(producer, consumer)
658+
.mapConsumerToProducer(consumer->domain(), producer->domain());
659+
660+
// IterDomains in consumer root also in producer root
661+
std::unordered_set<Val*> mapped_consumer_roots;
662+
for (auto entry : c2p_root_map) {
663+
mapped_consumer_roots.emplace(entry.first);
664+
}
665+
666+
const auto consumer_domain = consumer->domain()->domain();
667+
668+
auto mapped_consumer_domain_ids_vec = DependencyCheck::getAllValsBetween(
669+
mapped_consumer_roots, {consumer_domain.begin(), consumer_domain.end()});
670+
671+
std::unordered_set<Val*> mapped_consumer_domain_ids(
672+
mapped_consumer_domain_ids_vec.begin(),
673+
mapped_consumer_domain_ids_vec.end());
674+
675+
const auto producer_domain = producer->domain()->domain();
676+
677+
auto it_consumer = consumer_domain.begin();
678+
auto it_producer = producer_domain.begin();
679+
680+
auto best_effort_PasC = BestEffortReplay::replayPasC(
681+
producer, consumer, -1, PairwiseRootDomainMap(producer, consumer));
682+
683+
auto c2p_map = best_effort_PasC.getReplay();
684+
685+
int mismatched_consumer_pos = 0;
686+
int mismatched_producer_pos = 0;
687+
while (it_consumer != consumer_domain.end()) {
688+
auto consumer_id = *it_consumer;
689+
if (!mapped_consumer_domain_ids.count(consumer_id)) {
690+
++it_consumer;
691+
mismatched_consumer_pos++;
692+
continue;
693+
}
694+
695+
auto c2p_it = c2p_map.find(consumer_id);
696+
if (c2p_it == c2p_map.end()) {
697+
break;
698+
}
699+
700+
if (it_producer == producer_domain.end()) {
701+
break;
702+
}
703+
704+
auto producer_id = *it_producer;
705+
706+
if (c2p_it->second == producer_id) {
707+
++mismatched_consumer_pos;
708+
++mismatched_producer_pos;
709+
++it_consumer;
710+
++it_producer;
711+
if (consumer_pos) {
712+
if (consumer_or_producer_pos == mismatched_consumer_pos) {
713+
return mismatched_producer_pos;
714+
}
715+
} else {
716+
if (consumer_or_producer_pos == mismatched_producer_pos) {
717+
return mismatched_consumer_pos;
718+
}
719+
}
720+
} else {
721+
break;
722+
}
723+
}
724+
return -1;
725+
}
726+
727+
} // namespace
728+
729+
int TransformReplay::getMatchedLeafPosWithoutReplayPasC(
730+
const TensorView* producer,
731+
const TensorView* consumer,
732+
int consumer_pos) {
733+
return getMatchedLeafPosWithoutReplay(producer, consumer, consumer_pos, true);
734+
}
735+
736+
int TransformReplay::getMatchedLeafPosWithoutReplayCasP(
737+
const TensorView* consumer,
738+
const TensorView* producer,
739+
int producer_pos) {
740+
return getMatchedLeafPosWithoutReplay(
741+
producer, consumer, producer_pos, false);
742+
}
743+
744+
bool TransformReplay::fullSelfMatching(
745+
const TensorView* replay,
746+
const TensorView* target) {
747+
auto replay_root = replay->getRootDomain();
748+
auto replay_dom = replay->domain()->domain();
749+
auto target_root = target->getRootDomain();
750+
auto target_dom = target->domain()->domain();
751+
std::unordered_map<IterDomain*, IterDomain*> target2replay_map;
752+
if (replay_root.size() != target_root.size()) {
753+
return false;
754+
}
755+
target2replay_map.reserve(replay_root.size());
756+
std::transform(
757+
target_root.begin(),
758+
target_root.end(),
759+
replay_root.begin(),
760+
std::inserter(target2replay_map, target2replay_map.begin()),
761+
[](auto a, auto b) { return std::make_pair(a, b); });
762+
BestEffortReplay replay_(replay_dom, target_dom, target2replay_map);
763+
auto r = replay_.getReplay();
764+
for (int64_t i = 0; i < replay_dom.size(); i++) {
765+
auto target_id = target_dom[i];
766+
auto replay_it = r.find(target_id);
767+
if (replay_it == r.end() || replay_it->second != replay_dom[i]) {
768+
return false;
769+
}
770+
}
771+
return true;
772+
}
773+
647774
void TransformPropagator::propagateTvPasC(TensorView* from, TensorView* to) {
648775
int pos = replayed_pos_.at(from);
649-
auto replay = TransformReplay::replayPasC(to, from, pos);
650-
to->setDomain(replay.first);
651-
replayed_pos_[to] = replay.second;
776+
// Note: [Using multiple TransformPropagators]
777+
// There are cases that we use multiple TransformPropagators along different
778+
// spanning trees with different references in the same fusion. Some of these
779+
// spanning trees could overlap. In cases when there are overlapping nodes,
780+
// TransformPropagator needs to respect the replay of others, because the
781+
// current TransformPropagator might not contain the most amount of
782+
// information on how to do the correct transformation. The logic below tells
783+
// TransformPropagator to skip the replay when not necessary.
784+
int new_pos =
785+
TransformReplay::getMatchedLeafPosWithoutReplayPasC(to, from, pos);
786+
if (new_pos < 0) {
787+
auto replay = TransformReplay::replayPasC(to, from, pos);
788+
to->setDomain(replay.first);
789+
new_pos = replay.second;
790+
}
791+
replayed_pos_[to] = new_pos;
652792
}
653793

654794
void TransformPropagator::propagateTvCasP(TensorView* from, TensorView* to) {
655795
int pos = replayed_pos_.at(from);
656-
auto replay = TransformReplay::replayCasP(to, from, pos);
657-
to->setDomain(replay.first);
658-
replayed_pos_[to] = replay.second;
796+
// See note [Using multiple TransformPropagators]
797+
int new_pos =
798+
TransformReplay::getMatchedLeafPosWithoutReplayCasP(to, from, pos);
799+
if (new_pos < 0) {
800+
auto replay = TransformReplay::replayCasP(to, from, pos);
801+
to->setDomain(replay.first);
802+
new_pos = replay.second;
803+
}
804+
replayed_pos_[to] = new_pos;
659805
}
660806

661807
void TransformPropagator::propagateTvSibling(TensorView* from, TensorView* to) {
662808
int pos = replayed_pos_.at(from);
663-
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
664-
to->setDomain(replay);
809+
// See note [Using multiple TransformPropagators]
810+
if (!TransformReplay::fullSelfMatching(to, from)) {
811+
auto replay = TransformReplay::fullSelfReplay(to->domain(), from->domain());
812+
to->setDomain(replay);
813+
}
665814
replayed_pos_[to] = pos;
666815
}
667816

0 commit comments

Comments
 (0)