@@ -592,15 +592,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
592
592
ir_utils::getReductionOps (fusion /* , ignore_trivial=true */ ).empty (),
593
593
" This scheduler only handles pointwise ops." );
594
594
595
- // For intermediate outputs, apply cacheFork
596
- auto outs = fusion->outputs ();
597
- for (const auto output : outs) {
598
- if (!output->uses ().empty () && output->definition () != nullptr ) {
599
- if (output->getValType ().value () == ValType::TensorView) {
600
- output->as <TensorView>()->cacheFork ();
601
- }
602
- }
603
- }
595
+ // Cache inputs
596
+ auto cached_inputs = scheduler_utils::cacheInputs (fusion, true );
597
+
598
+ // Cache and fork outputs
599
+ auto cached_outputs = scheduler_utils::cacheAndForkOutputs (fusion, true );
604
600
605
601
std::vector<TensorView*> input_tvs;
606
602
{
@@ -637,31 +633,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
637
633
reference_tv != nullptr ,
638
634
" Could not find a fully broadcasted output to reference schedule on." );
639
635
640
- // Caches of inputs
641
- std::vector<TensorView*> cached_inputs;
642
-
643
- // Output, cacheBefore of output
644
- std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;
645
-
646
- // Track what should be vectorized versus unrolled
647
- std::unordered_set<TensorView*> vectorized_tensor;
648
-
649
- // Figure out which inputs to cache for unrolling or vectorization
650
- for (auto inp : input_tvs) {
651
- if (inp->uses ().empty () || inp->isFusionOutput ()) {
652
- continue ;
653
- }
654
- cached_inputs.emplace_back (inp->cacheAfter ());
655
- }
656
-
657
- // Figure out which outputs to cache for unrolling or vectorization
658
- for (auto out : output_tvs) {
659
- if (out->definition () == nullptr ) {
660
- continue ;
661
- }
662
- cached_outputs.emplace_back (std::make_pair (out, out->cacheBefore ()));
663
- }
664
-
665
636
auto all_tvs = ir_utils::allTvs (fusion);
666
637
667
638
// Merge right side of break point
@@ -929,8 +900,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
929
900
// Compute at cached outputs
930
901
// [BIDx, Unswitch, Vectorization, TIDx]
931
902
for (auto entry : cached_outputs) {
932
- auto cached_output = entry.second ;
933
- auto output = entry.first ;
903
+ auto cached_output = entry.first ;
904
+ auto output = entry.second ;
934
905
935
906
auto unswitch_it = std::find_if (
936
907
output->domain ()->domain ().begin (),
0 commit comments