diff --git a/references/classification/train.py b/references/classification/train.py new file mode 100644 index 00000000000..a5d10d3a2d4 --- /dev/null +++ b/references/classification/train.py @@ -0,0 +1,216 @@ +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn +import torchvision +from torchvision import transforms + +import utils + + +def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) + header = 'Epoch: [{}]'.format(epoch) + for image, target in metric_logger.log_every(data_loader, print_freq, header): + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + batch_size = image.shape[0] + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + + +def evaluate(model, criterion, data_loader, device): + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + with torch.no_grad(): + for image, target in metric_logger.log_every(data_loader, 100, header): + image = image.to(device, non_blocking=True) + target = target.to(device, non_blocking=True) + output = model(image) + loss = criterion(output, target) + + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + # FIXME need to take into account that the datasets + # could have been padded in distributed setup + batch_size = image.shape[0] + metric_logger.update(loss=loss.item()) + metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) + metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) + # gather the stats from all processes + metric_logger.synchronize_between_processes() + + print(' * Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f}' + .format(top1=metric_logger.acc1, top5=metric_logger.acc5)) + return metric_logger.acc1.global_avg + + +def main(args): + args.gpu = args.local_rank + + if args.distributed: + args.rank = int(os.environ["RANK"]) + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + dist_url = 'env://' + print('| distributed init (rank {}): {}'.format( + args.rank, dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=dist_url) + utils.setup_for_distributed(args.rank == 0) + + device = torch.device(args.device) + + torch.backends.cudnn.benchmark = True + + # Data loading code + print("Loading data") + traindir = os.path.join(args.data_path, 'train') + valdir = os.path.join(args.data_path, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + print("Loading training data") + st = time.time() + scale = (0.08, 1.0) + if args.model == 'mobilenet_v2': + scale = (0.2, 1.0) + dataset = torchvision.datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224, scale=scale), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) + print("Took", time.time() - st) + + print("Loading validation data") + dataset_test = torchvision.datasets.ImageFolder( + valdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + print("Creating data loaders") + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + else: + train_sampler = torch.utils.data.RandomSampler(dataset) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, + sampler=train_sampler, num_workers=args.workers, pin_memory=True) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=args.batch_size, + sampler=test_sampler, num_workers=args.workers, pin_memory=True) + + print("Creating model") + model = torchvision.models.__dict__[args.model]() + model.to(device) + if args.distributed: + model = torch.nn.utils.convert_sync_batchnorm(model) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + criterion = nn.CrossEntropyLoss() + + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + # if using mobilenet, step_size=2 and gamma=0.94 + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma) + + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + if args.test_only: + evaluate(model, criterion, data_loader_test, device=device) + return + + print("Start training") + start_time = time.time() + for epoch in range(args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + lr_scheduler.step() + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) + evaluate(model, criterion, data_loader_test, device=device) + if args.output_dir: + utils.save_on_master({ + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'args': args}, + os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='PyTorch Classification Training') + + parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset') + parser.add_argument('--model', default='resnet18', help='model') + parser.add_argument('--device', default='cuda', help='device') + parser.add_argument('-b', '--batch-size', default=32, type=int) + parser.add_argument('--epochs', default=90, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 16)') + parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') + parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('--output-dir', default='.', help='path where to save') + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + parser.add_argument('--local_rank', default=0, type=int, help='print frequency') + + args = parser.parse_args() + print(args) + + if args.output_dir: + utils.mkdir(args.output_dir) + + num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + args.distributed = num_gpus > 1 + + main(args) diff --git a/references/classification/utils.py b/references/classification/utils.py new file mode 100644 index 00000000000..de1a313c3d2 --- /dev/null +++ b/references/classification/utils.py @@ -0,0 +1,212 @@ +from collections import defaultdict, deque +import datetime +import time +import torch +import torch.distributed as dist + +import errno +import os + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {}'.format(header, total_time_str)) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target[None]) + + res = [] + for k in topk: + correct_k = correct[:k].flatten().sum(dtype=torch.float32) + res.append(correct_k * (100.0 / batch_size)) + return res + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs)