Skip to content

Commit b56b5e1

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into proto-bbox-center-crop
2 parents 28c380d + f079f5a commit b56b5e1

16 files changed

+436
-30
lines changed

docs/source/models.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ You can construct a model with random weights by calling its constructor:
6161
mobilenet_v3_large = models.mobilenet_v3_large()
6262
mobilenet_v3_small = models.mobilenet_v3_small()
6363
resnext50_32x4d = models.resnext50_32x4d()
64+
resnext101_32x8d = models.resnext101_32x8d()
65+
resnext101_64x4d = models.resnext101_64x4d()
6466
wide_resnet50_2 = models.wide_resnet50_2()
6567
mnasnet = models.mnasnet1_0()
6668
efficientnet_b0 = models.efficientnet_b0()
@@ -185,6 +187,7 @@ MobileNet V3 Large 74.042 91.340
185187
MobileNet V3 Small 67.668 87.402
186188
ResNeXt-50-32x4d 77.618 93.698
187189
ResNeXt-101-32x8d 79.312 94.526
190+
ResNeXt-101-64x4d 83.246 96.454
188191
Wide ResNet-50-2 78.468 94.086
189192
Wide ResNet-101-2 78.848 94.284
190193
MNASNet 1.0 73.456 91.510
@@ -366,6 +369,7 @@ ResNext
366369

367370
resnext50_32x4d
368371
resnext101_32x8d
372+
resnext101_64x4d
369373

370374
Wide ResNet
371375
-----------
@@ -481,8 +485,11 @@ a model with random weights by calling its constructor:
481485
resnet18 = models.quantization.resnet18()
482486
resnet50 = models.quantization.resnet50()
483487
resnext101_32x8d = models.quantization.resnext101_32x8d()
488+
resnext101_64x4d = models.quantization.resnext101_64x4d()
484489
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
485490
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
491+
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
492+
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()
486493
487494
Obtaining a pre-trained quantized model can be done with a few lines of code:
488495

@@ -508,6 +515,7 @@ ShuffleNet V2 x2.0 75.354 92.488
508515
ResNet 18 69.494 88.882
509516
ResNet 50 75.920 92.814
510517
ResNext 101 32x8d 78.986 94.480
518+
ResNext 101 64x4d 82.898 96.326
511519
Inception V3 77.176 93.354
512520
GoogleNet 69.826 89.404
513521
================================ ============= =============
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def _check_input_backprop(model, inputs):
315315
"convnext_base",
316316
"convnext_large",
317317
"resnext101_32x8d",
318+
"resnext101_64x4d",
318319
"wide_resnet101_2",
319320
"efficientnet_b6",
320321
"efficientnet_b7",

test/test_onnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,12 +412,13 @@ def forward(self_module, images, features):
412412
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
413413
import os
414414

415+
import torchvision.transforms._pil_constants as _pil_constants
415416
from PIL import Image
416417
from torchvision.transforms import functional as F
417418

418419
data_dir = os.path.join(os.path.dirname(__file__), "assets")
419420
path = os.path.join(data_dir, *rel_path.split("/"))
420-
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
421+
image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR)
421422

422423
return F.convert_image_dtype(F.pil_to_tensor(image))
423424

test/test_prototype_transforms_functional.py

Lines changed: 191 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.nn.functional import one_hot
1212
from torchvision.prototype import features
1313
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
14+
from torchvision.transforms.functional import _get_perspective_coeffs
1415
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1516

1617
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
@@ -380,6 +381,37 @@ def pad_segmentation_mask():
380381
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
381382

382383

384+
@register_kernel_info_from_sample_inputs_fn
385+
def perspective_bounding_box():
386+
for bounding_box, perspective_coeffs in itertools.product(
387+
make_bounding_boxes(),
388+
[
389+
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
390+
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
391+
],
392+
):
393+
yield SampleInput(
394+
bounding_box,
395+
format=bounding_box.format,
396+
perspective_coeffs=perspective_coeffs,
397+
)
398+
399+
400+
@register_kernel_info_from_sample_inputs_fn
401+
def perspective_segmentation_mask():
402+
for mask, perspective_coeffs in itertools.product(
403+
make_segmentation_masks(extra_dims=((), (4,))),
404+
[
405+
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
406+
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
407+
],
408+
):
409+
yield SampleInput(
410+
mask,
411+
perspective_coeffs=perspective_coeffs,
412+
)
413+
414+
383415
@register_kernel_info_from_sample_inputs_fn
384416
def center_crop_bounding_box():
385417
for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]):
@@ -993,7 +1025,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
9931025
],
9941026
)
9951027
def test_correctness_resized_crop_bounding_box(device, format, top, left, height, width, size):
996-
def _compute_expected(bbox, top_, left_, height_, width_, size_):
1028+
def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
9971029
# bbox should be xyxy
9981030
bbox[0] = (bbox[0] - left_) * size_[1] / width_
9991031
bbox[1] = (bbox[1] - top_) * size_[0] / height_
@@ -1009,7 +1041,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
10091041
]
10101042
expected_bboxes = []
10111043
for in_box in in_boxes:
1012-
expected_bboxes.append(_compute_expected(list(in_box), top, left, height, width, size))
1044+
expected_bboxes.append(_compute_expected_bbox(list(in_box), top, left, height, width, size))
10131045
expected_bboxes = torch.tensor(expected_bboxes, device=device)
10141046

10151047
in_boxes = features.BoundingBox(
@@ -1035,7 +1067,7 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
10351067
],
10361068
)
10371069
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
1038-
def _compute_expected(mask, top_, left_, height_, width_, size_):
1070+
def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
10391071
output = mask.clone()
10401072
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
10411073
output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest")
@@ -1046,7 +1078,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_):
10461078
in_mask[0, 10:20, 10:20] = 1
10471079
in_mask[0, 5:15, 12:23] = 2
10481080

1049-
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
1081+
expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size)
10501082
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
10511083
torch.testing.assert_close(output_mask, expected_mask)
10521084

@@ -1095,6 +1127,161 @@ def parse_padding():
10951127
torch.testing.assert_close(out_mask, expected_mask)
10961128

10971129

1130+
@pytest.mark.parametrize("device", cpu_and_gpu())
1131+
@pytest.mark.parametrize(
1132+
"startpoints, endpoints",
1133+
[
1134+
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
1135+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
1136+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
1137+
],
1138+
)
1139+
def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
1140+
def _compute_expected_bbox(bbox, pcoeffs_):
1141+
m1 = np.array(
1142+
[
1143+
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
1144+
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
1145+
]
1146+
)
1147+
m2 = np.array(
1148+
[
1149+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1150+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1151+
]
1152+
)
1153+
1154+
bbox_xyxy = convert_bounding_box_format(
1155+
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
1156+
)
1157+
points = np.array(
1158+
[
1159+
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
1160+
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
1161+
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
1162+
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
1163+
]
1164+
)
1165+
numer = np.matmul(points, m1.T)
1166+
denom = np.matmul(points, m2.T)
1167+
transformed_points = numer / denom
1168+
out_bbox = [
1169+
np.min(transformed_points[:, 0]),
1170+
np.min(transformed_points[:, 1]),
1171+
np.max(transformed_points[:, 0]),
1172+
np.max(transformed_points[:, 1]),
1173+
]
1174+
out_bbox = features.BoundingBox(
1175+
out_bbox,
1176+
format=features.BoundingBoxFormat.XYXY,
1177+
image_size=bbox.image_size,
1178+
dtype=torch.float32,
1179+
device=bbox.device,
1180+
)
1181+
return convert_bounding_box_format(
1182+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
1183+
)
1184+
1185+
image_size = (32, 38)
1186+
1187+
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
1188+
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
1189+
1190+
for bboxes in make_bounding_boxes(
1191+
image_sizes=[
1192+
image_size,
1193+
],
1194+
extra_dims=((4,),),
1195+
):
1196+
bboxes = bboxes.to(device)
1197+
bboxes_format = bboxes.format
1198+
bboxes_image_size = bboxes.image_size
1199+
1200+
output_bboxes = F.perspective_bounding_box(
1201+
bboxes,
1202+
bboxes_format,
1203+
perspective_coeffs=pcoeffs,
1204+
)
1205+
1206+
if bboxes.ndim < 2:
1207+
bboxes = [bboxes]
1208+
1209+
expected_bboxes = []
1210+
for bbox in bboxes:
1211+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
1212+
expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs))
1213+
if len(expected_bboxes) > 1:
1214+
expected_bboxes = torch.stack(expected_bboxes)
1215+
else:
1216+
expected_bboxes = expected_bboxes[0]
1217+
torch.testing.assert_close(output_bboxes, expected_bboxes, rtol=1e-5, atol=1e-5)
1218+
1219+
1220+
@pytest.mark.parametrize("device", cpu_and_gpu())
1221+
@pytest.mark.parametrize(
1222+
"startpoints, endpoints",
1223+
[
1224+
[[[0, 0], [33, 0], [33, 25], [0, 25]], [[3, 2], [32, 3], [30, 24], [2, 25]]],
1225+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[0, 0], [33, 0], [33, 25], [0, 25]]],
1226+
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
1227+
],
1228+
)
1229+
def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints):
1230+
def _compute_expected_mask(mask, pcoeffs_):
1231+
assert mask.ndim == 3 and mask.shape[0] == 1
1232+
m1 = np.array(
1233+
[
1234+
[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]],
1235+
[pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]],
1236+
]
1237+
)
1238+
m2 = np.array(
1239+
[
1240+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1241+
[pcoeffs_[6], pcoeffs_[7], 1.0],
1242+
]
1243+
)
1244+
1245+
expected_mask = torch.zeros_like(mask.cpu())
1246+
for out_y in range(expected_mask.shape[1]):
1247+
for out_x in range(expected_mask.shape[2]):
1248+
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
1249+
1250+
numer = np.matmul(output_pt, m1.T)
1251+
denom = np.matmul(output_pt, m2.T)
1252+
input_pt = np.floor(numer / denom).astype(np.int32)
1253+
1254+
in_x, in_y = input_pt[:2]
1255+
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
1256+
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
1257+
return expected_mask.to(mask.device)
1258+
1259+
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
1260+
1261+
for mask in make_segmentation_masks(extra_dims=((), (4,))):
1262+
mask = mask.to(device)
1263+
1264+
output_mask = F.perspective_segmentation_mask(
1265+
mask,
1266+
perspective_coeffs=pcoeffs,
1267+
)
1268+
1269+
if mask.ndim < 4:
1270+
masks = [mask]
1271+
else:
1272+
masks = [m for m in mask]
1273+
1274+
expected_masks = []
1275+
for mask in masks:
1276+
expected_mask = _compute_expected_mask(mask, pcoeffs)
1277+
expected_masks.append(expected_mask)
1278+
if len(expected_masks) > 1:
1279+
expected_masks = torch.stack(expected_masks)
1280+
else:
1281+
expected_masks = expected_masks[0]
1282+
torch.testing.assert_close(output_mask, expected_masks)
1283+
1284+
10981285
@pytest.mark.parametrize("device", cpu_and_gpu())
10991286
@pytest.mark.parametrize(
11001287
"output_size",
@@ -1148,5 +1335,4 @@ def _compute_expected_bbox(bbox, output_size_):
11481335
expected_bboxes = torch.stack(expected_bboxes)
11491336
else:
11501337
expected_bboxes = expected_bboxes[0]
1151-
expected_bboxes = expected_bboxes.to(device=device)
11521338
torch.testing.assert_close(output_boxes, expected_bboxes)

test/test_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99
import torch
1010
import torchvision.transforms as transforms
11+
import torchvision.transforms._pil_constants as _pil_constants
1112
import torchvision.transforms.functional as F
1213
import torchvision.transforms.functional_tensor as F_t
1314
from PIL import Image
@@ -173,7 +174,7 @@ def test_accimage_pil_to_tensor(self):
173174
def test_accimage_resize(self):
174175
trans = transforms.Compose(
175176
[
176-
transforms.Resize(256, interpolation=Image.LINEAR),
177+
transforms.Resize(256, interpolation=_pil_constants.LINEAR),
177178
transforms.PILToTensor(),
178179
transforms.ConvertImageDtype(dtype=torch.float),
179180
]

test/test_transforms_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66
import torch
7+
import torchvision.transforms._pil_constants as _pil_constants
78
from common_utils import (
89
get_tmp_dir,
910
int_dtypes,
@@ -15,7 +16,6 @@
1516
cpu_and_gpu,
1617
assert_equal,
1718
)
18-
from PIL import Image
1919
from torchvision import transforms as T
2020
from torchvision.transforms import InterpolationMode
2121
from torchvision.transforms import functional as F
@@ -771,13 +771,13 @@ def shear(pil_img, level, mode, resample):
771771
matrix = (1, level, 0, 0, 1, 0)
772772
elif mode == "Y":
773773
matrix = (1, 0, 0, level, 1, 0)
774-
return pil_img.transform((image_size, image_size), Image.AFFINE, matrix, resample=resample)
774+
return pil_img.transform((image_size, image_size), _pil_constants.AFFINE, matrix, resample=resample)
775775

776776
t_img, pil_img = _create_data(image_size, image_size)
777777

778778
resample_pil = {
779-
F.InterpolationMode.NEAREST: Image.NEAREST,
780-
F.InterpolationMode.BILINEAR: Image.BILINEAR,
779+
F.InterpolationMode.NEAREST: _pil_constants.NEAREST,
780+
F.InterpolationMode.BILINEAR: _pil_constants.BILINEAR,
781781
}[interpolation]
782782

783783
level = 0.3

0 commit comments

Comments
 (0)