Skip to content

Commit b40f49f

Browse files
authored
Remove interpolate in favor of PyTorch's implementation (#2252)
* Remove interpolate in favor of PyTorch's implementation * Bugfix * Bugfix
1 parent 98aa805 commit b40f49f

File tree

4 files changed

+11
-64
lines changed

4 files changed

+11
-64
lines changed

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import torch
22
from torch import nn
33

4-
from torchvision.ops import misc as misc_nn_ops
5-
64
from torchvision.ops import MultiScaleRoIAlign
75

86
from ..utils import load_state_dict_from_url
@@ -253,7 +251,7 @@ def __init__(self, in_channels, num_keypoints):
253251

254252
def forward(self, x):
255253
x = self.kps_score_lowres(x)
256-
x = misc_nn_ops.interpolate(
254+
x = torch.nn.functional.interpolate(
257255
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
258256
)
259257
return x

torchvision/models/detection/roi_heads.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch import nn, Tensor
66

77
from torchvision.ops import boxes as box_ops
8-
from torchvision.ops import misc as misc_nn_ops
98

109
from torchvision.ops import roi_align
1110

@@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
175174
width_correction = widths_i / roi_map_width
176175
height_correction = heights_i / roi_map_height
177176

178-
roi_map = torch.nn.functional.interpolate(
179-
maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0]
177+
roi_map = F.interpolate(
178+
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0]
180179

181180
w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
182181
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
@@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois):
256255
roi_map_height = int(heights_ceil[i].item())
257256
width_correction = widths[i] / roi_map_width
258257
height_correction = heights[i] / roi_map_height
259-
roi_map = torch.nn.functional.interpolate(
260-
maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0]
258+
roi_map = F.interpolate(
259+
maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0]
261260
# roi_map_probs = scores_to_probs(roi_map.copy())
262261
w = roi_map.shape[2]
263262
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
@@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
392391
mask = mask.expand((1, 1, -1, -1))
393392

394393
# Resize mask
395-
mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
394+
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
396395
mask = mask[0][0]
397396

398397
im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
@@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
420419
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))
421420

422421
# Resize mask
423-
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
422+
mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
424423
mask = mask[0][0]
425424

426425
x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))

torchvision/models/detection/transform.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
import math
33
import torch
44
from torch import nn, Tensor
5+
from torch.nn import functional as F
56
import torchvision
67
from torch.jit.annotations import List, Tuple, Dict, Optional
78

8-
from torchvision.ops import misc as misc_nn_ops
99
from .image_list import ImageList
1010
from .roi_heads import paste_masks_in_image
1111

@@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
2828

2929
if "masks" in target:
3030
mask = target["masks"]
31-
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
31+
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
3232
target["masks"] = mask
3333
return image, target
3434

@@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target):
5050

5151
if "masks" in target:
5252
mask = target["masks"]
53-
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
53+
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
5454
target["masks"] = mask
5555
return image, target
5656

torchvision/ops/misc.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
from collections import OrderedDict
2-
from torch.jit.annotations import Optional, List
3-
from torch import Tensor
4-
51
"""
62
helper class that supports empty tensors on some nn functions.
73
@@ -12,10 +8,8 @@
128
is implemented
139
"""
1410

15-
import math
1611
import warnings
1712
import torch
18-
from torchvision.ops import _new_empty_tensor
1913

2014

2115
class Conv2d(torch.nn.Conv2d):
@@ -42,51 +36,7 @@ def __init__(self, *args, **kwargs):
4236
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)
4337

4438

45-
def _check_size_scale_factor(dim, size, scale_factor):
46-
# type: (int, Optional[List[int]], Optional[float]) -> None
47-
if size is None and scale_factor is None:
48-
raise ValueError("either size or scale_factor should be defined")
49-
if size is not None and scale_factor is not None:
50-
raise ValueError("only one of size or scale_factor should be defined")
51-
if scale_factor is not None:
52-
if isinstance(scale_factor, (list, tuple)):
53-
if len(scale_factor) != dim:
54-
raise ValueError(
55-
"scale_factor shape must match input shape. "
56-
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
57-
)
58-
59-
60-
def _output_size(dim, input, size, scale_factor):
61-
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
62-
assert dim == 2
63-
_check_size_scale_factor(dim, size, scale_factor)
64-
if size is not None:
65-
return size
66-
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
67-
assert scale_factor is not None and isinstance(scale_factor, (int, float))
68-
scale_factors = [scale_factor, scale_factor]
69-
# math.floor might return float in py2.7
70-
return [
71-
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
72-
]
73-
74-
75-
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
76-
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
77-
"""
78-
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
79-
This will eventually be supported natively by PyTorch, and this
80-
class can go away.
81-
"""
82-
if input.numel() > 0:
83-
return torch.nn.functional.interpolate(
84-
input, size, scale_factor, mode, align_corners
85-
)
86-
87-
output_shape = _output_size(2, input, size, scale_factor)
88-
output_shape = list(input.shape[:-2]) + list(output_shape)
89-
return _new_empty_tensor(input, output_shape)
39+
interpolate = torch.nn.functional.interpolate
9040

9141

9242
# This is not in nn

0 commit comments

Comments
 (0)