From 4584807ecef2e8b526b9e2fe9ca941fc5101d07a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 6 Sep 2021 15:16:51 +0100 Subject: [PATCH] Resolving tracing problem on StochasticDepth iterator. --- torchvision/ops/stochastic_depth.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/ops/stochastic_depth.py b/torchvision/ops/stochastic_depth.py index 0b95e7cca67..9ba6394326b 100644 --- a/torchvision/ops/stochastic_depth.py +++ b/torchvision/ops/stochastic_depth.py @@ -28,9 +28,10 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) return input survival_rate = 1.0 - p - size = [1] * input.ndim if mode == "row": - size[0] = input.shape[0] + size = [input.shape[0]] + [1] * (input.ndim - 1) + else: + size = [1] * input.ndim noise = torch.empty(size, dtype=input.dtype, device=input.device) noise = noise.bernoulli_(survival_rate).div_(survival_rate) return input * noise