From 0d29db341ca17005bc4fcafa3c9b7da2fd1f1672 Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 7 Dec 2021 22:16:17 +0000 Subject: [PATCH 1/5] Adding repaeted data-augument sampler --- references/classification/ra_sampler.py | 58 +++++++++++++++++++++++++ references/classification/train.py | 23 ++++++---- 2 files changed, 72 insertions(+), 9 deletions(-) create mode 100644 references/classification/ra_sampler.py diff --git a/references/classification/ra_sampler.py b/references/classification/ra_sampler.py new file mode 100644 index 00000000000..0d54e50a5cb --- /dev/null +++ b/references/classification/ra_sampler.py @@ -0,0 +1,58 @@ +import math + +import torch +import torch.distributed as dist + + +class RASampler(torch.utils.data.Sampler): + """Sampler that restricts data loading to a subset of the dataset for distributed, + with repeated augmentation. + It ensures that different each augmented version of a sample will be visible to a + different process (GPU) + Heavily based on torch.utils.data.DistributedSampler + """ + + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) + self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) + self.shuffle = shuffle + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices = [ele for ele in indices for i in range(3)] + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices[: self.num_selected_samples]) + + def __len__(self): + return self.num_selected_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/references/classification/train.py b/references/classification/train.py index b2c6844df9b..3950dd756c2 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -9,11 +9,11 @@ import torchvision import transforms import utils +from references.classification.ra_sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.transforms.functional import InterpolationMode - try: from torchvision.prototype import models as PM except ImportError: @@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): start_time = time.time() image, target = image.to(device), target.to(device) - with torch.cuda.amp.autocast(enabled=scaler is not None): + with torch.cuda.amp.autocast(enabled=args.amp): output = model(image) loss = criterion(output, target) optimizer.zero_grad() - if scaler is not None: + if args.amp: scaler.scale(loss).backward() if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping @@ -158,7 +158,8 @@ def load_data(traindir, valdir, args): crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) else: - weights = PM.get_weight(args.weights) + fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model] + weights = PM._api.get_weight(fn, args.weights) preprocessing = weights.transforms() dataset_test = torchvision.datasets.ImageFolder( @@ -172,7 +173,10 @@ def load_data(traindir, valdir, args): print("Creating data loaders") if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + if args.ra_sampler: + train_sampler = RASampler(dataset, shuffle=True) + else: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False) else: train_sampler = torch.utils.data.RandomSampler(dataset) @@ -225,10 +229,10 @@ def main(args): ) print("Creating model") - if not args.weights: - model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) - else: - model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) + # if not args.weights: + # model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) + # else: + model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: @@ -484,6 +488,7 @@ def get_args_parser(add_help=True): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training") return parser From b9545970710dd4c0a7a3ffe4670ab49b7ad1ce9b Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 7 Dec 2021 22:32:11 +0000 Subject: [PATCH 2/5] rebase on top of latest main --- references/classification/train.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 3950dd756c2..065ae51a94a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): start_time = time.time() image, target = image.to(device), target.to(device) - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.cuda.amp.autocast(enabled=scaler is not None): output = model(image) loss = criterion(output, target) optimizer.zero_grad() - if args.amp: + if scaler is not None: scaler.scale(loss).backward() if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping @@ -158,8 +158,7 @@ def load_data(traindir, valdir, args): crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation ) else: - fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model] - weights = PM._api.get_weight(fn, args.weights) + weights = PM.get_weight(args.weights) preprocessing = weights.transforms() dataset_test = torchvision.datasets.ImageFolder( @@ -229,10 +228,10 @@ def main(args): ) print("Creating model") - # if not args.weights: - # model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) - # else: - model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) + if not args.weights: + model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes) + else: + model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes) model.to(device) if args.distributed and args.sync_bn: @@ -485,10 +484,10 @@ def get_args_parser(add_help=True): "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") + parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training") # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - parser.add_argument("--ra-sampler", action="store_true", help="whether to use ra_sampler in training") return parser From 1803440055eef3fdefc20f52782da10c8f867a33 Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 7 Dec 2021 22:45:22 +0000 Subject: [PATCH 3/5] fix formatting --- references/classification/ra_sampler.py | 5 ++--- references/classification/train.py | 1 + 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/references/classification/ra_sampler.py b/references/classification/ra_sampler.py index 0d54e50a5cb..8695cf4ac3c 100644 --- a/references/classification/ra_sampler.py +++ b/references/classification/ra_sampler.py @@ -27,7 +27,6 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): self.epoch = 0 self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas - # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) self.shuffle = shuffle @@ -40,12 +39,12 @@ def __iter__(self): else: indices = list(range(len(self.dataset))) - # add extra samples to make it evenly divisible + # Add extra samples to make it evenly divisible indices = [ele for ele in indices for i in range(3)] indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size - # subsample + # Subsample indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples diff --git a/references/classification/train.py b/references/classification/train.py index 065ae51a94a..470cd5e4c9a 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -14,6 +14,7 @@ from torch.utils.data.dataloader import default_collate from torchvision.transforms.functional import InterpolationMode + try: from torchvision.prototype import models as PM except ImportError: From 87ed084b0e5e5d6a9012265fdef8e0b2cbaed8ce Mon Sep 17 00:00:00 2001 From: sallysyw Date: Tue, 7 Dec 2021 22:48:32 +0000 Subject: [PATCH 4/5] rename file --- .../classification/{ra_sampler.py => sampler.py} | 10 +++++----- references/classification/train.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) rename references/classification/{ra_sampler.py => sampler.py} (91%) diff --git a/references/classification/ra_sampler.py b/references/classification/sampler.py similarity index 91% rename from references/classification/ra_sampler.py rename to references/classification/sampler.py index 8695cf4ac3c..d64a50c6366 100644 --- a/references/classification/ra_sampler.py +++ b/references/classification/sampler.py @@ -8,18 +8,18 @@ class RASampler(torch.utils.data.Sampler): """Sampler that restricts data loading to a subset of the dataset for distributed, with repeated augmentation. It ensures that different each augmented version of a sample will be visible to a - different process (GPU) - Heavily based on torch.utils.data.DistributedSampler + different process (GPU). + Heavily based on 'torch.utils.data.DistributedSampler'. """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): if num_replicas is None: if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") + raise RuntimeError("Requires distributed package to be available!") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): - raise RuntimeError("Requires distributed package to be available") + raise RuntimeError("Requires distributed package to be available!") rank = dist.get_rank() self.dataset = dataset self.num_replicas = num_replicas @@ -31,7 +31,7 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): self.shuffle = shuffle def __iter__(self): - # deterministically shuffle based on epoch + # Deterministically shuffle based on epoch g = torch.Generator() g.manual_seed(self.epoch) if self.shuffle: diff --git a/references/classification/train.py b/references/classification/train.py index 470cd5e4c9a..689735d8717 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -9,7 +9,7 @@ import torchvision import transforms import utils -from references.classification.ra_sampler import RASampler +from references.classification.sampler import RASampler from torch import nn from torch.utils.data.dataloader import default_collate from torchvision.transforms.functional import InterpolationMode From 7e5fffc9aff2cacc38cc4eed19fe84b593854a30 Mon Sep 17 00:00:00 2001 From: sallysyw Date: Wed, 8 Dec 2021 00:35:55 +0000 Subject: [PATCH 5/5] adding coode source --- references/classification/sampler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/references/classification/sampler.py b/references/classification/sampler.py index d64a50c6366..cfe95dd085a 100644 --- a/references/classification/sampler.py +++ b/references/classification/sampler.py @@ -10,6 +10,9 @@ class RASampler(torch.utils.data.Sampler): It ensures that different each augmented version of a sample will be visible to a different process (GPU). Heavily based on 'torch.utils.data.DistributedSampler'. + + This is borrowed from the DeiT Repo: + https://github.com/facebookresearch/deit/blob/main/samplers.py """ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):