Skip to content

Commit ad3c3f7

Browse files
authored
Adding docs for RandAugment (#4349)
* Adding docs for RandAugment. * Fix docs.
1 parent 5a81554 commit ad3c3f7

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

docs/source/transforms.rst

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ Generic Transforms
214214
:members:
215215

216216

217-
AutoAugment Transforms
218-
----------------------
217+
Automatic Augmentation Transforms
218+
---------------------------------
219219

220220
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
221221
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
@@ -229,6 +229,10 @@ The new transform can be used standalone or mixed-and-matched with existing tran
229229
.. autoclass:: AutoAugment
230230
:members:
231231

232+
`RandAugment <https://arxiv.org/abs/1909.13719>`_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models.
233+
234+
.. autoclass:: RandAugment
235+
:members:
232236

233237
.. _functional_transforms:
234238

gallery/plot_transforms.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
245245
row_title = [str(policy).split('.')[-1] for policy in policies]
246246
plot(imgs, row_title=row_title)
247247

248+
####################################
249+
# RandAugment
250+
# ~~~~~~~~~~~
251+
# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data.
252+
augmenter = T.RandAugment()
253+
imgs = [augmenter(orig_img) for _ in range(4)]
254+
plot(imgs)
255+
248256
####################################
249257
# Randomly-applied transforms
250258
# ---------------------------

torchvision/transforms/autoaugment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def __repr__(self) -> str:
245245
class RandAugment(torch.nn.Module):
246246
r"""RandAugment data augmentation method based on
247247
`"RandAugment: Practical automated data augmentation with a reduced search space"
248-
<https://arxiv.org/abs/1909.13719>`.
248+
<https://arxiv.org/abs/1909.13719>`_.
249249
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
250250
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
251251
If img is PIL Image, it is expected to be in mode "L" or "RGB".
@@ -293,6 +293,7 @@ def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str,
293293
def forward(self, img: Tensor) -> Tensor:
294294
"""
295295
img (PIL Image or Tensor): Image to be transformed.
296+
296297
Returns:
297298
PIL Image or Tensor: Transformed image.
298299
"""

0 commit comments

Comments
 (0)