Skip to content

Commit 99c3594

Browse files
committed
Merge branch 'main' of github.com:pytorch/vision into proto-perf-feature-improvements
2 parents 8441dbc + e1aacdd commit 99c3594

File tree

7 files changed

+31
-21
lines changed

7 files changed

+31
-21
lines changed

test/test_prototype_transforms.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,10 +1379,9 @@ def test__transform(self, mocker):
13791379

13801380

13811381
class TestRandomShortestSize:
1382-
def test__get_params(self, mocker):
1382+
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
1383+
def test__get_params(self, min_size, max_size, mocker):
13831384
spatial_size = (3, 10)
1384-
min_size = [5, 9]
1385-
max_size = 20
13861385

13871386
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
13881387

@@ -1395,10 +1394,9 @@ def test__get_params(self, mocker):
13951394
assert isinstance(size, tuple) and len(size) == 2
13961395

13971396
longer = max(size)
1398-
assert longer <= max_size
1399-
14001397
shorter = min(size)
1401-
if longer == max_size:
1398+
if max_size is not None:
1399+
assert longer <= max_size
14021400
assert shorter <= max_size
14031401
else:
14041402
assert shorter in min_size

test/test_prototype_transforms_consistency.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from prototype_common_utils import (
1414
ArgsKwargs,
15+
assert_close,
1516
assert_equal,
1617
make_bounding_box,
1718
make_detection_mask,
@@ -40,13 +41,15 @@ def __init__(
4041
make_images_kwargs=None,
4142
supports_pil=True,
4243
removed_params=(),
44+
closeness_kwargs=None,
4345
):
4446
self.prototype_cls = prototype_cls
4547
self.legacy_cls = legacy_cls
4648
self.args_kwargs = args_kwargs
4749
self.make_images_kwargs = make_images_kwargs or DEFAULT_MAKE_IMAGES_KWARGS
4850
self.supports_pil = supports_pil
4951
self.removed_params = removed_params
52+
self.closeness_kwargs = closeness_kwargs or dict(rtol=0, atol=0)
5053

5154

5255
# These are here since both the prototype and legacy transform need to be constructed with the same random parameters
@@ -491,10 +494,14 @@ def test_signature_consistency(config):
491494
assert prototype_kinds == legacy_kinds
492495

493496

494-
def check_call_consistency(prototype_transform, legacy_transform, images=None, supports_pil=True):
497+
def check_call_consistency(
498+
prototype_transform, legacy_transform, images=None, supports_pil=True, closeness_kwargs=None
499+
):
495500
if images is None:
496501
images = make_images(**DEFAULT_MAKE_IMAGES_KWARGS)
497502

503+
closeness_kwargs = closeness_kwargs or dict()
504+
498505
for image in images:
499506
image_repr = f"[{tuple(image.shape)}, {str(image.dtype).rsplit('.')[-1]}]"
500507

@@ -520,10 +527,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
520527
f"`is_simple_tensor` path in `_transform`."
521528
) from exc
522529

523-
assert_equal(
530+
assert_close(
524531
output_prototype_tensor,
525532
output_legacy_tensor,
526533
msg=lambda msg: f"Tensor image consistency check failed with: \n\n{msg}",
534+
**closeness_kwargs,
527535
)
528536

529537
try:
@@ -536,10 +544,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
536544
f"`features.Image` path in `_transform`."
537545
) from exc
538546

539-
assert_equal(
547+
assert_close(
540548
output_prototype_image,
541549
output_prototype_tensor,
542550
msg=lambda msg: f"Output for feature and tensor images is not equal: \n\n{msg}",
551+
**closeness_kwargs,
543552
)
544553

545554
if image.ndim == 3 and supports_pil:
@@ -565,10 +574,11 @@ def check_call_consistency(prototype_transform, legacy_transform, images=None, s
565574
f"`PIL.Image.Image` path in `_transform`."
566575
) from exc
567576

568-
assert_equal(
577+
assert_close(
569578
output_prototype_pil,
570579
output_legacy_pil,
571580
msg=lambda msg: f"PIL image consistency check failed with: \n\n{msg}",
581+
**closeness_kwargs,
572582
)
573583

574584

@@ -606,6 +616,7 @@ def test_call_consistency(config, args_kwargs):
606616
legacy_transform,
607617
images=make_images(**config.make_images_kwargs),
608618
supports_pil=config.supports_pil,
619+
closeness_kwargs=config.closeness_kwargs,
609620
)
610621

611622

torchvision/prototype/transforms/_geometry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class RandomShortestSize(Transform):
730730
def __init__(
731731
self,
732732
min_size: Union[List[int], Tuple[int], int],
733-
max_size: int,
733+
max_size: Optional[int] = None,
734734
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
735735
antialias: Optional[bool] = None,
736736
):
@@ -744,7 +744,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
744744
orig_height, orig_width = query_spatial_size(flat_inputs)
745745

746746
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
747-
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
747+
r = min_size / min(orig_height, orig_width)
748+
if self.max_size is not None:
749+
r = min(r, self.max_size / max(orig_height, orig_width))
748750

749751
new_width = int(orig_width * r)
750752
new_height = int(orig_height * r)

torchvision/prototype/transforms/_misc.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,10 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, torch.dtype]]) -> None:
157157
self.dtype = dtype
158158

159159
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
160-
return inpt.to(self.dtype[type(inpt)])
160+
dtype = self.dtype[type(inpt)]
161+
if dtype is None:
162+
return inpt
163+
return inpt.to(dtype=dtype)
161164

162165

163166
class RemoveSmallBoundingBoxes(Transform):

torchvision/prototype/transforms/_transform.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,15 @@
55
import torch
66
from torch import nn
77
from torch.utils._pytree import tree_flatten, tree_unflatten
8-
from torchvision.prototype import features
98
from torchvision.prototype.transforms._utils import _isinstance
109
from torchvision.utils import _log_api_usage_once
1110

1211

1312
class Transform(nn.Module):
1413

1514
# Class attribute defining transformed types. Other types are passed-through without any transformation
16-
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (
17-
features.is_simple_tensor,
18-
features._Feature,
19-
PIL.Image.Image,
20-
)
15+
# We support both Types and callables that are able to do further checks on the type of the input.
16+
_transformed_types: Tuple[Union[Type, Callable[[Any], bool]], ...] = (torch.Tensor, PIL.Image.Image)
2117

2218
def __init__(self) -> None:
2319
super().__init__()

torchvision/prototype/transforms/functional/_color.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
227227
if image.numel() == 0:
228228
return image
229229

230-
return _equalize_image_tensor_vec(image.view(-1, height, width)).view(image.shape)
230+
return _equalize_image_tensor_vec(image.view(-1, height, width)).reshape(image.shape)
231231

232232

233233
equalize_image_pil = _FP.equalize

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def _scale_channel(img_chan: Tensor) -> Tensor:
875875
if img_chan.is_cuda:
876876
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
877877
else:
878-
hist = torch.bincount(img_chan.view(-1), minlength=256)
878+
hist = torch.bincount(img_chan.reshape(-1), minlength=256)
879879

880880
nonzero_hist = hist[hist != 0]
881881
step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor")

0 commit comments

Comments
 (0)