Skip to content

Commit b096271

Browse files
authored
Updated annotation to be a Union of Tensor and List (#4416)
* Updated annotation to be a Union of Tensor and List * Updated check_roi_boxes_shape.
1 parent 719e120 commit b096271

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

torchvision/ops/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import Tensor
3-
from typing import List
3+
from typing import List, Union
44

55

66
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
@@ -24,7 +24,7 @@ def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
2424
return rois
2525

2626

27-
def check_roi_boxes_shape(boxes: Tensor):
27+
def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
2828
if isinstance(boxes, (list, tuple)):
2929
for _tensor in boxes:
3030
assert _tensor.size(1) == 4, \

torchvision/ops/roi_align.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
from typing import List, Union
2+
13
import torch
24
from torch import nn, Tensor
3-
45
from torch.nn.modules.utils import _pair
56
from torch.jit.annotations import BroadcastingList2
6-
77
from torchvision.extension import _assert_has_ops
8+
89
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
910

1011

1112
def roi_align(
1213
input: Tensor,
13-
boxes: Tensor,
14+
boxes: Union[Tensor, List[Tensor]],
1415
output_size: BroadcastingList2[int],
1516
spatial_scale: float = 1.0,
1617
sampling_ratio: int = -1,

torchvision/ops/roi_pool.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
from typing import List, Union
2+
13
import torch
24
from torch import nn, Tensor
3-
45
from torch.nn.modules.utils import _pair
56
from torch.jit.annotations import BroadcastingList2
6-
77
from torchvision.extension import _assert_has_ops
8+
89
from ._utils import convert_boxes_to_roi_format, check_roi_boxes_shape
910

1011

1112
def roi_pool(
1213
input: Tensor,
13-
boxes: Tensor,
14+
boxes: Union[Tensor, List[Tensor]],
1415
output_size: BroadcastingList2[int],
1516
spatial_scale: float = 1.0,
1617
) -> Tensor:

0 commit comments

Comments
 (0)