Skip to content

Commit 5029548

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Make RandomApply torchscriptable in V2 (#7256)
Reviewed By: vmoens Differential Revision: D44416608 fbshipit-source-id: 1e8afbc880dacacacbd2f3e543d21cb4b90e5fdf
1 parent 7159d20 commit 5029548

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

test/test_prototype_transforms_consistency.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,11 @@ def test_random_apply(self, p, sequence_type):
806806

807807
check_call_consistency(prototype_transform, legacy_transform)
808808

809+
if sequence_type is nn.ModuleList:
810+
# quick and dirty test that it is jit-scriptable
811+
scripted = torch.jit.script(prototype_transform)
812+
scripted(torch.rand(1, 3, 300, 300))
813+
809814
# We can't test other values for `p` since the random parameter generation is different
810815
@pytest.mark.parametrize("probabilities", [(0, 1), (1, 0)])
811816
def test_random_choice(self, probabilities):

torchvision/prototype/transforms/_container.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import warnings
2-
from typing import Any, Callable, List, Optional, Sequence, Union
2+
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
33

44
import torch
55

66
from torch import nn
7+
from torchvision import transforms as _transforms
78
from torchvision.prototype.transforms import Transform
89

910

@@ -28,6 +29,8 @@ def extra_repr(self) -> str:
2829

2930

3031
class RandomApply(Transform):
32+
_v1_transform_cls = _transforms.RandomApply
33+
3134
def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: float = 0.5) -> None:
3235
super().__init__()
3336

@@ -39,6 +42,9 @@ def __init__(self, transforms: Union[Sequence[Callable], nn.ModuleList], p: floa
3942
raise ValueError("`p` should be a floating point value in the interval [0.0, 1.0].")
4043
self.p = p
4144

45+
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
46+
return {"transforms": self.transforms, "p": self.p}
47+
4248
def forward(self, *inputs: Any) -> Any:
4349
sample = inputs if len(inputs) > 1 else inputs[0]
4450

torchvision/prototype/transforms/_transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ def __prepare_scriptable__(self) -> nn.Module:
141141
if self._v1_transform_cls is None:
142142
raise RuntimeError(
143143
f"Transform {type(self).__name__} cannot be JIT scripted. "
144-
f"This is only support for backward compatibility with transforms which already in v1."
145-
f"For torchscript support (on tensors only), you can use the functional API instead."
144+
"torchscript is only supported for backward compatibility with transforms "
145+
"which are already in torchvision.transforms. "
146+
"For torchscript support (on tensors only), you can use the functional API instead."
146147
)
147148

148149
return self._v1_transform_cls(**self._extract_params_for_v1_transform())

0 commit comments

Comments
 (0)