Skip to content

Commit 049e7e2

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [PoC] compatibility layer between stable datasets and prototype transforms (#6663)
Reviewed By: vmoens Differential Revision: D44416279 fbshipit-source-id: a3c1ba2048917c5af3005beef6cec77896ab20f8
1 parent 6a4283b commit 049e7e2

File tree

4 files changed

+446
-5
lines changed

4 files changed

+446
-5
lines changed

test/datasets_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import torchvision.datasets
2626
import torchvision.io
2727
from common_utils import disable_console_output, get_tmp_dir
28+
from torch.utils._pytree import tree_any
2829
from torchvision.transforms.functional import get_dimensions
2930

3031

@@ -581,6 +582,28 @@ def test_transforms(self, config):
581582

582583
mock.assert_called()
583584

585+
@test_all_configs
586+
def test_transforms_v2_wrapper(self, config):
587+
# Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs
588+
# to be available with the next release when v2 is released. Thus, if this import somehow fails on the release
589+
# branch, we screwed up the roll-out
590+
from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2
591+
from torchvision.prototype.datapoints._datapoint import Datapoint
592+
593+
try:
594+
with self.create_dataset(config) as (dataset, _):
595+
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset)
596+
wrapped_sample = wrapped_dataset[0]
597+
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
598+
except TypeError as error:
599+
if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"):
600+
return
601+
raise error
602+
except RuntimeError as error:
603+
if "currently not supported by this wrapper" in str(error):
604+
return
605+
raise error
606+
584607

585608
class ImageDatasetTestCase(DatasetTestCase):
586609
"""Abstract base class for image dataset testcases.
@@ -662,6 +685,15 @@ def wrapper(tmpdir, config):
662685

663686
return wrapper
664687

688+
@test_all_configs
689+
def test_transforms_v2_wrapper(self, config):
690+
# `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly
691+
# or use the supported `"TCHW"`
692+
if config.setdefault("output_format", "TCHW") == "THWC":
693+
return
694+
695+
super().test_transforms_v2_wrapper.__wrapped__(self, config)
696+
665697

666698
def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
667699
r"""Create a random uint8 tensor.

test/test_datasets.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -763,11 +763,19 @@ def _create_annotation_file(self, root, name, file_names, num_annotations_per_im
763763
return info
764764

765765
def _create_annotations(self, image_ids, num_annotations_per_image):
766-
annotations = datasets_utils.combinations_grid(
767-
image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image
768-
)
769-
for id, annotation in enumerate(annotations):
770-
annotation["id"] = id
766+
annotations = []
767+
annotion_id = 0
768+
for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
769+
annotations.append(
770+
dict(
771+
image_id=image_id,
772+
id=annotion_id,
773+
bbox=torch.rand(4).tolist(),
774+
segmentation=[torch.rand(8).tolist()],
775+
category_id=int(torch.randint(91, ())),
776+
)
777+
)
778+
annotion_id += 1
771779
return annotations, dict()
772780

773781
def _create_json(self, root, name, content):

torchvision/prototype/datapoints/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,5 @@
44
from ._label import Label, OneHotLabel
55
from ._mask import Mask
66
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT
7+
8+
from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip

0 commit comments

Comments
 (0)