Skip to content

Commit 379ed22

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add gallery example for MixUp and CutMix (#7772)
Reviewed By: matteobettini Differential Revision: D48642291 fbshipit-source-id: edecd2575c086cf66e7f2c4b9f6edb3db262ee2d
1 parent 1c744be commit 379ed22

File tree

4 files changed

+171
-13
lines changed

4 files changed

+171
-13
lines changed

docs/source/transforms.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,13 +261,13 @@ The new transform can be used standalone or mixed-and-matched with existing tran
261261
AugMix
262262
v2.AugMix
263263

264-
Cutmix - Mixup
264+
CutMix - MixUp
265265
--------------
266266

267-
Cutmix and Mixup are special transforms that
267+
CutMix and MixUp are special transforms that
268268
are meant to be used on batches rather than on individual images, because they
269-
are combining pairs of images together. These can be used after the dataloader,
270-
or part of a collation function. See
269+
are combining pairs of images together. These can be used after the dataloader
270+
(once the samples are batched), or part of a collation function. See
271271
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
272272

273273
.. autosummary::

gallery/plot_cutmix_mixup.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,152 @@
11

22
"""
33
===========================
4-
How to use Cutmix and Mixup
4+
How to use CutMix and MixUp
55
===========================
66
7-
TODO
7+
:class:`~torchvision.transforms.v2.Cutmix` and
8+
:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies
9+
that can improve classification accuracy.
10+
11+
These transforms are slightly different from the rest of the Torchvision
12+
transforms, because they expect
13+
**batches** of samples as input, not individual images. In this example we'll
14+
explain how to use them: after the ``DataLoader``, or as part of a collation
15+
function.
816
"""
17+
18+
# %%
19+
import torch
20+
import torchvision
21+
from torchvision.datasets import FakeData
22+
23+
# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that
24+
# some APIs may slightly change in the future
25+
torchvision.disable_beta_transforms_warning()
26+
27+
from torchvision.transforms import v2
28+
29+
30+
NUM_CLASSES = 100
31+
32+
# %%
33+
# Pre-processing pipeline
34+
# -----------------------
35+
#
36+
# We'll use a simple but typical image classification pipeline:
37+
38+
preproc = v2.Compose([
39+
v2.PILToTensor(),
40+
v2.RandomResizedCrop(size=(224, 224), antialias=True),
41+
v2.RandomHorizontalFlip(p=0.5),
42+
v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1]
43+
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet
44+
])
45+
46+
dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc)
47+
48+
img, label = dataset[0]
49+
print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }")
50+
51+
# %%
52+
#
53+
# One important thing to note is that neither CutMix nor MixUp are part of this
54+
# pre-processing pipeline. We'll add them a bit later once we define the
55+
# DataLoader. Just as a refresher, this is what the DataLoader and training loop
56+
# would look like if we weren't using CutMix or MixUp:
57+
58+
from torch.utils.data import DataLoader
59+
60+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
61+
62+
for images, labels in dataloader:
63+
print(f"{images.shape = }, {labels.shape = }")
64+
print(labels.dtype)
65+
# <rest of the training loop here>
66+
break
67+
# %%
68+
69+
# %%
70+
# Where to use MixUp and CutMix
71+
# -----------------------------
72+
#
73+
# After the DataLoader
74+
# ^^^^^^^^^^^^^^^^^^^^
75+
#
76+
# Now let's add CutMix and MixUp. The simplest way to do this right after the
77+
# DataLoader: the Dataloader has already batched the images and labels for us,
78+
# and this is exactly what these transforms expect as input:
79+
80+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
81+
82+
cutmix = v2.Cutmix(num_classes=NUM_CLASSES)
83+
mixup = v2.Mixup(num_classes=NUM_CLASSES)
84+
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
85+
86+
for images, labels in dataloader:
87+
print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }")
88+
images, labels = cutmix_or_mixup(images, labels)
89+
print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }")
90+
91+
# <rest of the training loop here>
92+
break
93+
# %%
94+
#
95+
# Note how the labels were also transformed: we went from a batched label of
96+
# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The
97+
# transformed labels can still be passed as-is to a loss function like
98+
# :func:`torch.nn.functional.cross_entropy`.
99+
#
100+
# As part of the collation function
101+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
102+
#
103+
# Passing the transforms after the DataLoader is the simplest way to use CutMix
104+
# and MixUp, but one disadvantage is that it does not take advantage of the
105+
# DataLoader multi-processing. For that, we can pass those transforms as part of
106+
# the collation function (refer to the `PyTorch docs
107+
# <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn
108+
# more about collation).
109+
110+
from torch.utils.data import default_collate
111+
112+
113+
def collate_fn(batch):
114+
return cutmix_or_mixup(*default_collate(batch))
115+
116+
117+
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn)
118+
119+
for images, labels in dataloader:
120+
print(f"{images.shape = }, {labels.shape = }")
121+
# No need to call cutmix_or_mixup, it's already been called as part of the DataLoader!
122+
# <rest of the training loop here>
123+
break
124+
125+
# %%
126+
# Non-standard input format
127+
# -------------------------
128+
#
129+
# So far we've used a typical sample structure where we pass ``(images,
130+
# labels)`` as inputs. MixUp and CutMix will magically work by default with most
131+
# common sample structures: tuples where the second parameter is a tensor label,
132+
# or dict with a "label[s]" key. Look at the documentation of the
133+
# ``labels_getter`` parameter for more details.
134+
#
135+
# If your samples have a different structure, you can still use CutMix and MixUp
136+
# by passing a callable to the ``labels_getter`` parameter. For example:
137+
138+
batch = {
139+
"imgs": torch.rand(4, 3, 224, 224),
140+
"target": {
141+
"classes": torch.randint(0, NUM_CLASSES, size=(4,)),
142+
"some_other_key": "this is going to be passed-through"
143+
}
144+
}
145+
146+
147+
def labels_getter(batch):
148+
return batch["target"]["classes"]
149+
150+
151+
out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch)
152+
print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }")

test/test_transforms_v2_refactored.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1922,7 +1922,7 @@ def test_supported_input_structure(self, T):
19221922

19231923
dataset = self.DummyDataset(size=batch_size, num_classes=num_classes)
19241924

1925-
cutmix_mixup = T(alpha=0.5, num_classes=num_classes)
1925+
cutmix_mixup = T(num_classes=num_classes)
19261926

19271927
dl = DataLoader(dataset, batch_size=batch_size)
19281928

torchvision/transforms/v2/_augment.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def _transform(
141141

142142

143143
class _BaseMixupCutmix(Transform):
144-
def __init__(self, *, alpha: float = 1, num_classes: int, labels_getter="default") -> None:
144+
def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None:
145145
super().__init__()
146-
self.alpha = alpha
146+
self.alpha = float(alpha)
147147
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
148148

149149
self.num_classes = num_classes
@@ -204,13 +204,20 @@ def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor:
204204

205205

206206
class Mixup(_BaseMixupCutmix):
207-
"""[BETA] Apply Mixup to the provided batch of images and labels.
207+
"""[BETA] Apply MixUp to the provided batch of images and labels.
208208
209209
.. v2betastatus:: Mixup transform
210210
211211
Paper: `mixup: Beyond Empirical Risk Minimization <https://arxiv.org/abs/1710.09412>`_.
212212
213-
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
213+
.. note::
214+
This transform is meant to be used on **batches** of samples, not
215+
individual images. See
216+
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
217+
examples.
218+
The sample pairing is deterministic and done by matching consecutive
219+
samples in the batch, so the batch needs to be shuffled (this is an
220+
implementation detail, not a guaranteed convention.)
214221
215222
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
216223
into a tensor of shape ``(batch_size, num_classes)``.
@@ -246,14 +253,21 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
246253

247254

248255
class Cutmix(_BaseMixupCutmix):
249-
"""[BETA] Apply Cutmix to the provided batch of images and labels.
256+
"""[BETA] Apply CutMix to the provided batch of images and labels.
250257
251258
.. v2betastatus:: Cutmix transform
252259
253260
Paper: `CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
254261
<https://arxiv.org/abs/1905.04899>`_.
255262
256-
See :ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage examples.
263+
.. note::
264+
This transform is meant to be used on **batches** of samples, not
265+
individual images. See
266+
:ref:`sphx_glr_auto_examples_plot_cutmix_mixup.py` for detailed usage
267+
examples.
268+
The sample pairing is deterministic and done by matching consecutive
269+
samples in the batch, so the batch needs to be shuffled (this is an
270+
implementation detail, not a guaranteed convention.)
257271
258272
In the input, the labels are expected to be a tensor of shape ``(batch_size,)``. They will be transformed
259273
into a tensor of shape ``(batch_size, num_classes)``.

0 commit comments

Comments
 (0)