Skip to content

Remove interpolate in favor of PyTorch's implementation #2252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
from torch import nn

from torchvision.ops import misc as misc_nn_ops

from torchvision.ops import MultiScaleRoIAlign

from ..utils import load_state_dict_from_url
Expand Down Expand Up @@ -253,7 +251,7 @@ def __init__(self, in_channels, num_keypoints):

def forward(self, x):
x = self.kps_score_lowres(x)
x = misc_nn_ops.interpolate(
x = torch.nn.functional.interpolate(
x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False
)
return x
Expand Down
13 changes: 6 additions & 7 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from torch import nn, Tensor

from torchvision.ops import boxes as box_ops
from torchvision.ops import misc as misc_nn_ops

from torchvision.ops import roi_align

Expand Down Expand Up @@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height,
width_correction = widths_i / roi_map_width
height_correction = heights_i / roi_map_height

roi_map = torch.nn.functional.interpolate(
maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0]
roi_map = F.interpolate(
maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0]

w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64)
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
Expand Down Expand Up @@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois):
roi_map_height = int(heights_ceil[i].item())
width_correction = widths[i] / roi_map_width
height_correction = heights[i] / roi_map_height
roi_map = torch.nn.functional.interpolate(
maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0]
roi_map = F.interpolate(
maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0]
# roi_map_probs = scores_to_probs(roi_map.copy())
w = roi_map.shape[2]
pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1)
Expand Down Expand Up @@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, -1, -1))

# Resize mask
mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False)
mask = mask[0][0]

im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device)
Expand Down Expand Up @@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w):
mask = mask.expand((1, 1, mask.size(0), mask.size(1)))

# Resize mask
mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False)
mask = mask[0][0]

x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero)))
Expand Down
6 changes: 3 additions & 3 deletions torchvision/models/detection/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import math
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import torchvision
from torch.jit.annotations import List, Tuple, Dict, Optional

from torchvision.ops import misc as misc_nn_ops
from .image_list import ImageList
from .roi_heads import paste_masks_in_image

Expand All @@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):

if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask
return image, target

Expand All @@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target):

if "masks" in target:
mask = target["masks"]
mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte()
mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte()
target["masks"] = mask
return image, target

Expand Down
52 changes: 1 addition & 51 deletions torchvision/ops/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
from collections import OrderedDict
from torch.jit.annotations import Optional, List
from torch import Tensor

"""
helper class that supports empty tensors on some nn functions.

Expand All @@ -12,10 +8,8 @@
is implemented
"""

import math
import warnings
import torch
from torchvision.ops import _new_empty_tensor


class Conv2d(torch.nn.Conv2d):
Expand All @@ -42,51 +36,7 @@ def __init__(self, *args, **kwargs):
"removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning)


def _check_size_scale_factor(dim, size, scale_factor):
# type: (int, Optional[List[int]], Optional[float]) -> None
if size is None and scale_factor is None:
raise ValueError("either size or scale_factor should be defined")
if size is not None and scale_factor is not None:
raise ValueError("only one of size or scale_factor should be defined")
if scale_factor is not None:
if isinstance(scale_factor, (list, tuple)):
if len(scale_factor) != dim:
raise ValueError(
"scale_factor shape must match input shape. "
"Input is {}D, scale_factor size is {}".format(dim, len(scale_factor))
)


def _output_size(dim, input, size, scale_factor):
# type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int]
assert dim == 2
_check_size_scale_factor(dim, size, scale_factor)
if size is not None:
return size
# if dim is not 2 or scale_factor is iterable use _ntuple instead of concat
assert scale_factor is not None and isinstance(scale_factor, (int, float))
scale_factors = [scale_factor, scale_factor]
# math.floor might return float in py2.7
return [
int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)
]


def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)

output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
interpolate = torch.nn.functional.interpolate


# This is not in nn
Expand Down