Skip to content

Initial version of segmentation reference scripts #820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions references/segmentation/coco_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import copy
import torch
import torch.utils.data
import torchvision
from PIL import Image

import os

from pycocotools import mask as coco_mask

from transforms import Compose


class FilterAndRemapCocoCategories(object):
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap

def __call__(self, image, anno):
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
return image, anno
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
return image, anno


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks


class ConvertCocoPolysToMask(object):
def __call__(self, image, anno):
w, h = image.size
segmentations = [obj["segmentation"] for obj in anno]
cats = [obj["category_id"] for obj in anno]
if segmentations:
masks = convert_coco_poly_to_mask(segmentations, h, w)
cats = torch.as_tensor(cats, dtype=masks.dtype)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
else:
target = torch.zeros((h, w), dtype=torch.uint8)
target = Image.fromarray(target.numpy())
return image, target


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000

assert isinstance(dataset, torchvision.datasets.CocoDetection)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


def get_coco(root, image_set, transforms):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4,
1, 64, 20, 63, 7, 72]

transforms = Compose([
FilterAndRemapCocoCategories(CAT_LIST, remap=True),
ConvertCocoPolysToMask(),
transforms
])

img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)

dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)

if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)

return dataset
219 changes: 219 additions & 0 deletions references/segmentation/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import datetime
import os
import time

import torch
import torch.utils.data
from torch import nn
import torchvision

from coco_utils import get_coco
import transforms as T
import utils


def get_dataset(name, image_set, transform):
def sbd(*args, **kwargs):
return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs)
paths = {
"voc": ('/datasets01/VOC/060817/', torchvision.datasets.VOCSegmentation, 21),
"voc_aug": ('/datasets01/SBDD/072318/', sbd, 21),
"coco": ('/datasets01/COCO/022719/', get_coco, 21)
}
p, ds_fn, num_classes = paths[name]

ds = ds_fn(p, image_set=image_set, transforms=transform)
return ds, num_classes


def get_transform(train):
base_size = 520
crop_size = 480

min_size = int((0.5 if train else 1.0) * base_size)
max_size = int((2.0 if train else 1.0) * base_size)
transforms = []
transforms.append(T.RandomResize(min_size, max_size))
if train:
transforms.append(T.RandomHorizontalFlip(0.5))
transforms.append(T.RandomCrop(crop_size))
transforms.append(T.ToTensor())
transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]))

return T.Compose(transforms)


def criterion(inputs, target):
losses = {}
for name, x in inputs.items():
losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)

if len(losses) == 1:
return losses['out']

return losses['out'] + 0.5 * losses['aux']


def evaluate(model, data_loader, device, num_classes):
model.eval()
confmat = utils.ConfusionMatrix(num_classes)
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, 100, header):
image, target = image.to(device), target.to(device)
output = model(image)
output = output['out']

confmat.update(target.flatten(), output.argmax(1).flatten())

confmat.reduce_from_all_processes()

return confmat


def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
header = 'Epoch: [{}]'.format(epoch)
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

lr_scheduler.step()

metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])


def main(args):
if args.output_dir:
utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)
print(args)

device = torch.device(args.device)

dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True))
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False))

if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn, drop_last=True)

data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1,
sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)

model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss)
model.to(device)
if args.distributed:
model = torch.nn.utils.convert_sync_batchnorm(model)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])

model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return

params_to_optimize = [
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
]
if args.aux_loss:
params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
params_to_optimize.append({"params": params, "lr": args.lr * 10})
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9)

start_time = time.time()
for epoch in range(args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq)
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
utils.save_on_master(
{
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'args': args
},
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))


def parse_args():
import argparse
parser = argparse.ArgumentParser(description='PyTorch Segmentation Training')

parser.add_argument('--dataset', default='voc', help='dataset')
parser.add_argument('--model', default='fcn_resnet101', help='model')
parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss')
parser.add_argument('--device', default='cuda', help='device')
parser.add_argument('-b', '--batch-size', default=8, type=int)
parser.add_argument('--epochs', default=30, type=int, metavar='N',
help='number of total epochs to run')

parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
parser.add_argument('--output-dir', default='.', help='path where to save')
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument(
"--test-only",
dest="test_only",
help="Only test the model",
action="store_true",
)
# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
main(args)
Loading