Skip to content

Commit 3df9742

Browse files
authored
Use scheduler_utils to cache inputs and outputs in schedulePointwise (#1811)
1 parent 03180aa commit 3df9742

File tree

3 files changed

+11
-42
lines changed

3 files changed

+11
-42
lines changed

torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -983,9 +983,8 @@ TORCH_CUDA_CU_API void schedulePersistentKernel(
983983
// Cache inputs if unrolled
984984
auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll);
985985

986-
// Cache and fork outputs
987-
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
988-
scheduler_utils::cacheAndForkOutputs(fusion, unroll);
986+
// Cache and fork outputs
987+
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll);
989988

990989
// Make sure we don't have global memory set on intermediate tensors from
991990
// fusion segmentation

torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -592,15 +592,11 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
592592
ir_utils::getReductionOps(fusion /*, ignore_trivial=true */).empty(),
593593
"This scheduler only handles pointwise ops.");
594594

595-
// For intermediate outputs, apply cacheFork
596-
auto outs = fusion->outputs();
597-
for (const auto output : outs) {
598-
if (!output->uses().empty() && output->definition() != nullptr) {
599-
if (output->getValType().value() == ValType::TensorView) {
600-
output->as<TensorView>()->cacheFork();
601-
}
602-
}
603-
}
595+
// Cache inputs
596+
auto cached_inputs = scheduler_utils::cacheInputs(fusion, true);
597+
598+
// Cache and fork outputs
599+
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, true);
604600

605601
std::vector<TensorView*> input_tvs;
606602
{
@@ -637,31 +633,6 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
637633
reference_tv != nullptr,
638634
"Could not find a fully broadcasted output to reference schedule on.");
639635

640-
// Caches of inputs
641-
std::vector<TensorView*> cached_inputs;
642-
643-
// Output, cacheBefore of output
644-
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs;
645-
646-
// Track what should be vectorized versus unrolled
647-
std::unordered_set<TensorView*> vectorized_tensor;
648-
649-
// Figure out which inputs to cache for unrolling or vectorization
650-
for (auto inp : input_tvs) {
651-
if (inp->uses().empty() || inp->isFusionOutput()) {
652-
continue;
653-
}
654-
cached_inputs.emplace_back(inp->cacheAfter());
655-
}
656-
657-
// Figure out which outputs to cache for unrolling or vectorization
658-
for (auto out : output_tvs) {
659-
if (out->definition() == nullptr) {
660-
continue;
661-
}
662-
cached_outputs.emplace_back(std::make_pair(out, out->cacheBefore()));
663-
}
664-
665636
auto all_tvs = ir_utils::allTvs(fusion);
666637

667638
// Merge right side of break point
@@ -929,8 +900,8 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams& params) {
929900
// Compute at cached outputs
930901
//[BIDx, Unswitch, Vectorization, TIDx]
931902
for (auto entry : cached_outputs) {
932-
auto cached_output = entry.second;
933-
auto output = entry.first;
903+
auto cached_output = entry.first;
904+
auto output = entry.second;
934905

935906
auto unswitch_it = std::find_if(
936907
output->domain()->domain().begin(),

torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,9 +1002,8 @@ void scheduleReduction(Fusion* fusion, const ReductionParams& rparams) {
10021002
// Cache inputs if unrolled
10031003
auto cached_inputs = scheduler_utils::cacheInputs(fusion, unroll);
10041004

1005-
// Cache and fork outputs
1006-
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs =
1007-
scheduler_utils::cacheAndForkOutputs(fusion, unroll);
1005+
// Cache and fork outputs
1006+
auto cached_outputs = scheduler_utils::cacheAndForkOutputs(fusion, unroll);
10081007

10091008
// Make sure we don't have global memory set on intermediate tensors from
10101009
// fusion segmentation

0 commit comments

Comments
 (0)