diff --git a/imagenet/main.py b/imagenet/main.py index 829567f488..9ca57ca3b3 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -3,8 +3,8 @@ import random import shutil import time -import warnings import sys +import logging import torch import torch.nn as nn @@ -20,8 +20,8 @@ import torchvision.models as models model_names = sorted(name for name in models.__dict__ - if name.islower() and not name.startswith("__") - and callable(models.__dict__[name])) + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser.add_argument('data', metavar='DIR', @@ -29,8 +29,8 @@ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + - ' | '.join(model_names) + - ' (default: resnet18)') + ' | '.join(model_names) + + ' (default: resnet18)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', @@ -74,6 +74,9 @@ 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') +parser.add_argument('-l', '--log-file', default='imagenet.log', type=str, + metavar='PATH', + help='path to log file (default: imagenet.log)') best_acc1 = 0 @@ -81,19 +84,21 @@ def main(): args = parser.parse_args() + logger = configure_logging(args) + if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True - warnings.warn('You have chosen to seed training. ' - 'This will turn on the CUDNN deterministic setting, ' - 'which can slow down your training considerably! ' - 'You may see unexpected behavior when restarting ' - 'from checkpoints.') + logger.warning('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') if args.gpu is not None: - warnings.warn('You have chosen a specific GPU. This will completely ' - 'disable data parallelism.') + logger.warning('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') if args.dist_url == "env://" and args.world_size == -1: args.world_size = int(os.environ["WORLD_SIZE"]) @@ -114,11 +119,12 @@ def main(): def main_worker(gpu, ngpus_per_node, args): + logger = logging.getLogger() global best_acc1 args.gpu = gpu if args.gpu is not None: - print("Use GPU: {} for training".format(args.gpu)) + logger.info("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: @@ -131,10 +137,10 @@ def main_worker(gpu, ngpus_per_node, args): world_size=args.world_size, rank=args.rank) # create model if args.pretrained: - print("=> using pre-trained model '{}'".format(args.arch)) + logger.info("Using pre-trained model '{}'".format(args.arch)) model = models.__dict__[args.arch](pretrained=True) else: - print("=> creating model '{}'".format(args.arch)) + logger.info("Creating model '{}'".format(args.arch)) model = models.__dict__[args.arch]() if args.distributed: @@ -149,7 +155,8 @@ def main_worker(gpu, ngpus_per_node, args): # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model = torch.nn.parallel.DistributedDataParallel(model, + device_ids=[args.gpu]) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all @@ -176,7 +183,7 @@ def main_worker(gpu, ngpus_per_node, args): # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) + logger.info("Loading checkpoint '{}' ...".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] @@ -185,10 +192,11 @@ def main_worker(gpu, ngpus_per_node, args): best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) - print("=> loaded checkpoint '{}' (epoch {})" - .format(args.resume, checkpoint['epoch'])) + logger.info("Loaded checkpoint '{}' (epoch {})" + .format(args.resume, args.start_epoch)) else: - print("=> no checkpoint found at '{}'".format(args.resume)) + logger.warning("No checkpoint found at '{}'. ".format(args.resume), + "Continuing without one.") cudnn.benchmark = True @@ -246,17 +254,18 @@ def main_worker(gpu, ngpus_per_node, args): best_acc1 = max(acc1, best_acc1) if not args.multiprocessing_distributed or (args.multiprocessing_distributed - and args.rank % ngpus_per_node == 0): + and args.rank % ngpus_per_node == 0): save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, - 'optimizer' : optimizer.state_dict(), + 'optimizer': optimizer.state_dict(), }, is_best) def train(train_loader, model, criterion, optimizer, epoch, args): + logger = logging.getLogger() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() @@ -295,17 +304,19 @@ def train(train_loader, model, criterion, optimizer, epoch, args): end = time.time() if i % args.print_freq == 0: - print('Epoch: [{0}][{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - epoch, i, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses, top1=top1, top5=top5)) + msg = ('Epoch: [{0}][{1}/{2}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + epoch, i, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses, top1=top1, top5=top5)) + logger.info(msg) def validate(val_loader, model, criterion, args): + logger = logging.getLogger() batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() @@ -336,20 +347,46 @@ def validate(val_loader, model, criterion, args): end = time.time() if i % args.print_freq == 0: - print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - i, len(val_loader), batch_time=batch_time, loss=losses, - top1=top1, top5=top5)) - - print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' - .format(top1=top1, top5=top5)) + msg = ('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(val_loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + logger.info(msg) + + logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) return top1.avg +def configure_logging(args): + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + fmt = logging.Formatter("%(asctime)s| %(levelname)s: %(message)s") + + ch = logging.StreamHandler(sys.stdout) + ch.setLevel(logging.INFO) + ch.addFilter(lambda record: record.levelno <= logging.INFO) + ch.setFormatter(fmt) + logger.addHandler(ch) + + ch = logging.StreamHandler(sys.stderr) + ch.setLevel(logging.WARNING) + ch.setFormatter(fmt) + logger.addHandler(ch) + + fh = logging.FileHandler(args.log_file) + fh.setLevel(logging.INFO) + fh.setFormatter(fmt) + logger.addHandler(fh) + + return logger + + def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: @@ -358,6 +395,7 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): class AverageMeter(object): """Computes and stores the average and current value""" + def __init__(self): self.reset()