Skip to content

Commit 1c517ce

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [proto] Compose transform keeps BC (#6391)
Summary: * [proto] Compose keeps BC * Compose -> Compose(Transform) Reviewed By: datumbox Differential Revision: D38824249 fbshipit-source-id: 294949efaee4cb01be5a958799deac9f3018d9e4
1 parent 2b597aa commit 1c517ce

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

test/test_prototype_transforms.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,3 +1083,22 @@ def test__transform(self, inpt_type, mocker):
10831083
fn.call_count == 0
10841084
else:
10851085
fn.assert_called_once_with(inpt)
1086+
1087+
1088+
class TestCompose:
1089+
def test_assertions(self):
1090+
with pytest.raises(TypeError, match="Argument transforms should be a sequence of callables"):
1091+
transforms.Compose(123)
1092+
1093+
@pytest.mark.parametrize(
1094+
"trfms",
1095+
[
1096+
[transforms.Pad(2), transforms.RandomCrop(28)],
1097+
[lambda x: 2.0 * x],
1098+
],
1099+
)
1100+
def test_ctor(self, trfms):
1101+
c = transforms.Compose(trfms)
1102+
inpt = torch.rand(1, 3, 32, 32)
1103+
output = c(inpt)
1104+
assert isinstance(output, torch.Tensor)

torchvision/prototype/transforms/_container.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import Any, Callable, Dict, List, Optional, Sequence
22

33
import torch
44
from torchvision.prototype.transforms import Transform
@@ -7,11 +7,11 @@
77

88

99
class Compose(Transform):
10-
def __init__(self, *transforms: Transform) -> None:
10+
def __init__(self, transforms: Sequence[Callable]) -> None:
1111
super().__init__()
12+
if not isinstance(transforms, Sequence):
13+
raise TypeError("Argument transforms should be a sequence of callables")
1214
self.transforms = transforms
13-
for idx, transform in enumerate(transforms):
14-
self.add_module(str(idx), transform)
1515

1616
def forward(self, *inputs: Any) -> Any:
1717
sample = inputs if len(inputs) > 1 else inputs[0]

0 commit comments

Comments
 (0)