diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py new file mode 100644 index 00000000000..c86d5495247 --- /dev/null +++ b/references/segmentation/coco_utils.py @@ -0,0 +1,111 @@ +import copy +import torch +import torch.utils.data +import torchvision +from PIL import Image + +import os + +from pycocotools import mask as coco_mask + +from transforms import Compose + + +class FilterAndRemapCocoCategories(object): + def __init__(self, categories, remap=True): + self.categories = categories + self.remap = remap + + def __call__(self, image, anno): + anno = [obj for obj in anno if obj["category_id"] in self.categories] + if not self.remap: + return image, anno + anno = copy.deepcopy(anno) + for obj in anno: + obj["category_id"] = self.categories.index(obj["category_id"]) + return image, anno + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __call__(self, image, anno): + w, h = image.size + segmentations = [obj["segmentation"] for obj in anno] + cats = [obj["category_id"] for obj in anno] + if segmentations: + masks = convert_coco_poly_to_mask(segmentations, h, w) + cats = torch.as_tensor(cats, dtype=masks.dtype) + # merge all instance masks into a single segmentation map + # with its corresponding categories + target, _ = (masks * cats[:, None, None]).max(dim=0) + # discard overlapping instances + target[masks.sum(0) > 1] = 255 + else: + target = torch.zeros((h, w), dtype=torch.uint8) + target = Image.fromarray(target.numpy()) + return image, target + + +def _coco_remove_images_without_annotations(dataset, cat_list=None): + def _has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if more than 1k pixels occupied in the image + return sum(obj["area"] for obj in anno) > 1000 + + assert isinstance(dataset, torchvision.datasets.CocoDetection) + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if cat_list: + anno = [obj for obj in anno if obj["category_id"] in cat_list] + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +def get_coco(root, image_set, transforms): + PATHS = { + "train": ("train2017", os.path.join("annotations", "instances_train2017.json")), + "val": ("val2017", os.path.join("annotations", "instances_val2017.json")), + # "train": ("val2017", os.path.join("annotations", "instances_val2017.json")) + } + CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, + 1, 64, 20, 63, 7, 72] + + transforms = Compose([ + FilterAndRemapCocoCategories(CAT_LIST, remap=True), + ConvertCocoPolysToMask(), + transforms + ]) + + img_folder, ann_file = PATHS[image_set] + img_folder = os.path.join(root, img_folder) + ann_file = os.path.join(root, ann_file) + + dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms) + + if image_set == "train": + dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST) + + return dataset diff --git a/references/segmentation/train.py b/references/segmentation/train.py new file mode 100644 index 00000000000..3f0327e04de --- /dev/null +++ b/references/segmentation/train.py @@ -0,0 +1,219 @@ +import datetime +import os +import time + +import torch +import torch.utils.data +from torch import nn +import torchvision + +from coco_utils import get_coco +import transforms as T +import utils + + +def get_dataset(name, image_set, transform): + def sbd(*args, **kwargs): + return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) + paths = { + "voc": ('/datasets01/VOC/060817/', torchvision.datasets.VOCSegmentation, 21), + "voc_aug": ('/datasets01/SBDD/072318/', sbd, 21), + "coco": ('/datasets01/COCO/022719/', get_coco, 21) + } + p, ds_fn, num_classes = paths[name] + + ds = ds_fn(p, image_set=image_set, transforms=transform) + return ds, num_classes + + +def get_transform(train): + base_size = 520 + crop_size = 480 + + min_size = int((0.5 if train else 1.0) * base_size) + max_size = int((2.0 if train else 1.0) * base_size) + transforms = [] + transforms.append(T.RandomResize(min_size, max_size)) + if train: + transforms.append(T.RandomHorizontalFlip(0.5)) + transforms.append(T.RandomCrop(crop_size)) + transforms.append(T.ToTensor()) + transforms.append(T.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])) + + return T.Compose(transforms) + + +def criterion(inputs, target): + losses = {} + for name, x in inputs.items(): + losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) + + if len(losses) == 1: + return losses['out'] + + return losses['out'] + 0.5 * losses['aux'] + + +def evaluate(model, data_loader, device, num_classes): + model.eval() + confmat = utils.ConfusionMatrix(num_classes) + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + with torch.no_grad(): + for image, target in metric_logger.log_every(data_loader, 100, header): + image, target = image.to(device), target.to(device) + output = model(image) + output = output['out'] + + confmat.update(target.flatten(), output.argmax(1).flatten()) + + confmat.reduce_from_all_processes() + + return confmat + + +def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) + header = 'Epoch: [{}]'.format(epoch) + for image, target in metric_logger.log_every(data_loader, print_freq, header): + image, target = image.to(device), target.to(device) + output = model(image) + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + lr_scheduler.step() + + metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) + + +def main(args): + if args.output_dir: + utils.mkdir(args.output_dir) + + utils.init_distributed_mode(args) + print(args) + + device = torch.device(args.device) + + dataset, num_classes = get_dataset(args.dataset, "train", get_transform(train=True)) + dataset_test, _ = get_dataset(args.dataset, "val", get_transform(train=False)) + + if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) + test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) + else: + train_sampler = torch.utils.data.RandomSampler(dataset) + test_sampler = torch.utils.data.SequentialSampler(dataset_test) + + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, + sampler=train_sampler, num_workers=args.workers, + collate_fn=utils.collate_fn, drop_last=True) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=1, + sampler=test_sampler, num_workers=args.workers, + collate_fn=utils.collate_fn) + + model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss) + model.to(device) + if args.distributed: + model = torch.nn.utils.convert_sync_batchnorm(model) + + if args.resume: + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint['model']) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + if args.test_only: + confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) + print(confmat) + return + + params_to_optimize = [ + {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, + {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, + ] + if args.aux_loss: + params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] + params_to_optimize.append({"params": params, "lr": args.lr * 10}) + optimizer = torch.optim.SGD( + params_to_optimize, + lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + lr_scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) + + start_time = time.time() + for epoch in range(args.epochs): + if args.distributed: + train_sampler.set_epoch(epoch) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq) + confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) + print(confmat) + utils.save_on_master( + { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'args': args + }, + os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='PyTorch Segmentation Training') + + parser.add_argument('--dataset', default='voc', help='dataset') + parser.add_argument('--model', default='fcn_resnet101', help='model') + parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') + parser.add_argument('--device', default='cuda', help='device') + parser.add_argument('-b', '--batch-size', default=8, type=int) + parser.add_argument('--epochs', default=30, type=int, metavar='N', + help='number of total epochs to run') + + parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', + help='number of data loading workers (default: 16)') + parser.add_argument('--lr', default=0.01, type=float, help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('--print-freq', default=10, type=int, help='print frequency') + parser.add_argument('--output-dir', default='.', help='path where to save') + parser.add_argument('--resume', default='', help='resume from checkpoint') + parser.add_argument( + "--test-only", + dest="test_only", + help="Only test the model", + action="store_true", + ) + # distributed training parameters + parser.add_argument('--world-size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py new file mode 100644 index 00000000000..bce4bfbe639 --- /dev/null +++ b/references/segmentation/transforms.py @@ -0,0 +1,92 @@ +import numpy as np +from PIL import Image +import random + +import torch +from torchvision import transforms as T +from torchvision.transforms import functional as F + + +def pad_if_smaller(img, size, fill=0): + min_size = min(img.size) + if min_size < size: + ow, oh = img.size + padh = size - oh if oh < size else 0 + padw = size - ow if ow < size else 0 + img = F.pad(img, (0, 0, padw, padh), fill=fill) + return img + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class RandomResize(object): + def __init__(self, min_size, max_size=None): + self.min_size = min_size + if max_size is None: + max_size = min_size + self.max_size = max_size + + def __call__(self, image, target): + size = random.randint(self.min_size, self.max_size) + image = F.resize(image, size) + target = F.resize(target, size, interpolation=Image.NEAREST) + return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, flip_prob): + self.flip_prob = flip_prob + + def __call__(self, image, target): + if random.random() < self.flip_prob: + image = F.hflip(image) + target = F.hflip(target) + return image, target + + +class RandomCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = pad_if_smaller(image, self.size) + target = pad_if_smaller(target, self.size, fill=255) + crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) + image = F.crop(image, *crop_params) + target = F.crop(target, *crop_params) + return image, target + + +class CenterCrop(object): + def __init__(self, size): + self.size = size + + def __call__(self, image, target): + image = F.center_crop(image, self.size) + target = F.center_crop(target, self.size) + return image, target + + +class ToTensor(object): + def __call__(self, image, target): + image = F.to_tensor(image) + target = torch.as_tensor(np.asarray(target), dtype=torch.int64) + return image, target + + +class Normalize(object): + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, image, target): + image = F.normalize(image, mean=self.mean, std=self.std) + return image, target diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py new file mode 100644 index 00000000000..8431cb47172 --- /dev/null +++ b/references/segmentation/utils.py @@ -0,0 +1,284 @@ +from __future__ import print_function +from collections import defaultdict, deque +import datetime +import math +import time +import torch +import torch.distributed as dist + +import errno +import os + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class ConfusionMatrix(object): + def __init__(self, num_classes): + self.num_classes = num_classes + self.mat = None + + def update(self, a, b): + n = self.num_classes + if self.mat is None: + self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + with torch.no_grad(): + k = (a >= 0) & (a < n) + inds = n * a[k].to(torch.int64) + b[k] + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + + def reset(self): + self.mat.zero_() + + def compute(self): + h = self.mat.float() + acc_global = torch.diag(h).sum() / h.sum() + acc = torch.diag(h) / h.sum(1) + iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) + return acc_global, acc, iu + + def reduce_from_all_processes(self): + if not torch.distributed.is_available(): + return + if not torch.distributed.is_initialized(): + return + torch.distributed.barrier() + torch.distributed.all_reduce(self.mat) + + def __str__(self): + acc_global, acc, iu = self.compute() + return ( + 'global correct: {:.1f}\n' + 'average row correct: {}\n' + 'IoU: {}\n' + 'mean IoU: {:.1f}').format( + acc_global.item() * 100, + ['{:.1f}'.format(i) for i in (acc * 100).tolist()], + ['{:.1f}'.format(i) for i in (iu * 100).tolist()], + iu.mean().item() * 100) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {}'.format(header, total_time_str)) + + +def cat_list(images, fill_value=0): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + batch_shape = (len(images),) + max_size + batched_imgs = images[0].new(*batch_shape).fill_(fill_value) + for img, pad_img in zip(images, batched_imgs): + pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img) + return batched_imgs + + +def collate_fn(batch): + images, targets = list(zip(*batch)) + batched_imgs = cat_list(images, fill_value=0) + batched_targets = cat_list(targets, fill_value=255) + return batched_imgs, batched_targets + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + setup_for_distributed(args.rank == 0) diff --git a/test/test_models.py b/test/test_models.py index 083ec42a6cc..586eeb258e9 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -5,13 +5,18 @@ import unittest -def get_available_models(): +def get_available_classification_models(): # TODO add a registration mechanism to torchvision.models - return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0]] + return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + + +def get_available_segmentation_models(): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.segmentation.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"] class Tester(unittest.TestCase): - def _test_model(self, name, input_shape): + def _test_classification_model(self, name, input_shape): # passing num_class equal to a number other than 1000 helps in making the test # more enforcing in nature model = models.__dict__[name](num_classes=50) @@ -20,6 +25,16 @@ def _test_model(self, name, input_shape): out = model(x) self.assertEqual(out.shape[-1], 50) + def _test_segmentation_model(self, name): + # passing num_class equal to a number other than 1000 helps in making the test + # more enforcing in nature + model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False) + model.eval() + input_shape = (1, 3, 300, 300) + x = torch.rand(input_shape) + out = model(x) + self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300)) + def _make_sliced_model(self, model, stop_layer): layers = OrderedDict() for name, layer in model.named_children(): @@ -41,14 +56,23 @@ def test_resnet_dilation(self): self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) -for model_name in get_available_models(): +for model_name in get_available_classification_models(): # for-loop bodies don't define scopes, so we have to save the variables # we want to close over in some way def do_test(self, model_name=model_name): input_shape = (1, 3, 224, 224) if model_name in ['inception_v3']: input_shape = (1, 3, 299, 299) - self._test_model(model_name, input_shape) + self._test_classification_model(model_name, input_shape) + + setattr(Tester, "test_" + model_name, do_test) + + +for model_name in get_available_segmentation_models(): + # for-loop bodies don't define scopes, so we have to save the variables + # we want to close over in some way + def do_test(self, model_name=model_name): + self._test_segmentation_model(model_name) setattr(Tester, "test_" + model_name, do_test) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 727aed44dfb..fb3d1747165 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -7,3 +7,4 @@ from .googlenet import * from .mobilenet import * from .shufflenetv2 import * +from . import segmentation diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py new file mode 100644 index 00000000000..227fae46dc0 --- /dev/null +++ b/torchvision/models/_utils.py @@ -0,0 +1,41 @@ +from collections import OrderedDict + +import torch +from torch import nn + + +# TODO should we remove the unused parameters or not? +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work + """ + def __init__(self, model, return_layers): + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + + orig_return_layers = return_layers + return_layers = {k: v for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.named_children(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out diff --git a/torchvision/models/deeplabv3.py b/torchvision/models/deeplabv3.py new file mode 100644 index 00000000000..d2ce18e62ed --- /dev/null +++ b/torchvision/models/deeplabv3.py @@ -0,0 +1,137 @@ +from collections import OrderedDict + +import torch +from torch import nn +from torch.nn import functional as F + + +class _SimpleSegmentationModel(nn.Module): + def __init__(self, backbone, classifier, aux_classifier=None): + super(_SimpleSegmentationModel, self).__init__() + self.backbone = backbone + self.classifier = classifier + self.aux_classifier = aux_classifier + + def forward(self, x): + input_shape = x.shape[-2:] + # contract: features is a dict of tensors + features = self.backbone(x) + + result = OrderedDict() + x = features["out"] + x = self.classifier(x) + x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + result["out"] = x + + if self.aux_classifier is not None: + x = features["aux"] + x = self.aux_classifier(x) + x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) + result["aux"] = x + + return result + + +class FCN(_SimpleSegmentationModel): + pass + + +class DeepLabV3(_SimpleSegmentationModel): + pass + + +class FCNHead(nn.Sequential): + def __init__(self, in_channels, channels): + inter_channels = in_channels // 4 + layers = [ + nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), + nn.BatchNorm2d(inter_channels), + nn.ReLU(), + nn.Dropout(0.1), + nn.Conv2d(inter_channels, channels, 1) + ] + + super(FCNHead, self).__init__(*layers) + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + """ + + +class DeepLabHead(nn.Sequential): + def __init__(self, in_channels, num_classes): + super(DeepLabHead, self).__init__( + ASPP(in_channels, [12, 24, 36]), + nn.Conv2d(256, 256, 3, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(), + nn.Conv2d(256, num_classes, 1) + ) + """ + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + """ + + +class ASPPConv(nn.Sequential): + def __init__(self, in_channels, out_channels, dilation): + modules = [ + nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU() + ] + super(ASPPConv, self).__init__(*modules) + + +class ASPPPooling(nn.Sequential): + def __init__(self, in_channels, out_channels): + super(ASPPPooling, self).__init__( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU()) + + def forward(self, x): + size = x.shape[-2:] + x = super(ASPPPooling, self).forward(x) + return F.interpolate(x, size=size, mode='bilinear', align_corners=False) + + +class ASPP(nn.Module): + def __init__(self, in_channels, atrous_rates): + super(ASPP, self).__init__() + out_channels = 256 + modules = [] + modules.append(nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU())) + + rate1, rate2, rate3 = tuple(atrous_rates) + modules.append(ASPPConv(in_channels, out_channels, rate1)) + modules.append(ASPPConv(in_channels, out_channels, rate2)) + modules.append(ASPPConv(in_channels, out_channels, rate3)) + modules.append(ASPPPooling(in_channels, out_channels)) + + self.convs = nn.ModuleList(modules) + + self.project = nn.Sequential( + nn.Conv2d(5 * out_channels, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Dropout(0.5)) + + def forward(self, x): + res = [] + for conv in self.convs: + res.append(conv(x)) + res = torch.cat(res, dim=1) + return self.project(res) diff --git a/torchvision/models/segmentation.py b/torchvision/models/segmentation.py new file mode 100644 index 00000000000..766df316b8c --- /dev/null +++ b/torchvision/models/segmentation.py @@ -0,0 +1,58 @@ +from ._utils import IntermediateLayerGetter +from . import resnet +from .deeplabv3 import FCN, FCNHead, DeepLabHead, DeepLabV3 + + +def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True): + backbone = resnet.__dict__[backbone_name]( + pretrained=pretrained_backbone, + replace_stride_with_dilation=[False, True, True]) + + return_layers = {'layer4': 'out'} + if aux: + return_layers['layer3'] = 'aux' + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + aux_classifier = None + if aux: + inplanes = 1024 + aux_classifier = FCNHead(inplanes, num_classes) + + model_map = { + 'deeplab': (DeepLabHead, DeepLabV3), + 'fcn': (FCNHead, FCN), + } + inplanes = 2048 + classifier = model_map[name][0](inplanes, num_classes) + base_model = model_map[name][1] + + model = base_model(backbone, classifier, aux_classifier) + return model + + +def fcn_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs): + model = _segm_resnet("fcn", "resnet50", num_classes, aux_loss, **kwargs) + if pretrained: + pass + return model + + +def fcn_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs): + model = _segm_resnet("fcn", "resnet101", num_classes, aux_loss, **kwargs) + if pretrained: + pass + return model + + +def deeplabv3_resnet50(pretrained=False, num_classes=21, aux_loss=None, **kwargs): + model = _segm_resnet("deeplab", "resnet50", num_classes, aux_loss, **kwargs) + if pretrained: + pass + return model + + +def deeplabv3_resnet101(pretrained=False, num_classes=21, aux_loss=None, **kwargs): + model = _segm_resnet("deeplab", "resnet101", num_classes, aux_loss, **kwargs) + if pretrained: + pass + return model