Skip to content

Commit a4effa6

Browse files
authored
Enable output allocation cache (#2010)
Fixes #2002 checks all IterDomain on outputs and disables verifies that no extent value is a consumer of fusion inputs.
1 parent 35440b7 commit a4effa6

File tree

4 files changed

+41
-2
lines changed

4 files changed

+41
-2
lines changed

torch/csrc/jit/codegen/cuda/executor.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,35 @@ void FusionExecutor::compileFusion(
183183
TORCH_INTERNAL_ASSERT(
184184
out->getValType() == ValType::TensorView,
185185
"Output types from fusions that are not tensors are not supported at this point.");
186+
187+
const auto maybe_rfactor_domain =
188+
out->as<TensorView>()->getMaybeRFactorDomain();
189+
// walking through outputs to see if output shapes are dependent on
190+
// non-tensor inputs. For which case, we should have disabled output
191+
// allocation, since the caching id only looks at tensor shapes.
192+
// See issue https://github.com/csarofeen/pytorch/issues/2002
193+
std::vector<Val*> output_extents;
194+
for (const auto id : maybe_rfactor_domain) {
195+
Val* extent = nullptr;
196+
if (id->isReduction() || id->isStride()) {
197+
continue;
198+
} else if (id->isBroadcast() && id->hasExpandedExtent()) {
199+
extent = id->expandedExtent();
200+
} else {
201+
extent = id->extent();
202+
}
203+
output_extents.emplace_back(extent);
204+
}
205+
auto dependencies = InputsOf::outputs(fusion, output_extents);
206+
if (std::any_of(dependencies.begin(), dependencies.end(), [](Val* val) {
207+
return val->isFusionInput();
208+
})) {
209+
// TODO: parameter cache is too big a hammer here. We should consider
210+
// separate the caching logic of output sizes & launch params. Since
211+
// output size dependency should only invalidate the output sizes
212+
disable_parameter_cache_ = true;
213+
break;
214+
}
186215
}
187216

188217
if (isDebugDumpEnabled(DebugDumpOption::FusionIr)) {

torch/csrc/jit/codegen/cuda/executor.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,10 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable {
310310
// Profiling support: the last launch param used
311311
LaunchParams launch_params_;
312312

313-
// Profiling support: knob to disable caching of launch params
313+
// Profiling support: disable caching of launch params and output allocation
314+
// output allocation is also disable when output sizes are dependent on
315+
// runtime scalar inputs, such as for the case of tensor factory. see
316+
// https://github.com/csarofeen/pytorch/issues/2002
314317
bool disable_parameter_cache_ = false;
315318

316319
// Profiling support: kept copy of the cuda kernel

torch/csrc/jit/codegen/cuda/executor_kernel_arg.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,9 @@ class TORCH_CUDA_CU_API KernelArgumentHolder {
301301
: index_mode_(index_mode) {}
302302

303303
KernelArgumentHolder(const KernelArgumentHolder& self)
304-
: device_index_(self.getDeviceIndex()), index_mode_(self.getIndexMode()) {
304+
: device_index_(self.getDeviceIndex()),
305+
cache_id_(self.getCacheId()),
306+
index_mode_(self.getIndexMode()) {
305307
for (const auto& arg : self.arguments_) {
306308
push(arg.get());
307309
}

torch/csrc/jit/codegen/cuda/kernel_cache.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,11 +621,16 @@ std::vector<at::Tensor> FusionKernelRuntime::runWithInput(
621621
<< std::endl;
622622
}
623623

624+
// group should share cache id.
625+
auto group_cache_id = args.getCacheId();
624626
for (auto group_to_run : runtime_workspace_.group_run_order) {
625627
// TODO: index mode should be updated per segmented kernel
626628
// Prepare input vector
627629
KernelArgumentHolder group_runtime_inputs(args.getIndexMode());
628630
group_runtime_inputs.setDeviceIndex(args.getDeviceIndex());
631+
if (group_cache_id.has_value()) {
632+
group_runtime_inputs.setCacheId(group_cache_id.value());
633+
}
629634
for (auto input : group_to_run->inputs()) {
630635
group_runtime_inputs.push(tensor_map.at(input));
631636
}

0 commit comments

Comments
 (0)