Skip to content

Commit e97a6d9

Browse files
committed
Conditional cast.
1 parent c83beaf commit e97a6d9

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchvision/models/detection/transform.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target):
2121

2222
if torchvision._is_tracing():
2323
im_shape = _get_shape_onnx(image)
24+
cast_to_float = False
2425
else:
2526
im_shape = torch.tensor(image.shape[-2:])
27+
cast_to_float = True
2628

2729
min_size = torch.min(im_shape).to(dtype=torch.float32)
2830
max_size = torch.max(im_shape).to(dtype=torch.float32)
29-
scale_factor = float(torch.min(self_min_size / min_size, self_max_size / max_size))
31+
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
32+
33+
if cast_to_float:
34+
scale_factor = float(scale_factor)
3035

3136
image = torch.nn.functional.interpolate(
3237
image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True,

0 commit comments

Comments
 (0)