|
26 | 26 | make_video,
|
27 | 27 | set_rng_seed,
|
28 | 28 | )
|
| 29 | + |
| 30 | +from torch import nn |
29 | 31 | from torch.testing import assert_close
|
30 | 32 | from torchvision import datapoints
|
31 | 33 |
|
@@ -1634,3 +1636,64 @@ def test_transform_negative_degrees_error(self):
|
1634 | 1636 | def test_transform_unknown_fill_error(self):
|
1635 | 1637 | with pytest.raises(TypeError, match="Got inappropriate fill arg"):
|
1636 | 1638 | transforms.RandomAffine(degrees=0, fill="fill")
|
| 1639 | + |
| 1640 | + |
| 1641 | +class TestCompose: |
| 1642 | + class BuiltinTransform(transforms.Transform): |
| 1643 | + def _transform(self, inpt, params): |
| 1644 | + return inpt |
| 1645 | + |
| 1646 | + class PackedInputTransform(nn.Module): |
| 1647 | + def forward(self, sample): |
| 1648 | + assert len(sample) == 2 |
| 1649 | + return sample |
| 1650 | + |
| 1651 | + class UnpackedInputTransform(nn.Module): |
| 1652 | + def forward(self, image, label): |
| 1653 | + return image, label |
| 1654 | + |
| 1655 | + @pytest.mark.parametrize( |
| 1656 | + "transform_clss", |
| 1657 | + [ |
| 1658 | + [BuiltinTransform], |
| 1659 | + [PackedInputTransform], |
| 1660 | + [UnpackedInputTransform], |
| 1661 | + [BuiltinTransform, BuiltinTransform], |
| 1662 | + [PackedInputTransform, PackedInputTransform], |
| 1663 | + [UnpackedInputTransform, UnpackedInputTransform], |
| 1664 | + [BuiltinTransform, PackedInputTransform, BuiltinTransform], |
| 1665 | + [BuiltinTransform, UnpackedInputTransform, BuiltinTransform], |
| 1666 | + [PackedInputTransform, BuiltinTransform, PackedInputTransform], |
| 1667 | + [UnpackedInputTransform, BuiltinTransform, UnpackedInputTransform], |
| 1668 | + ], |
| 1669 | + ) |
| 1670 | + @pytest.mark.parametrize("unpack", [True, False]) |
| 1671 | + def test_packed_unpacked(self, transform_clss, unpack): |
| 1672 | + needs_packed_inputs = any(issubclass(cls, self.PackedInputTransform) for cls in transform_clss) |
| 1673 | + needs_unpacked_inputs = any(issubclass(cls, self.UnpackedInputTransform) for cls in transform_clss) |
| 1674 | + assert not (needs_packed_inputs and needs_unpacked_inputs) |
| 1675 | + |
| 1676 | + transform = transforms.Compose([cls() for cls in transform_clss]) |
| 1677 | + |
| 1678 | + image = make_image() |
| 1679 | + label = 3 |
| 1680 | + packed_input = (image, label) |
| 1681 | + |
| 1682 | + def call_transform(): |
| 1683 | + if unpack: |
| 1684 | + return transform(*packed_input) |
| 1685 | + else: |
| 1686 | + return transform(packed_input) |
| 1687 | + |
| 1688 | + if needs_unpacked_inputs and not unpack: |
| 1689 | + with pytest.raises(TypeError, match="missing 1 required positional argument"): |
| 1690 | + call_transform() |
| 1691 | + elif needs_packed_inputs and unpack: |
| 1692 | + with pytest.raises(TypeError, match="takes 2 positional arguments but 3 were given"): |
| 1693 | + call_transform() |
| 1694 | + else: |
| 1695 | + output = call_transform() |
| 1696 | + |
| 1697 | + assert isinstance(output, tuple) and len(output) == 2 |
| 1698 | + assert output[0] is image |
| 1699 | + assert output[1] is label |
0 commit comments