From 99c670074943b2833f7552b6592ce629c07aef42 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jul 2023 11:53:40 +0200 Subject: [PATCH 1/5] improve UX for v2 Compose --- test/test_transforms_v2_refactored.py | 51 +++++++++++++++++++++++++ torchvision/transforms/v2/_container.py | 11 ++++-- 2 files changed, 59 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 69180b99dbc..fc007bb6d51 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -26,6 +26,8 @@ make_video, set_rng_seed, ) + +from torch import nn from torch.testing import assert_close from torchvision import datapoints @@ -1634,3 +1636,52 @@ def test_transform_negative_degrees_error(self): def test_transform_unknown_fill_error(self): with pytest.raises(TypeError, match="Got inappropriate fill arg"): transforms.RandomAffine(degrees=0, fill="fill") + + +class TestCompose: + class BuiltinTransform(transforms.Transform): + def _transform(self, inpt, params): + return inpt + + class PackedInputTransform(nn.Module): + def forward(self, sample): + image, label = sample + return image, label + + class UnpackedInputTransform(nn.Module): + def forward(self, image, label): + return image, label + + @pytest.mark.parametrize( + "transform_clss", + [ + [BuiltinTransform], + [PackedInputTransform], + [UnpackedInputTransform], + [BuiltinTransform, BuiltinTransform], + [PackedInputTransform, PackedInputTransform], + [UnpackedInputTransform, UnpackedInputTransform], + [BuiltinTransform, PackedInputTransform, BuiltinTransform], + [BuiltinTransform, UnpackedInputTransform, BuiltinTransform], + [PackedInputTransform, BuiltinTransform, PackedInputTransform], + [UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform], + ], + ) + @pytest.mark.parametrize("unpack", [True, False]) + def test_packed_unpacked(self, transform_clss, unpack): + if (unpack and any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)) or ( + not unpack and any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + ): + return + + transform = transforms.Compose([cls() for cls in transform_clss]) + + image = make_image() + label = torch.tensor(3) + packed_input = (image, label) + + output = transform(*packed_input if unpack else (packed_input,)) + + assert isinstance(output, tuple) and len(output) == 2 + assert output[0] is image + assert output[1] is label diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index fffef4157bd..39940e44b34 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -46,10 +46,15 @@ def __init__(self, transforms: Sequence[Callable]) -> None: self.transforms = transforms def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + needs_unpacking = len(inputs) > 1 + + if not self.transforms: + return inputs if needs_unpacking else inputs[0] + for transform in self.transforms: - sample = transform(sample) - return sample + outputs = transform(*inputs) + inputs = outputs if needs_unpacking else (outputs,) + return outputs def extra_repr(self) -> str: format_string = [] From 835f40cbb0a957975400c98a984300fa3c6bd8c3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jul 2023 13:44:04 +0200 Subject: [PATCH 2/5] simplify guard --- test/test_transforms_v2_refactored.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index fc007bb6d51..109d1a409ba 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1669,8 +1669,12 @@ def forward(self, image, label): ) @pytest.mark.parametrize("unpack", [True, False]) def test_packed_unpacked(self, transform_clss, unpack): - if (unpack and any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)) or ( - not unpack and any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + if any( + unpack + and issubclass(cls, self.PackedInputTransform) + or not unpack + and issubclass(cls, self.UnpackedInputTransform) + for cls in transform_clss ): return From 4bcb4888206d7ea24d02352fe75321566c41d7bc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jul 2023 15:57:33 +0200 Subject: [PATCH 3/5] improve test --- test/test_transforms_v2_refactored.py | 34 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 109d1a409ba..06d8dc98cf8 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1669,23 +1669,31 @@ def forward(self, image, label): ) @pytest.mark.parametrize("unpack", [True, False]) def test_packed_unpacked(self, transform_clss, unpack): - if any( - unpack - and issubclass(cls, self.PackedInputTransform) - or not unpack - and issubclass(cls, self.UnpackedInputTransform) - for cls in transform_clss - ): - return + needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) + needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) + assert not (needs_packed_inputs and needs_unpacked_inputs) transform = transforms.Compose([cls() for cls in transform_clss]) image = make_image() - label = torch.tensor(3) + label = 3 packed_input = (image, label) - output = transform(*packed_input if unpack else (packed_input,)) + def call_transform(): + if unpack: + return transform(*packed_input) + else: + return transform(packed_input) + + if needs_unpacked_inputs and not unpack: + with pytest.raises(TypeError, match="missing 1 required positional argument"): + call_transform() + elif needs_packed_inputs and unpack: + with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"): + call_transform() + else: + output = call_transform() - assert isinstance(output, tuple) and len(output) == 2 - assert output[0] is image - assert output[1] is label + assert isinstance(output, tuple) and len(output) == 2 + assert output[0] is image + assert output[1] is label From b99647b9d21cc2574d671e826f9c7b235e052caf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jul 2023 16:00:46 +0200 Subject: [PATCH 4/5] enforce at least one transform --- torchvision/transforms/v2/_container.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/transforms/v2/_container.py b/torchvision/transforms/v2/_container.py index 39940e44b34..8f591c49707 100644 --- a/torchvision/transforms/v2/_container.py +++ b/torchvision/transforms/v2/_container.py @@ -43,14 +43,12 @@ def __init__(self, transforms: Sequence[Callable]) -> None: super().__init__() if not isinstance(transforms, Sequence): raise TypeError("Argument transforms should be a sequence of callables") + elif not transforms: + raise ValueError("Pass at least one transform") self.transforms = transforms def forward(self, *inputs: Any) -> Any: needs_unpacking = len(inputs) > 1 - - if not self.transforms: - return inputs if needs_unpacking else inputs[0] - for transform in self.transforms: outputs = transform(*inputs) inputs = outputs if needs_unpacking else (outputs,) From f8c326e6566eb377c55b5c47196e7daf78e05dcf Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jul 2023 16:21:33 +0200 Subject: [PATCH 5/5] Update test/test_transforms_v2_refactored.py Co-authored-by: Nicolas Hug --- test/test_transforms_v2_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 06d8dc98cf8..64a79262f3e 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1645,8 +1645,8 @@ def _transform(self, inpt, params): class PackedInputTransform(nn.Module): def forward(self, sample): - image, label = sample - return image, label + assert len(sample) == 2 + return sample class UnpackedInputTransform(nn.Module): def forward(self, image, label):