Skip to content

Commit cc0f9d0

Browse files
pmeierNicolasHug
andauthored
improve UX for v2 Compose (#7758)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent a6dea86 commit cc0f9d0

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

test/test_transforms_v2_refactored.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
make_video,
2727
set_rng_seed,
2828
)
29+
30+
from torch import nn
2931
from torch.testing import assert_close
3032
from torchvision import datapoints
3133

@@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self):
16341636
def test_transform_unknown_fill_error(self):
16351637
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
16361638
transforms.RandomAffine(degrees=0, fill="fill")
1639+
1640+
1641+
class TestCompose:
1642+
class BuiltinTransform(transforms.Transform):
1643+
def _transform(self, inpt, params):
1644+
return inpt
1645+
1646+
class PackedInputTransform(nn.Module):
1647+
def forward(self, sample):
1648+
assert len(sample) == 2
1649+
return sample
1650+
1651+
class UnpackedInputTransform(nn.Module):
1652+
def forward(self, image, label):
1653+
return image, label
1654+
1655+
@pytest.mark.parametrize(
1656+
"transform_clss",
1657+
[
1658+
[BuiltinTransform],
1659+
[PackedInputTransform],
1660+
[UnpackedInputTransform],
1661+
[BuiltinTransform, BuiltinTransform],
1662+
[PackedInputTransform, PackedInputTransform],
1663+
[UnpackedInputTransform, UnpackedInputTransform],
1664+
[BuiltinTransform, PackedInputTransform, BuiltinTransform],
1665+
[BuiltinTransform, UnpackedInputTransform, BuiltinTransform],
1666+
[PackedInputTransform, BuiltinTransform, PackedInputTransform],
1667+
[UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform],
1668+
],
1669+
)
1670+
@pytest.mark.parametrize("unpack", [True, False])
1671+
def test_packed_unpacked(self, transform_clss, unpack):
1672+
needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss)
1673+
needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss)
1674+
assert not (needs_packed_inputs and needs_unpacked_inputs)
1675+
1676+
transform = transforms.Compose([cls() for cls in transform_clss])
1677+
1678+
image = make_image()
1679+
label = 3
1680+
packed_input = (image, label)
1681+
1682+
def call_transform():
1683+
if unpack:
1684+
return transform(*packed_input)
1685+
else:
1686+
return transform(packed_input)
1687+
1688+
if needs_unpacked_inputs and not unpack:
1689+
with pytest.raises(TypeError, match="missing 1 required positional argument"):
1690+
call_transform()
1691+
elif needs_packed_inputs and unpack:
1692+
with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"):
1693+
call_transform()
1694+
else:
1695+
output = call_transform()
1696+
1697+
assert isinstance(output, tuple) and len(output) == 2
1698+
assert output[0] is image
1699+
assert output[1] is label

torchvision/transforms/v2/_container.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ def __init__(self, transforms: Sequence[Callable]) -> None:
4343
super().__init__()
4444
if not isinstance(transforms, Sequence):
4545
raise TypeError("Argument transforms should be a sequence of callables")
46+
elif not transforms:
47+
raise ValueError("Pass at least one transform")
4648
self.transforms = transforms
4749

4850
def forward(self, *inputs: Any) -> Any:
49-
sample = inputs if len(inputs) > 1 else inputs[0]
51+
needs_unpacking = len(inputs) > 1
5052
for transform in self.transforms:
51-
sample = transform(sample)
52-
return sample
53+
outputs = transform(*inputs)
54+
inputs = outputs if needs_unpacking else (outputs,)
55+
return outputs
5356

5457
def extra_repr(self) -> str:
5558
format_string = []

0 commit comments

Comments
 (0)