diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 854c869a0de..49a5a32301b 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -214,8 +214,8 @@ Generic Transforms :members: -AutoAugment Transforms ----------------------- +Automatic Augmentation Transforms +--------------------------------- `AutoAugment `_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models. Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that @@ -229,6 +229,10 @@ The new transform can be used standalone or mixed-and-matched with existing tran .. autoclass:: AutoAugment :members: +`RandAugment `_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models. + +.. autoclass:: RandAugment + :members: .. _functional_transforms: diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index 032dd584c26..0a0c1afb479 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -245,6 +245,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): row_title = [str(policy).split('.')[-1] for policy in policies] plot(imgs, row_title=row_title) +#################################### +# RandAugment +# ~~~~~~~~~~~ +# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data. +augmenter = T.RandAugment() +imgs = [augmenter(orig_img) for _ in range(4)] +plot(imgs) + #################################### # Randomly-applied transforms # --------------------------- diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index c8b6a543722..3d9c8b6796f 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -245,7 +245,7 @@ def __repr__(self) -> str: class RandAugment(torch.nn.Module): r"""RandAugment data augmentation method based on `"RandAugment: Practical automated data augmentation with a reduced search space" - `. + `_. If the image is torch Tensor, it should be of type torch.uint8, and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "L" or "RGB". @@ -293,6 +293,7 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. + Returns: PIL Image or Tensor: Transformed image. """