|
25 | 25 | import torchvision.datasets
|
26 | 26 | import torchvision.io
|
27 | 27 | from common_utils import disable_console_output, get_tmp_dir
|
| 28 | +from torch.utils._pytree import tree_any |
28 | 29 | from torchvision.transforms.functional import get_dimensions
|
29 | 30 |
|
30 | 31 |
|
@@ -581,6 +582,28 @@ def test_transforms(self, config):
|
581 | 582 |
|
582 | 583 | mock.assert_called()
|
583 | 584 |
|
| 585 | + @test_all_configs |
| 586 | + def test_transforms_v2_wrapper(self, config): |
| 587 | + # Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs |
| 588 | + # to be available with the next release when v2 is released. Thus, if this import somehow fails on the release |
| 589 | + # branch, we screwed up the roll-out |
| 590 | + from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2 |
| 591 | + from torchvision.prototype.datapoints._datapoint import Datapoint |
| 592 | + |
| 593 | + try: |
| 594 | + with self.create_dataset(config) as (dataset, _): |
| 595 | + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) |
| 596 | + wrapped_sample = wrapped_dataset[0] |
| 597 | + assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) |
| 598 | + except TypeError as error: |
| 599 | + if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): |
| 600 | + return |
| 601 | + raise error |
| 602 | + except RuntimeError as error: |
| 603 | + if "currently not supported by this wrapper" in str(error): |
| 604 | + return |
| 605 | + raise error |
| 606 | + |
584 | 607 |
|
585 | 608 | class ImageDatasetTestCase(DatasetTestCase):
|
586 | 609 | """Abstract base class for image dataset testcases.
|
@@ -662,6 +685,15 @@ def wrapper(tmpdir, config):
|
662 | 685 |
|
663 | 686 | return wrapper
|
664 | 687 |
|
| 688 | + @test_all_configs |
| 689 | + def test_transforms_v2_wrapper(self, config): |
| 690 | + # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly |
| 691 | + # or use the supported `"TCHW"` |
| 692 | + if config.setdefault("output_format", "TCHW") == "THWC": |
| 693 | + return |
| 694 | + |
| 695 | + super().test_transforms_v2_wrapper.__wrapped__(self, config) |
| 696 | + |
665 | 697 |
|
666 | 698 | def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor:
|
667 | 699 | r"""Create a random uint8 tensor.
|
|
0 commit comments