Skip to content

Commit 496de93

Browse files
committed
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
1 parent fa8a6d5 commit 496de93

File tree

12 files changed

+176
-34
lines changed

12 files changed

+176
-34
lines changed

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
sphinx==2.4.4
1+
sphinx==3.5.4
22
sphinx-gallery>=0.9.0
33
sphinx-copybutton>=0.3.1
44
matplotlib

docs/source/transforms.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,8 @@ Generic Transforms
214214
:members:
215215

216216

217-
AutoAugment Transforms
218-
----------------------
217+
Automatic Augmentation Transforms
218+
---------------------------------
219219

220220
`AutoAugment <https://arxiv.org/pdf/1805.09501.pdf>`_ is a common Data Augmentation technique that can improve the accuracy of Image Classification models.
221221
Though the data augmentation policies are directly linked to their trained dataset, empirical studies show that
@@ -229,6 +229,15 @@ The new transform can be used standalone or mixed-and-matched with existing tran
229229
.. autoclass:: AutoAugment
230230
:members:
231231

232+
`RandAugment <https://arxiv.org/abs/1909.13719>`_ is a simple high-performing Data Augmentation technique which improves the accuracy of Image Classification models.
233+
234+
.. autoclass:: RandAugment
235+
:members:
236+
237+
`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.
238+
239+
.. autoclass:: TrivialAugmentWide
240+
:members:
232241

233242
.. _functional_transforms:
234243

gallery/plot_transforms.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,22 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
245245
row_title = [str(policy).split('.')[-1] for policy in policies]
246246
plot(imgs, row_title=row_title)
247247

248+
####################################
249+
# RandAugment
250+
# ~~~~~~~~~~~
251+
# The :class:`~torchvision.transforms.RandAugment` transform automatically augments the data.
252+
augmenter = T.RandAugment()
253+
imgs = [augmenter(orig_img) for _ in range(4)]
254+
plot(imgs)
255+
256+
####################################
257+
# TrivialAugmentWide
258+
# ~~~~~~~~~~~
259+
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
260+
augmenter = T.TrivialAugmentWide()
261+
imgs = [augmenter(orig_img) for _ in range(4)]
262+
plot(imgs)
263+
248264
####################################
249265
# Randomly-applied transforms
250266
# ---------------------------

gallery/plot_visualization_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def show(imgs):
343343
print(dog1_output['scores'])
344344

345345
#####################################
346-
# Clearly the model is less confident about the dog detection than it is about
346+
# Clearly the model is more confident about the dog detection than it is about
347347
# the people detections. That's good news. When plotting the masks, we can ask
348348
# for only those that have a good score. Let's use a score threshold of .75
349349
# here, and also plot the masks of the second dog.

references/classification/presets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
99
if hflip_prob > 0:
1010
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
1111
if auto_augment_policy is not None:
12-
if auto_augment_policy == "ta_wide":
12+
if auto_augment_policy == "ra":
13+
trans.append(autoaugment.RandAugment())
14+
elif auto_augment_policy == "ta_wide":
1315
trans.append(autoaugment.TrivialAugmentWide())
1416
else:
1517
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)

references/classification/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def main(args):
175175
if args.distributed and args.sync_bn:
176176
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
177177

178-
criterion = nn.CrossEntropyLoss()
178+
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
179179

180180
opt_name = args.opt.lower()
181181
if opt_name == 'sgd':
@@ -256,6 +256,9 @@ def get_args_parser(add_help=True):
256256
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
257257
metavar='W', help='weight decay (default: 1e-4)',
258258
dest='weight_decay')
259+
parser.add_argument('--label-smoothing', default=0.0, type=float,
260+
help='label smoothing (default: 0.0)',
261+
dest='label_smoothing')
259262
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
260263
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
261264
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')

test/test_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,6 +1490,18 @@ def test_autoaugment(policy, fill):
14901490
transform.__repr__()
14911491

14921492

1493+
@pytest.mark.parametrize('num_ops', [1, 2, 3])
1494+
@pytest.mark.parametrize('magnitude', [7, 9, 11])
1495+
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
1496+
def test_randaugment(num_ops, magnitude, fill):
1497+
random.seed(42)
1498+
img = Image.open(GRACE_HOPPER)
1499+
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
1500+
for _ in range(100):
1501+
img = transform(img)
1502+
transform.__repr__()
1503+
1504+
14931505
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
14941506
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
14951507
def test_trivialaugmentwide(fill, num_magnitude_bins):

test/test_transforms_tensor.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -525,18 +525,26 @@ def test_autoaugment(device, policy, fill):
525525
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
526526
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
527527

528-
s_transform = None
529528
transform = T.AutoAugment(policy=policy, fill=fill)
530529
s_transform = torch.jit.script(transform)
531530
for _ in range(25):
532531
_test_transform_vs_scripted(transform, s_transform, tensor)
533532
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
534533

535534

536-
def test_autoaugment_save(tmpdir):
537-
transform = T.AutoAugment()
535+
@pytest.mark.parametrize('device', cpu_and_gpu())
536+
@pytest.mark.parametrize('num_ops', [1, 2, 3])
537+
@pytest.mark.parametrize('magnitude', [7, 9, 11])
538+
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
539+
def test_randaugment(device, num_ops, magnitude, fill):
540+
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
541+
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)
542+
543+
transform = T.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
538544
s_transform = torch.jit.script(transform)
539-
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
545+
for _ in range(25):
546+
_test_transform_vs_scripted(transform, s_transform, tensor)
547+
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
540548

541549

542550
@pytest.mark.parametrize('device', cpu_and_gpu())
@@ -552,8 +560,9 @@ def test_trivialaugmentwide(device, fill):
552560
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
553561

554562

555-
def test_trivialaugmentwide_save(tmpdir):
556-
transform = T.TrivialAugmentWide()
563+
@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
564+
def test_autoaugment_save(augmentation, tmpdir):
565+
transform = augmentation()
557566
s_transform = torch.jit.script(transform)
558567
s_transform.save(os.path.join(tmpdir, "t_autoaugment.pt"))
559568

torchvision/csrc/io/image/cpu/encode_jpeg.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ torch::Tensor encode_jpeg(const torch::Tensor& data, int64_t quality) {
1414

1515
#else
1616
// For libjpeg version <= 9b, the out_size parameter in jpeg_mem_dest() is
17-
// defined as unsigned long, where as in later version, it is defined as size_t.
17+
// defined as unsigned long, whereas in later version, it is defined as size_t.
1818
// For windows backward compatibility, we define JpegSizeType as different types
19-
// according to the libjpeg version used, in order to prevent compilcation
19+
// according to the libjpeg version used, in order to prevent compilation
2020
// errors.
2121
#if defined(_WIN32) || !defined(JPEG_LIB_VERSION_MAJOR) || \
22-
(JPEG_LIB_VERSION_MAJOR < 9) || \
22+
JPEG_LIB_VERSION_MAJOR < 9 || \
2323
(JPEG_LIB_VERSION_MAJOR == 9 && JPEG_LIB_VERSION_MINOR <= 2)
2424
using JpegSizeType = unsigned long;
2525
#else

torchvision/datasets/caltech.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ class Caltech101(VisionDataset):
1818
root (string): Root directory of dataset where directory
1919
``caltech101`` exists or will be saved to if download is set to True.
2020
target_type (string or list, optional): Type of target to use, ``category`` or
21-
``annotation``. Can also be a list to output a tuple with all specified target types.
22-
``category`` represents the target class, and ``annotation`` is a list of points
23-
from a hand-generated outline. Defaults to ``category``.
21+
``annotation``. Can also be a list to output a tuple with all specified
22+
target types. ``category`` represents the target class, and
23+
``annotation`` is a list of points from a hand-generated outline.
24+
Defaults to ``category``.
2425
transform (callable, optional): A function/transform that takes in an PIL image
2526
and returns a transformed version. E.g, ``transforms.RandomCrop``
2627
target_transform (callable, optional): A function/transform that takes in the

0 commit comments

Comments
 (0)