Skip to content

improve UX for v2 Compose #7758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
make_video,
set_rng_seed,
)

from torch import nn
from torch.testing import assert_close
from torchvision import datapoints

Expand Down Expand Up @@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self):
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")


class TestCompose:
class BuiltinTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt

class PackedInputTransform(nn.Module):
def forward(self, sample):
assert len(sample) == 2
return sample

class UnpackedInputTransform(nn.Module):
def forward(self, image, label):
return image, label

@pytest.mark.parametrize(
"transform_clss",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call this transform_class ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would have to be transform_classes, since it is a list of classes. And since we use cls for singular, I'm usually just append an s to it. I'll leave it up to you.

[
[BuiltinTransform],
[PackedInputTransform],
[UnpackedInputTransform],
[BuiltinTransform, BuiltinTransform],
[PackedInputTransform, PackedInputTransform],
[UnpackedInputTransform, UnpackedInputTransform],
[BuiltinTransform, PackedInputTransform, BuiltinTransform],
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
[PackedInputTransform, BuiltinTransform, PackedInputTransform],
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
],
)
@pytest.mark.parametrize("unpack", [True, False])
def test_packed_unpacked(self, transform_clss, unpack):
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
assert not (needs_packed_inputs and needs_unpacked_inputs)

transform = transforms.Compose([cls() for cls in transform_clss])

image = make_image()
label = 3
packed_input = (image, label)

def call_transform():
if unpack:
return transform(*packed_input)
else:
return transform(packed_input)

if needs_unpacked_inputs and not unpack:
with pytest.raises(TypeError, match="missing 1 required positional argument"):
call_transform()
elif needs_packed_inputs and unpack:
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
call_transform()
else:
output = call_transform()

assert isinstance(output, tuple) and len(output) == 2
assert output[0] is image
assert output[1] is label
9 changes: 6 additions & 3 deletions torchvision/transforms/v2/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,16 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
super().__init__()
if not isinstance(transforms, Sequence):
raise TypeError("Argument transforms should be a sequence of callables")
elif not transforms:
raise ValueError("Pass at least one transform")
self.transforms = transforms

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
needs_unpacking = len(inputs) > 1
for transform in self.transforms:
sample = transform(sample)
return sample
outputs = transform(*inputs)
inputs = outputs if needs_unpacking else (outputs,)
return outputs

def extra_repr(self) -> str:
format_string = []
Expand Down