diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 38fc417204c..a5a9f75f543 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -568,21 +568,46 @@ def __call__(self, img): return img -class RandomChoice(RandomTransforms): - """Apply single transformation randomly picked from a list. This transform does not support torchscript.""" +class RandomChoice(torch.nn.Module): + """Apply single transformation randomly picked from a list. + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + >>> transforms = transforms.RandomChoice(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ]), p=torch.Tensor([0.3])) + >>> scripted_transforms = torch.jit.script(transforms) + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + Args: + transforms (sequence or torch.nn.Module): list of transformations + p (optional, torch.Tensor): input tensor containing weights. Default: equal weights + """ - def __init__(self, transforms, p=None): - super().__init__(transforms) - if p is not None and not isinstance(p, Sequence): - raise TypeError("Argument p should be a sequence") + def __init__(self, transforms, p: Optional[torch.Tensor] = None): + super().__init__() + _log_api_usage_once(self) + if p is None: + p = torch.ones(len(transforms)) + self.transforms = transforms self.p = p - def __call__(self, *args): - t = random.choices(self.transforms, weights=self.p)[0] - return t(*args) + def forward(self, img): + i = torch.multinomial(self.p, 1) + # self.transforms[i.item()](img) gives Error: Expected integer literal for index, whilw JIT Scripting + # Workaround the ModuleList indexing issue: https://github.com/pytorch/pytorch/issues/16123 + for j,t in enumerate(self.transforms): + if i==j: + return t(img) def __repr__(self) -> str: - return f"{super().__repr__()}(p={self.p})" + format_string = self.__class__.__name__ + "(" + format_string += f"\n p={self.p}" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string class RandomCrop(torch.nn.Module):