From b935c0279bbcfe3e8de7bc3ea83a9e349caf7826 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 10 May 2019 19:43:37 -0500 Subject: [PATCH 1/5] Add ssTEM dataset --- docs/source/datasets.rst | 6 ++ torchvision/datasets/__init__.py | 1 + torchvision/datasets/sstem.py | 134 +++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+) create mode 100644 torchvision/datasets/sstem.py diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 88260d4c018..82fff063edb 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -188,3 +188,9 @@ SBD .. autoclass:: SBDataset :members: __getitem__ :special-members: + +ssTEM +~~~~~ + +.. note :: + This requires `skimage` to be installed diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index b3ce206f330..c2ec39cdc0b 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -18,6 +18,7 @@ from .celeba import CelebA from .sbd import SBDataset from .vision import VisionDataset +from .sstem import ssTEM __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'DatasetFolder', 'FakeData', diff --git a/torchvision/datasets/sstem.py b/torchvision/datasets/sstem.py new file mode 100644 index 00000000000..d99dad6fc23 --- /dev/null +++ b/torchvision/datasets/sstem.py @@ -0,0 +1,134 @@ +from __future__ import print_function + +import os + +from PIL import Image +import skimage +import torch.utils.data as data + +from .utils import download_url, check_integrity + + +class ssTEM(data.Dataset): + """Dataset for `ISBI Challenge: Segmentation of neuronal structures + in EM stacks `_. + + Args: + root (string): Root directory where dataset exists or will be saved + to if download is set to True. + train (bool, optional): If True, creates dataset from training set, + otherwise creates from test set. + transform (callable, optional): A function/transform that takes in a + PIL image and returns a transformed version. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + download (bool, optional): If true, downloads the dataset from the + internet and puts it in root directory. If dataset is already + downloaded, it is not downloaded again. + """ + base_url = 'http://brainiac2.mit.edu/isbi_challenge/sites/default/files/' + + meta = { + 'train': { + 'data': { + 'filename': 'train-volume.tif', + 'md5': '465461edbe0254630c4ec5577f1e7764' + }, + 'labels': { + 'filename': 'train-labels.tif', + 'md5': '657fe6b728c6dd0152e295c6d800001d' + } + }, + 'test': { + 'data': { + 'filename': 'test-volume.tif', + 'md5': '9767660d7abe4e0ecbdd0061a16058ad' + } + } + } + + def __init__(self, root, train=True, transform=None, target_transform=None, + download=False): + self.root = os.path.expanduser(root) + self.train = train + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + if train: + self.data = skimage.io.imread(os.path.join( + self.root, self.meta['train']['data']['filename'])) + self.labels = skimage.io.imread(os.path.join( + self.root, self.meta['train']['labels']['filename'])) + else: + self.data = skimage.io.imread(os.path.join( + self.root, self.meta['test']['data']['filename'])) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + img = self.data[index] + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if not self.train: + return img + + target = self.labels[index] + target = Image.fromarray(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + def _check_integrity(self): + if self.train: + dataset = 'train' + records = ['data', 'labels'] + else: + dataset = 'test' + records = ['data'] + + for record in records: + filename = self.meta[dataset][record]['filename'] + fpath = os.path.join(self.root, filename) + md5 = self.meta[dataset][record]['md5'] + + if not check_integrity(fpath, md5): + return False + + return True + + def download(self): + if self._check_integrity(): + print('Files already downloaded and verified') + return + + if self.train: + dataset = 'train' + records = ['data', 'labels'] + else: + dataset = 'test' + records = ['data'] + + for record in records: + filename = self.meta[dataset][record]['filename'] + md5 = self.meta[dataset][record]['md5'] + + download_url(self.base_url + filename, self.root, filename, md5) From a8fc5ec5f7a17afff63c389062114b24cdb259aa Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 10 May 2019 19:48:53 -0500 Subject: [PATCH 2/5] Documentation fix --- docs/source/datasets.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst index 82fff063edb..71fe2f8b269 100644 --- a/docs/source/datasets.rst +++ b/docs/source/datasets.rst @@ -194,3 +194,7 @@ ssTEM .. note :: This requires `skimage` to be installed + +.. autoclass:: ssTEM + :members: __getitem__ + :special-members: From 625f5467c70302486efb749ea1f28818dabbd081 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sat, 11 May 2019 16:05:20 -0500 Subject: [PATCH 3/5] Changes requested during review --- torchvision/datasets/sstem.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/torchvision/datasets/sstem.py b/torchvision/datasets/sstem.py index d99dad6fc23..91d59890387 100644 --- a/torchvision/datasets/sstem.py +++ b/torchvision/datasets/sstem.py @@ -3,13 +3,13 @@ import os from PIL import Image -import skimage import torch.utils.data as data from .utils import download_url, check_integrity +from .vision import VisionDataset -class ssTEM(data.Dataset): +class ssTEM(VisionDataset): """Dataset for `ISBI Challenge: Segmentation of neuronal structures in EM stacks `_. @@ -18,10 +18,8 @@ class ssTEM(data.Dataset): to if download is set to True. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. - transform (callable, optional): A function/transform that takes in a - PIL image and returns a transformed version. - target_transform (callable, optional): A function/transform that takes - in the target and transforms it. + transforms (callable, optional): A function/transform that takes in + two PIL images and returns transformed versions. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. @@ -47,12 +45,13 @@ class ssTEM(data.Dataset): } } - def __init__(self, root, train=True, transform=None, target_transform=None, - download=False): + def __init__(self, root, train=True, transforms=None, download=False): + # Lazy import + import skimage.io + self.root = os.path.expanduser(root) self.train = train - self.transform = transform - self.target_transform = target_transform + self.transforms = transforms if download: self.download() @@ -80,17 +79,13 @@ def __getitem__(self, index): img = self.data[index] img = Image.fromarray(img) - if self.transform is not None: - img = self.transform(img) - - if not self.train: - return img - - target = self.labels[index] - target = Image.fromarray(img) + target = None + if self.train: + target = self.labels[index] + target = Image.fromarray(img) - if self.target_transform is not None: - target = self.target_transform(target) + if self.transforms is not None: + img, target = self.transforms(img, target) return img, target From bd1dce26434468771eeaa670fd927bb5772e68a5 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 12 May 2019 08:21:37 -0500 Subject: [PATCH 4/5] Call __init__ of super class --- torchvision/datasets/sstem.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchvision/datasets/sstem.py b/torchvision/datasets/sstem.py index 91d59890387..8cc11af5ba3 100644 --- a/torchvision/datasets/sstem.py +++ b/torchvision/datasets/sstem.py @@ -49,6 +49,8 @@ def __init__(self, root, train=True, transforms=None, download=False): # Lazy import import skimage.io + super(ssTEM, self).__init__(root, transforms) + self.root = os.path.expanduser(root) self.train = train self.transforms = transforms From b36b410ec4dddaf70fba325686b1dd9c4f379dba Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Sun, 12 May 2019 18:21:24 -0500 Subject: [PATCH 5/5] Remove lines obsoleted by super class init --- torchvision/datasets/sstem.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/datasets/sstem.py b/torchvision/datasets/sstem.py index 8cc11af5ba3..83281dd1ab7 100644 --- a/torchvision/datasets/sstem.py +++ b/torchvision/datasets/sstem.py @@ -51,9 +51,7 @@ def __init__(self, root, train=True, transforms=None, download=False): super(ssTEM, self).__init__(root, transforms) - self.root = os.path.expanduser(root) self.train = train - self.transforms = transforms if download: self.download()