Skip to content

Commit 4946827

Browse files
NicolasHugpmeier
andauthored
Add support for PCAM dataset (#5203)
* Add support for PCAM dataset * mypy * Apply suggestions from code review Co-authored-by: Philip Meier <[email protected]> * Remove classes and class_to_idx attributes * Use _decompress Co-authored-by: Philip Meier <[email protected]>
1 parent 5e56575 commit 4946827

File tree

8 files changed

+160
-2
lines changed

8 files changed

+160
-2
lines changed

.circleci/unittest/linux/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- libpng
1010
- jpeg
1111
- ca-certificates
12+
- h5py
1213
- pip:
1314
- future
1415
- pillow >=5.3.0, !=8.3.*

.circleci/unittest/windows/scripts/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies:
99
- libpng
1010
- jpeg
1111
- ca-certificates
12+
- h5py
1213
- pip:
1314
- future
1415
- pillow >=5.3.0, !=8.3.*

docs/source/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ You can also create your own datasets using the provided :ref:`base classes <bas
6666
MNIST
6767
Omniglot
6868
OxfordIIITPet
69+
PCAM
6970
PhotoTour
7071
Places365
7172
QMNIST

test/datasets_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class LazyImporter:
6161
"requests",
6262
"scipy.io",
6363
"scipy.sparse",
64+
"h5py",
6465
)
6566

6667
def __init__(self):

test/test_datasets.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,5 +2577,28 @@ def inject_fake_data(self, tmpdir: str, config):
25772577
return num_images_per_split[config["split"]]
25782578

25792579

2580+
class PCAMTestCase(datasets_utils.ImageDatasetTestCase):
2581+
DATASET_CLASS = datasets.PCAM
2582+
2583+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val", "test"))
2584+
REQUIRED_PACKAGES = ("h5py",)
2585+
2586+
def inject_fake_data(self, tmpdir: str, config):
2587+
base_folder = pathlib.Path(tmpdir) / "pcam"
2588+
base_folder.mkdir()
2589+
2590+
num_images = {"train": 2, "test": 3, "val": 4}[config["split"]]
2591+
2592+
images_file = datasets.PCAM._FILES[config["split"]]["images"][0]
2593+
with datasets_utils.lazy_importer.h5py.File(str(base_folder / images_file), "w") as f:
2594+
f["x"] = np.random.randint(0, 256, size=(num_images, 10, 10, 3), dtype=np.uint8)
2595+
2596+
targets_file = datasets.PCAM._FILES[config["split"]]["targets"][0]
2597+
with datasets_utils.lazy_importer.h5py.File(str(base_folder / targets_file), "w") as f:
2598+
f["y"] = np.random.randint(0, 2, size=(num_images, 1, 1, 1), dtype=np.uint8)
2599+
2600+
return num_images
2601+
2602+
25802603
if __name__ == "__main__":
25812604
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST, QMNIST
2626
from .omniglot import Omniglot
2727
from .oxford_iiit_pet import OxfordIIITPet
28+
from .pcam import PCAM
2829
from .phototour import PhotoTour
2930
from .places365 import Places365
3031
from .sbd import SBDataset

torchvision/datasets/oxford_iiit_pet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class OxfordIIITPet(VisionDataset):
2727
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
2828
version. E.g, ``transforms.RandomCrop``.
2929
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30-
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/dtd``. If
31-
dataset is already downloaded, it is not downloaded again.
30+
download (bool, optional): If True, downloads the dataset from the internet and puts it into
31+
``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
3232
"""
3333

3434
_RESOURCES = (

torchvision/datasets/pcam.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import pathlib
2+
from typing import Any, Callable, Optional, Tuple
3+
4+
from PIL import Image
5+
6+
from .utils import download_file_from_google_drive, _decompress, verify_str_arg
7+
from .vision import VisionDataset
8+
9+
10+
class PCAM(VisionDataset):
11+
"""`PCAM Dataset <https://github.com/basveeling/pcam>`_.
12+
13+
The PatchCamelyon dataset is a binary classification dataset with 327,680
14+
color images (96px x 96px), extracted from histopathologic scans of lymph node
15+
sections. Each image is annotated with a binary label indicating presence of
16+
metastatic tissue.
17+
18+
This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
19+
20+
Args:
21+
root (string): Root directory of the dataset.
22+
split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
23+
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
24+
version. E.g, ``transforms.RandomCrop``.
25+
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
26+
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
27+
dataset is already downloaded, it is not downloaded again.
28+
"""
29+
30+
_FILES = {
31+
"train": {
32+
"images": (
33+
"camelyonpatch_level_2_split_train_x.h5", # Data file name
34+
"1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID
35+
"1571f514728f59376b705fc836ff4b63", # md5 hash
36+
),
37+
"targets": (
38+
"camelyonpatch_level_2_split_train_y.h5",
39+
"1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
40+
"35c2d7259d906cfc8143347bb8e05be7",
41+
),
42+
},
43+
"test": {
44+
"images": (
45+
"camelyonpatch_level_2_split_test_x.h5",
46+
"1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
47+
"d5b63470df7cfa627aeec8b9dc0c066e",
48+
),
49+
"targets": (
50+
"camelyonpatch_level_2_split_test_y.h5",
51+
"17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
52+
"2b85f58b927af9964a4c15b8f7e8f179",
53+
),
54+
},
55+
"val": {
56+
"images": (
57+
"camelyonpatch_level_2_split_valid_x.h5",
58+
"1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
59+
"d8c2d60d490dbd479f8199bdfa0cf6ec",
60+
),
61+
"targets": (
62+
"camelyonpatch_level_2_split_valid_y.h5",
63+
"1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
64+
"60a7035772fbdb7f34eb86d4420cf66a",
65+
),
66+
},
67+
}
68+
69+
def __init__(
70+
self,
71+
root: str,
72+
split: str = "train",
73+
transform: Optional[Callable] = None,
74+
target_transform: Optional[Callable] = None,
75+
download: bool = True,
76+
):
77+
try:
78+
import h5py # type: ignore[import]
79+
80+
self.h5py = h5py
81+
except ImportError:
82+
raise RuntimeError(
83+
"h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
84+
)
85+
86+
self._split = verify_str_arg(split, "split", ("train", "test", "val"))
87+
88+
super().__init__(root, transform=transform, target_transform=target_transform)
89+
self._base_folder = pathlib.Path(self.root) / "pcam"
90+
91+
if download:
92+
self._download()
93+
94+
if not self._check_exists():
95+
raise RuntimeError("Dataset not found. You can use download=True to download it")
96+
97+
def __len__(self) -> int:
98+
images_file = self._FILES[self._split]["images"][0]
99+
with self.h5py.File(self._base_folder / images_file) as images_data:
100+
return images_data["x"].shape[0]
101+
102+
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
103+
images_file = self._FILES[self._split]["images"][0]
104+
with self.h5py.File(self._base_folder / images_file) as images_data:
105+
image = Image.fromarray(images_data["x"][idx]).convert("RGB")
106+
107+
targets_file = self._FILES[self._split]["targets"][0]
108+
with self.h5py.File(self._base_folder / targets_file) as targets_data:
109+
target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1]
110+
111+
if self.transform:
112+
image = self.transform(image)
113+
if self.target_transform:
114+
target = self.target_transform(target)
115+
116+
return image, target
117+
118+
def _check_exists(self) -> bool:
119+
images_file = self._FILES[self._split]["images"][0]
120+
targets_file = self._FILES[self._split]["targets"][0]
121+
return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
122+
123+
def _download(self) -> None:
124+
if self._check_exists():
125+
return
126+
127+
for file_name, file_id, md5 in self._FILES[self._split].values():
128+
archive_name = file_name + ".gz"
129+
download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
130+
_decompress(str(self._base_folder / archive_name))

0 commit comments

Comments
 (0)