diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 2a6e0ce12c0..8d1228d743f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -550,16 +550,44 @@ def __repr__(self) -> str: return format_string -class RandomOrder(RandomTransforms): - """Apply a list of transformations in a random order. This transform does not support torchscript.""" +class RandomOrder(torch.nn.Module): + """Apply a list of transformations in a random order. - def __call__(self, img): - order = list(range(len(self.transforms))) - random.shuffle(order) + .. note:: + In order to script the transformation, please use ``torch.nn.ModuleList`` as input instead of list/tuple of + transforms as shown below: + + >>> transforms = transforms.RandomOrder(torch.nn.ModuleList([ + >>> transforms.ColorJitter(), + >>> ])) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + Args: + transforms (sequence or torch.nn.Module): list of transformations + """ + + def __init__(self, transforms): + super().__init__() + _log_api_usage_once(self) + self.transforms = transforms + + def forward(self, img): + order = torch.randperm(len(self.transforms)) for i in order: - img = self.transforms[i](img) + img = self.transforms[i.item()](img) return img + def __repr__(self) -> str: + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += "\n" + format_string += f" {t}" + format_string += "\n)" + return format_string + class RandomChoice(RandomTransforms): """Apply single transformation randomly picked from a list. This transform does not support torchscript."""