Skip to content

Commit 496de93

Browse files
committed
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
1 parent fa8a6d5 commit 496de93

File tree

12 files changed

+176
-34
lines changed

12 files changed

+176
-34
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
sphinx==2.4.4
1+
sphinx==3.5.4
22
sphinx-gallery>=0.9.0
33
sphinx-copybutton>=0.3.1
44
matplotlib

docs/source/transforms.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ Generic Transforms
214214
:members:
215215

216216

217-
AutoAugment Transforms
218-
----------------------
217+
Automatic Augmentation Transforms
218+
---------------------------------
219219

220220
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
221221
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
@@ -229,6 +229,15 @@ The new transform can be used standalone or mixed-and-matched with existing tran
229229
.. autoclass:: AutoAugment
230230
:members:
231231

232+
`RandAugment <https://arxiv.org/abs/1909.13719>`_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models.
233+
234+
.. autoclass:: RandAugment
235+
:members:
236+
237+
`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.
238+
239+
.. autoclass:: TrivialAugmentWide
240+
:members:
232241

233242
.. _functional_transforms:
234243

gallery/plot_transforms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,22 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
245245
row_title = [str(policy).split('.')[-1] for policy in policies]
246246
plot(imgs, row_title=row_title)
247247

248+
####################################
249+
# RandAugment
250+
# ~~~~~~~~~~~
251+
# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data.
252+
augmenter = T.RandAugment()
253+
imgs = [augmenter(orig_img) for _ in range(4)]
254+
plot(imgs)
255+
256+
####################################
257+
# TrivialAugmentWide
258+
# ~~~~~~~~~~~
259+
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
260+
augmenter = T.TrivialAugmentWide()
261+
imgs = [augmenter(orig_img) for _ in range(4)]
262+
plot(imgs)
263+
248264
####################################
249265
# Randomly-applied transforms
250266
# ---------------------------

gallery/plot_visualization_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def show(imgs):
343343
print(dog1_output['scores'])
344344

345345
#####################################
346-
# Clearly the model is less confident about the dog detection than it is about
346+
# Clearly the model is more confident about the dog detection than it is about
347347
# the people detections. That's good news. When plotting the masks, we can ask
348348
# for only those that have a good score. Let's use a score threshold of .75
349349
# here, and also plot the masks of the second dog.

references/classification/presets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
99
if hflip_prob > 0:
1010
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
1111
if auto_augment_policy is not None:
12-
if auto_augment_policy == "ta_wide":
12+
if auto_augment_policy == "ra":
13+
trans.append(autoaugment.RandAugment())
14+
elif auto_augment_policy == "ta_wide":
1315
trans.append(autoaugment.TrivialAugmentWide())
1416
else:
1517
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)

references/classification/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def main(args):
175175
if args.distributed and args.sync_bn:
176176
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
177177

178-
criterion = nn.CrossEntropyLoss()
178+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
179179

180180
opt_name = args.opt.lower()
181181
if opt_name == 'sgd':
@@ -256,6 +256,9 @@ def get_args_parser(add_help=True):
256256
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
257257
metavar='W', help='weight decay (default: 1e-4)',
258258
dest='weight_decay')
259+
parser.add_argument('--label-smoothing', default=0.0, type=float,
260+
help='label smoothing (default: 0.0)',
261+
dest='label_smoothing')
259262
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
260263
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
261264
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')

test/test_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
14901490
transform.__repr__()
14911491

14921492

1493+
@pytest.mark.parametrize('num_ops', [1, 2, 3])
1494+
@pytest.mark.parametrize('magnitude', [7, 9, 11])
1495+
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
1496+
def test_randaugment(num_ops, magnitude, fill):
1497+
random.seed(42)
1498+
img = Image.open(GRACE_HOPPER)
1499+
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
1500+
for _ in range(100):
1501+
img = transform(img)
1502+
transform.__repr__()
1503+
1504+
14931505
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
14941506
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
14951507
def test_trivialaugmentwide(fill, num_magnitude_bins):

test/test_transforms_tensor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,18 +525,26 @@ def test_autoaugment(device, policy, fill):
525525
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
526526
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
527527

528-
s_transform = None
529528
transform = T.AutoAugment(policy=policy, fill=fill)
530529
s_transform = torch.jit.script(transform)
531530
for _ in range(25):
532531
_test_transform_vs_scripted(transform, s_transform, tensor)
533532
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
534533

535534

536-
def test_autoaugment_save(tmpdir):
537-
transform = T.AutoAugment()
535+
@pytest.mark.parametrize('device', cpu_and_gpu())
536+
@pytest.mark.parametrize('num_ops', [1, 2, 3])
537+
@pytest.mark.parametrize('magnitude', [7, 9, 11])
538+
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
539+
def test_randaugment(device, num_ops, magnitude, fill):
540+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
541+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
542+
543+
transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
538544
s_transform = torch.jit.script(transform)
539-
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
545+
for _ in range(25):
546+
_test_transform_vs_scripted(transform, s_transform, tensor)
547+
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
540548

541549

542550
@pytest.mark.parametrize('device', cpu_and_gpu())
@@ -552,8 +560,9 @@ def test_trivialaugmentwide(device, fill):
552560
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
553561

554562

555-
def test_trivialaugmentwide_save(tmpdir):
556-
transform = T.TrivialAugmentWide()
563+
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
564+
def test_autoaugment_save(augmentation, tmpdir):
565+
transform = augmentation()
557566
s_transform = torch.jit.script(transform)
558567
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
559568

torchvision/csrc/io/image/cpu/encode_jpeg.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
1414

1515
#else
1616
// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
17-
// defined as unsigned long, where as in later version, it is defined as size_t.
17+
// defined as unsigned long, whereas in later version, it is defined as size_t.
1818
// For windows backward compatibility, we define JpegSizeType as different types
19-
// according to the libjpeg version used, in order to prevent compilcation
19+
// according to the libjpeg version used, in order to prevent compilation
2020
// errors.
2121
#if defined(_WIN32) || !defined(JPEG_LIB_VERSION_MAJOR) || \
22-
(JPEG_LIB_VERSION_MAJOR < 9) || \
22+
JPEG_LIB_VERSION_MAJOR < 9 || \
2323
(JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
2424
using JpegSizeType = unsigned long;
2525
#else

torchvision/datasets/caltech.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ class Caltech101(VisionDataset):
1818
root (string): Root directory of dataset where directory
1919
``caltech101`` exists or will be saved to if download is set to True.
2020
target_type (string or list, optional): Type of target to use, ``category`` or
21-
``annotation``. Can also be a list to output a tuple with all specified target types.
22-
``category`` represents the target class, and ``annotation`` is a list of points
23-
from a hand-generated outline. Defaults to ``category``.
21+
``annotation``. Can also be a list to output a tuple with all specified
22+
target types. ``category`` represents the target class, and
23+
``annotation`` is a list of points from a hand-generated outline.
24+
Defaults to ``category``.
2425
transform (callable, optional): A function/transform that takes in an PIL image
2526
and returns a transformed version. E.g, ``transforms.RandomCrop``
2627
target_transform (callable, optional): A function/transform that takes in the

torchvision/models/detection/_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,14 @@ def decode_single(self, rel_codes, boxes):
216216
pred_w = torch.exp(dw) * widths[:, None]
217217
pred_h = torch.exp(dh) * heights[:, None]
218218

219-
pred_boxes1 = pred_ctr_x - torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
220-
pred_boxes2 = pred_ctr_y - torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
221-
pred_boxes3 = pred_ctr_x + torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
222-
pred_boxes4 = pred_ctr_y + torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
219+
# Distance from center to box's corner.
220+
c_to_c_h = torch.tensor(0.5, dtype=pred_ctr_y.dtype, device=pred_h.device) * pred_h
221+
c_to_c_w = torch.tensor(0.5, dtype=pred_ctr_x.dtype, device=pred_w.device) * pred_w
222+
223+
pred_boxes1 = pred_ctr_x - c_to_c_w
224+
pred_boxes2 = pred_ctr_y - c_to_c_h
225+
pred_boxes3 = pred_ctr_x + c_to_c_w
226+
pred_boxes4 = pred_ctr_y + c_to_c_h
223227
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=2).flatten(1)
224228
return pred_boxes
225229

torchvision/transforms/autoaugment.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import functional as F, InterpolationMode
99

10-
__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugmentWide"]
10+
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
1111

1212

1313
def _apply_op(img: Tensor, op_name: str, magnitude: float,
@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
5858
SVHN = "svhn"
5959

6060

61+
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
6162
class AutoAugment(torch.nn.Module):
6263
r"""AutoAugment data augmentation method based on
6364
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
@@ -85,9 +86,9 @@ def __init__(
8586
self.policy = policy
8687
self.interpolation = interpolation
8788
self.fill = fill
88-
self.transforms = self._get_transforms(policy)
89+
self.policies = self._get_policies(policy)
8990

90-
def _get_transforms(
91+
def _get_policies(
9192
self,
9293
policy: AutoAugmentPolicy
9394
) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]:
@@ -178,9 +179,9 @@ def _get_transforms(
178179
else:
179180
raise ValueError("The provided policy {} is not recognized.".format(policy))
180181

181-
def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
182+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
182183
return {
183-
# name: (magnitudes, signed)
184+
# op_name: (magnitudes, signed)
184185
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
185186
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
186187
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
@@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
224225
elif fill is not None:
225226
fill = [float(f) for f in fill]
226227

227-
transform_id, probs, signs = self.get_params(len(self.transforms))
228+
transform_id, probs, signs = self.get_params(len(self.policies))
228229

229-
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
230+
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
230231
if probs[i] <= p:
231-
op_meta = self._get_magnitudes(10, F.get_image_size(img))
232+
op_meta = self._augmentation_space(10, F.get_image_size(img))
232233
magnitudes, signed = op_meta[op_name]
233234
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
234235
if signed and signs[i] == 0:
@@ -241,6 +242,91 @@ def __repr__(self) -> str:
241242
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
242243

243244

245+
class RandAugment(torch.nn.Module):
246+
r"""RandAugment data augmentation method based on
247+
`"RandAugment: Practical automated data augmentation with a reduced search space"
248+
<https://arxiv.org/abs/1909.13719>`_.
249+
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
250+
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
251+
If img is PIL Image, it is expected to be in mode "L" or "RGB".
252+
253+
Args:
254+
num_ops (int): Number of augmentation transformations to apply sequentially.
255+
magnitude (int): Magnitude for all the transformations.
256+
num_magnitude_bins (int): The number of different magnitude values.
257+
interpolation (InterpolationMode): Desired interpolation enum defined by
258+
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
259+
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
260+
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
261+
image. If given a number, the value is used for all bands respectively.
262+
"""
263+
264+
def __init__(self, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 30,
265+
interpolation: InterpolationMode = InterpolationMode.NEAREST,
266+
fill: Optional[List[float]] = None) -> None:
267+
super().__init__()
268+
self.num_ops = num_ops
269+
self.magnitude = magnitude
270+
self.num_magnitude_bins = num_magnitude_bins
271+
self.interpolation = interpolation
272+
self.fill = fill
273+
274+
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
275+
return {
276+
# op_name: (magnitudes, signed)
277+
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
278+
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
279+
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
280+
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
281+
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
282+
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
283+
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
284+
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
285+
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
286+
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
287+
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
288+
"AutoContrast": (torch.tensor(0.0), False),
289+
"Equalize": (torch.tensor(0.0), False),
290+
"Invert": (torch.tensor(0.0), False),
291+
}
292+
293+
def forward(self, img: Tensor) -> Tensor:
294+
"""
295+
img (PIL Image or Tensor): Image to be transformed.
296+
297+
Returns:
298+
PIL Image or Tensor: Transformed image.
299+
"""
300+
fill = self.fill
301+
if isinstance(img, Tensor):
302+
if isinstance(fill, (int, float)):
303+
fill = [float(fill)] * F.get_image_num_channels(img)
304+
elif fill is not None:
305+
fill = [float(f) for f in fill]
306+
307+
for _ in range(self.num_ops):
308+
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
309+
op_index = int(torch.randint(len(op_meta), (1,)).item())
310+
op_name = list(op_meta.keys())[op_index]
311+
magnitudes, signed = op_meta[op_name]
312+
magnitude = float(magnitudes[self.magnitude].item()) if magnitudes.ndim > 0 else 0.0
313+
if signed and torch.randint(2, (1,)):
314+
magnitude *= -1.0
315+
img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
316+
317+
return img
318+
319+
def __repr__(self) -> str:
320+
s = self.__class__.__name__ + '('
321+
s += 'num_ops={num_ops}'
322+
s += ', magnitude={magnitude}'
323+
s += ', num_magnitude_bins={num_magnitude_bins}'
324+
s += ', interpolation={interpolation}'
325+
s += ', fill={fill}'
326+
s += ')'
327+
return s.format(**self.__dict__)
328+
329+
244330
class TrivialAugmentWide(torch.nn.Module):
245331
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
246332
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
@@ -264,9 +350,9 @@ def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMod
264350
self.interpolation = interpolation
265351
self.fill = fill
266352

267-
def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
353+
def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
268354
return {
269-
# name: (magnitudes, signed)
355+
# op_name: (magnitudes, signed)
270356
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
271357
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
272358
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
@@ -283,7 +369,7 @@ def _get_magnitudes(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
283369
"Invert": (torch.tensor(0.0), False),
284370
}
285371

286-
def forward(self, img: Tensor):
372+
def forward(self, img: Tensor) -> Tensor:
287373
"""
288374
img (PIL Image or Tensor): Image to be transformed.
289375
@@ -297,7 +383,7 @@ def forward(self, img: Tensor):
297383
elif fill is not None:
298384
fill = [float(f) for f in fill]
299385

300-
op_meta = self._get_magnitudes(self.num_magnitude_bins)
386+
op_meta = self._augmentation_space(self.num_magnitude_bins)
301387
op_index = int(torch.randint(len(op_meta), (1,)).item())
302388
op_name = list(op_meta.keys())[op_index]
303389
magnitudes, signed = op_meta[op_name]

0 commit comments

Comments
 (0)