|
1 | 1 |
|
2 | 2 | """
|
3 | 3 | ===========================
|
4 |
| -How to use Cutmix and Mixup |
| 4 | +How to use CutMix and MixUp |
5 | 5 | ===========================
|
6 | 6 |
|
7 |
| -TODO |
| 7 | +:class:`~torchvision.transforms.v2.Cutmix` and |
| 8 | +:class:`~torchvision.transforms.v2.Mixup` are popular augmentation strategies |
| 9 | +that can improve classification accuracy. |
| 10 | +
|
| 11 | +These transforms are slightly different from the rest of the Torchvision |
| 12 | +transforms, because they expect |
| 13 | +**batches** of samples as input, not individual images. In this example we'll |
| 14 | +explain how to use them: after the ``DataLoader``, or as part of a collation |
| 15 | +function. |
8 | 16 | """
|
| 17 | + |
| 18 | +# %% |
| 19 | +import torch |
| 20 | +import torchvision |
| 21 | +from torchvision.datasets import FakeData |
| 22 | + |
| 23 | +# We are using BETA APIs, so we deactivate the associated warning, thereby acknowledging that |
| 24 | +# some APIs may slightly change in the future |
| 25 | +torchvision.disable_beta_transforms_warning() |
| 26 | + |
| 27 | +from torchvision.transforms import v2 |
| 28 | + |
| 29 | + |
| 30 | +NUM_CLASSES = 100 |
| 31 | + |
| 32 | +# %% |
| 33 | +# Pre-processing pipeline |
| 34 | +# ----------------------- |
| 35 | +# |
| 36 | +# We'll use a simple but typical image classification pipeline: |
| 37 | + |
| 38 | +preproc = v2.Compose([ |
| 39 | + v2.PILToTensor(), |
| 40 | + v2.RandomResizedCrop(size=(224, 224), antialias=True), |
| 41 | + v2.RandomHorizontalFlip(p=0.5), |
| 42 | + v2.ToDtype(torch.float32, scale=True), # to float32 in [0, 1] |
| 43 | + v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), # typically from ImageNet |
| 44 | +]) |
| 45 | + |
| 46 | +dataset = FakeData(size=1000, num_classes=NUM_CLASSES, transform=preproc) |
| 47 | + |
| 48 | +img, label = dataset[0] |
| 49 | +print(f"{type(img) = }, {img.dtype = }, {img.shape = }, {label = }") |
| 50 | + |
| 51 | +# %% |
| 52 | +# |
| 53 | +# One important thing to note is that neither CutMix nor MixUp are part of this |
| 54 | +# pre-processing pipeline. We'll add them a bit later once we define the |
| 55 | +# DataLoader. Just as a refresher, this is what the DataLoader and training loop |
| 56 | +# would look like if we weren't using CutMix or MixUp: |
| 57 | + |
| 58 | +from torch.utils.data import DataLoader |
| 59 | + |
| 60 | +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
| 61 | + |
| 62 | +for images, labels in dataloader: |
| 63 | + print(f"{images.shape = }, {labels.shape = }") |
| 64 | + print(labels.dtype) |
| 65 | + # <rest of the training loop here> |
| 66 | + break |
| 67 | +# %% |
| 68 | + |
| 69 | +# %% |
| 70 | +# Where to use MixUp and CutMix |
| 71 | +# ----------------------------- |
| 72 | +# |
| 73 | +# After the DataLoader |
| 74 | +# ^^^^^^^^^^^^^^^^^^^^ |
| 75 | +# |
| 76 | +# Now let's add CutMix and MixUp. The simplest way to do this right after the |
| 77 | +# DataLoader: the Dataloader has already batched the images and labels for us, |
| 78 | +# and this is exactly what these transforms expect as input: |
| 79 | + |
| 80 | +dataloader = DataLoader(dataset, batch_size=4, shuffle=True) |
| 81 | + |
| 82 | +cutmix = v2.Cutmix(num_classes=NUM_CLASSES) |
| 83 | +mixup = v2.Mixup(num_classes=NUM_CLASSES) |
| 84 | +cutmix_or_mixup = v2.RandomChoice([cutmix, mixup]) |
| 85 | + |
| 86 | +for images, labels in dataloader: |
| 87 | + print(f"Before CutMix/MixUp: {images.shape = }, {labels.shape = }") |
| 88 | + images, labels = cutmix_or_mixup(images, labels) |
| 89 | + print(f"After CutMix/MixUp: {images.shape = }, {labels.shape = }") |
| 90 | + |
| 91 | + # <rest of the training loop here> |
| 92 | + break |
| 93 | +# %% |
| 94 | +# |
| 95 | +# Note how the labels were also transformed: we went from a batched label of |
| 96 | +# shape (batch_size,) to a tensor of shape (batch_size, num_classes). The |
| 97 | +# transformed labels can still be passed as-is to a loss function like |
| 98 | +# :func:`torch.nn.functional.cross_entropy`. |
| 99 | +# |
| 100 | +# As part of the collation function |
| 101 | +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 102 | +# |
| 103 | +# Passing the transforms after the DataLoader is the simplest way to use CutMix |
| 104 | +# and MixUp, but one disadvantage is that it does not take advantage of the |
| 105 | +# DataLoader multi-processing. For that, we can pass those transforms as part of |
| 106 | +# the collation function (refer to the `PyTorch docs |
| 107 | +# <https://pytorch.org/docs/stable/data.html#dataloader-collate-fn>`_ to learn |
| 108 | +# more about collation). |
| 109 | + |
| 110 | +from torch.utils.data import default_collate |
| 111 | + |
| 112 | + |
| 113 | +def collate_fn(batch): |
| 114 | + return cutmix_or_mixup(*default_collate(batch)) |
| 115 | + |
| 116 | + |
| 117 | +dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2, collate_fn=collate_fn) |
| 118 | + |
| 119 | +for images, labels in dataloader: |
| 120 | + print(f"{images.shape = }, {labels.shape = }") |
| 121 | + # No need to call cutmix_or_mixup, it's already been called as part of the DataLoader! |
| 122 | + # <rest of the training loop here> |
| 123 | + break |
| 124 | + |
| 125 | +# %% |
| 126 | +# Non-standard input format |
| 127 | +# ------------------------- |
| 128 | +# |
| 129 | +# So far we've used a typical sample structure where we pass ``(images, |
| 130 | +# labels)`` as inputs. MixUp and CutMix will magically work by default with most |
| 131 | +# common sample structures: tuples where the second parameter is a tensor label, |
| 132 | +# or dict with a "label[s]" key. Look at the documentation of the |
| 133 | +# ``labels_getter`` parameter for more details. |
| 134 | +# |
| 135 | +# If your samples have a different structure, you can still use CutMix and MixUp |
| 136 | +# by passing a callable to the ``labels_getter`` parameter. For example: |
| 137 | + |
| 138 | +batch = { |
| 139 | + "imgs": torch.rand(4, 3, 224, 224), |
| 140 | + "target": { |
| 141 | + "classes": torch.randint(0, NUM_CLASSES, size=(4,)), |
| 142 | + "some_other_key": "this is going to be passed-through" |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | + |
| 147 | +def labels_getter(batch): |
| 148 | + return batch["target"]["classes"] |
| 149 | + |
| 150 | + |
| 151 | +out = v2.Cutmix(num_classes=NUM_CLASSES, labels_getter=labels_getter)(batch) |
| 152 | +print(f"{out['imgs'].shape = }, {out['target']['classes'].shape = }") |
0 commit comments