diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index ad19af8cb72..53fc80396d2 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -1,8 +1,6 @@ import torch from torch import nn -from torchvision.ops import misc as misc_nn_ops - from torchvision.ops import MultiScaleRoIAlign from ..utils import load_state_dict_from_url @@ -253,7 +251,7 @@ def __init__(self, in_channels, num_keypoints): def forward(self, x): x = self.kps_score_lowres(x) - x = misc_nn_ops.interpolate( + x = torch.nn.functional.interpolate( x, scale_factor=float(self.up_scale), mode="bilinear", align_corners=False ) return x diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index 74ac5267393..cc6f6083e67 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -5,7 +5,6 @@ from torch import nn, Tensor from torchvision.ops import boxes as box_ops -from torchvision.ops import misc as misc_nn_ops from torchvision.ops import roi_align @@ -175,8 +174,8 @@ def _onnx_heatmaps_to_keypoints(maps, maps_i, roi_map_width, roi_map_height, width_correction = widths_i / roi_map_width height_correction = heights_i / roi_map_height - roi_map = torch.nn.functional.interpolate( - maps_i[None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[0] + roi_map = F.interpolate( + maps_i[:, None], size=(int(roi_map_height), int(roi_map_width)), mode='bicubic', align_corners=False)[:, 0] w = torch.scalar_tensor(roi_map.size(2), dtype=torch.int64) pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) @@ -256,8 +255,8 @@ def heatmaps_to_keypoints(maps, rois): roi_map_height = int(heights_ceil[i].item()) width_correction = widths[i] / roi_map_width height_correction = heights[i] / roi_map_height - roi_map = torch.nn.functional.interpolate( - maps[i][None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[0] + roi_map = F.interpolate( + maps[i][:, None], size=(roi_map_height, roi_map_width), mode='bicubic', align_corners=False)[:, 0] # roi_map_probs = scores_to_probs(roi_map.copy()) w = roi_map.shape[2] pos = roi_map.reshape(num_keypoints, -1).argmax(dim=1) @@ -392,7 +391,7 @@ def paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, -1, -1)) # Resize mask - mask = misc_nn_ops.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(h, w), mode='bilinear', align_corners=False) mask = mask[0][0] im_mask = torch.zeros((im_h, im_w), dtype=mask.dtype, device=mask.device) @@ -420,7 +419,7 @@ def _onnx_paste_mask_in_image(mask, box, im_h, im_w): mask = mask.expand((1, 1, mask.size(0), mask.size(1))) # Resize mask - mask = torch.nn.functional.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) + mask = F.interpolate(mask, size=(int(h), int(w)), mode='bilinear', align_corners=False) mask = mask[0][0] x_0 = torch.max(torch.cat((box[0].unsqueeze(0), zero))) diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 29a2e07d5a2..5564866c571 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -2,10 +2,10 @@ import math import torch from torch import nn, Tensor +from torch.nn import functional as F import torchvision from torch.jit.annotations import List, Tuple, Dict, Optional -from torchvision.ops import misc as misc_nn_ops from .image_list import ImageList from .roi_heads import paste_masks_in_image @@ -28,7 +28,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): if "masks" in target: mask = target["masks"] - mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() target["masks"] = mask return image, target @@ -50,7 +50,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target): if "masks" in target: mask = target["masks"] - mask = misc_nn_ops.interpolate(mask[None].float(), scale_factor=scale_factor)[0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() target["masks"] = mask return image, target diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index ccc82e63cf2..61fab3edd7a 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,7 +1,3 @@ -from collections import OrderedDict -from torch.jit.annotations import Optional, List -from torch import Tensor - """ helper class that supports empty tensors on some nn functions. @@ -12,10 +8,8 @@ is implemented """ -import math import warnings import torch -from torchvision.ops import _new_empty_tensor class Conv2d(torch.nn.Conv2d): @@ -42,51 +36,7 @@ def __init__(self, *args, **kwargs): "removed in future versions, use torch.nn.BatchNorm2d instead.", FutureWarning) -def _check_size_scale_factor(dim, size, scale_factor): - # type: (int, Optional[List[int]], Optional[float]) -> None - if size is None and scale_factor is None: - raise ValueError("either size or scale_factor should be defined") - if size is not None and scale_factor is not None: - raise ValueError("only one of size or scale_factor should be defined") - if scale_factor is not None: - if isinstance(scale_factor, (list, tuple)): - if len(scale_factor) != dim: - raise ValueError( - "scale_factor shape must match input shape. " - "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) - ) - - -def _output_size(dim, input, size, scale_factor): - # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] - assert dim == 2 - _check_size_scale_factor(dim, size, scale_factor) - if size is not None: - return size - # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat - assert scale_factor is not None and isinstance(scale_factor, (int, float)) - scale_factors = [scale_factor, scale_factor] - # math.floor might return float in py2.7 - return [ - int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim) - ] - - -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor - """ - Equivalent to nn.functional.interpolate, but with support for empty batch sizes. - This will eventually be supported natively by PyTorch, and this - class can go away. - """ - if input.numel() > 0: - return torch.nn.functional.interpolate( - input, size, scale_factor, mode, align_corners - ) - - output_shape = _output_size(2, input, size, scale_factor) - output_shape = list(input.shape[:-2]) + list(output_shape) - return _new_empty_tensor(input, output_shape) +interpolate = torch.nn.functional.interpolate # This is not in nn