Skip to content

Commit ccbb150

Browse files
committed
Conditional cast for onnx.
1 parent e97a6d9 commit ccbb150

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

torchvision/models/detection/transform.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,27 +11,28 @@
1111

1212
@torch.jit.unused
1313
def _get_shape_onnx(image):
14-
# type: (Tensor) -> Tensor
1514
from torch.onnx import operators
1615
return operators.shape_as_tensor(image)[-2:]
1716

17+
@torch.jit.unused
18+
def _float_to_tensor_onnx(v):
19+
return torch.tensor(v)
20+
1821

1922
def _resize_image_and_masks(image, self_min_size, self_max_size, target):
2023
# type: (Tensor, float, float, Optional[Dict[str, Tensor]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]
2124

2225
if torchvision._is_tracing():
2326
im_shape = _get_shape_onnx(image)
24-
cast_to_float = False
2527
else:
2628
im_shape = torch.tensor(image.shape[-2:])
27-
cast_to_float = True
2829

2930
min_size = torch.min(im_shape).to(dtype=torch.float32)
3031
max_size = torch.max(im_shape).to(dtype=torch.float32)
31-
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
32+
scale_factor = float(torch.min(self_min_size / min_size, self_max_size / max_size))
3233

33-
if cast_to_float:
34-
scale_factor = float(scale_factor)
34+
if torchvision._is_tracing():
35+
scale_factor = _float_to_tensor_onnx(scale_factor)
3536

3637
image = torch.nn.functional.interpolate(
3738
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,

0 commit comments

Comments
 (0)