Skip to content

Add logging to ImageNet training #530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 78 additions & 40 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import random
import shutil
import time
import warnings
import sys
import logging

import torch
import torch.nn as nn
Expand All @@ -20,17 +20,17 @@
import torchvision.models as models

model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18)')
' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
Expand Down Expand Up @@ -74,26 +74,31 @@
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('-l', '--log-file', default='imagenet.log', type=str,
metavar='PATH',
help='path to log file (default: imagenet.log)')

best_acc1 = 0


def main():
args = parser.parse_args()

logger = configure_logging(args)

if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
logger.warning('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')

if args.gpu is not None:
warnings.warn('You have chosen a specific GPU. This will completely '
'disable data parallelism.')
logger.warning('You have chosen a specific GPU. This will completely '
'disable data parallelism.')

if args.dist_url == "env://" and args.world_size == -1:
args.world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -114,11 +119,12 @@ def main():


def main_worker(gpu, ngpus_per_node, args):
logger = logging.getLogger()
global best_acc1
args.gpu = gpu

if args.gpu is not None:
print("Use GPU: {} for training".format(args.gpu))
logger.info("Use GPU: {} for training".format(args.gpu))

if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
Expand All @@ -131,10 +137,10 @@ def main_worker(gpu, ngpus_per_node, args):
world_size=args.world_size, rank=args.rank)
# create model
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
logger.info("Using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)
else:
print("=> creating model '{}'".format(args.arch))
logger.info("Creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()

if args.distributed:
Expand All @@ -149,7 +155,8 @@ def main_worker(gpu, ngpus_per_node, args):
# ourselves based on the total number of GPUs we have
args.batch_size = int(args.batch_size / ngpus_per_node)
args.workers = int(args.workers / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.gpu])
else:
model.cuda()
# DistributedDataParallel will divide and allocate batch_size to all
Expand All @@ -176,7 +183,7 @@ def main_worker(gpu, ngpus_per_node, args):
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
logger.info("Loading checkpoint '{}' ...".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_acc1 = checkpoint['best_acc1']
Expand All @@ -185,10 +192,11 @@ def main_worker(gpu, ngpus_per_node, args):
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
logger.info("Loaded checkpoint '{}' (epoch {})"
.format(args.resume, args.start_epoch))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
logger.warning("No checkpoint found at '{}'. ".format(args.resume),
"Continuing without one.")

cudnn.benchmark = True

Expand Down Expand Up @@ -246,17 +254,18 @@ def main_worker(gpu, ngpus_per_node, args):
best_acc1 = max(acc1, best_acc1)

if not args.multiprocessing_distributed or (args.multiprocessing_distributed
and args.rank % ngpus_per_node == 0):
and args.rank % ngpus_per_node == 0):
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
'optimizer': optimizer.state_dict(),
}, is_best)


def train(train_loader, model, criterion, optimizer, epoch, args):
logger = logging.getLogger()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
Expand Down Expand Up @@ -295,17 +304,19 @@ def train(train_loader, model, criterion, optimizer, epoch, args):
end = time.time()

if i % args.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
msg = ('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5))
logger.info(msg)


def validate(val_loader, model, criterion, args):
logger = logging.getLogger()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
Expand Down Expand Up @@ -336,20 +347,46 @@ def validate(val_loader, model, criterion, args):
end = time.time()

if i % args.print_freq == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))

print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
msg = ('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
logger.info(msg)

logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))

return top1.avg


def configure_logging(args):
logger = logging.getLogger()
logger.setLevel(logging.INFO)

fmt = logging.Formatter("%(asctime)s| %(levelname)s: %(message)s")

ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.INFO)
ch.addFilter(lambda record: record.levelno <= logging.INFO)
ch.setFormatter(fmt)
logger.addHandler(ch)

ch = logging.StreamHandler(sys.stderr)
ch.setLevel(logging.WARNING)
ch.setFormatter(fmt)
logger.addHandler(ch)

fh = logging.FileHandler(args.log_file)
fh.setLevel(logging.INFO)
fh.setFormatter(fmt)
logger.addHandler(fh)

return logger


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
torch.save(state, filename)
if is_best:
Expand All @@ -358,6 +395,7 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):

class AverageMeter(object):
"""Computes and stores the average and current value"""

def __init__(self):
self.reset()

Expand Down