From 39a7a2fb6015792b6c32fd7ee3eeac9774f712a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Fri, 14 Jun 2019 18:40:04 +0200 Subject: [PATCH] Add some utility functions to Cityscapes --- torchvision/datasets/cityscapes.py | 56 +++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/torchvision/datasets/cityscapes.py b/torchvision/datasets/cityscapes.py index 4c801577a0d..0d4a7053570 100644 --- a/torchvision/datasets/cityscapes.py +++ b/torchvision/datasets/cityscapes.py @@ -2,9 +2,11 @@ import os from collections import namedtuple -from .vision import VisionDataset +import torch from PIL import Image +from .vision import VisionDataset + class Cityscapes(VisionDataset): """`Cityscapes `_ Dataset. @@ -174,6 +176,58 @@ def __getitem__(self, index): def __len__(self): return len(self.images) + @staticmethod + def convert_id_to_train_id(target): + target_copy = target.clone() + + for cls in Cityscapes.classes: + target_copy[target == cls.id] = cls.train_id + + return target_copy + + @staticmethod + def convert_train_id_to_id(target): + target_copy = target.clone() + + for cls in Cityscapes.classes: + target_copy[target == cls.train_id] = cls.id + + return target_copy + + @staticmethod + def get_class_from_name(name): + for cls in Cityscapes.classes: + if cls.name == name: + return cls + return None + + @staticmethod + def get_class_from_id(id): + for cls in Cityscapes.classes: + if cls.id == id: + return cls + return None + + @staticmethod + def get_class_from_train_id(train_id): + for cls in Cityscapes.classes: + if cls.train_id == train_id: + return cls + return None + + @staticmethod + def get_colormap(): + cmap = torch.zeros([256, 3], dtype=torch.uint8) + + for cls in Cityscapes.classes: + cmap[cls.id, :] = torch.tensor(cls.color) + + return cmap + + @staticmethod + def num_classes(): + return len([cls for cls in Cityscapes.classes if not cls.ignore_in_eval]) + def extra_repr(self): lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"] return '\n'.join(lines).format(**self.__dict__)