diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 768177b1c28..f14d0f0f1c0 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1454,3 +1454,24 @@ def usps(info, root, config): fh.write("\n".join(lines).encode()) return num_samples + + +@register_mock +def sbu(info, root, config): + num_samples = 10 + + dataset_folder = pathlib.Path(root).joinpath("dataset") + dataset_folder.mkdir(parents=True) + + photo_urls_file = "SBU_captioned_photo_dataset_urls.txt" + photo_captions_file = "SBU_captioned_photo_dataset_captions.txt" + with open(dataset_folder.joinpath(photo_urls_file), "w") as url_file, open( + dataset_folder.joinpath(photo_captions_file), "w" + ) as caption_file: + urls = [f"https://via.placeholder.com/{random.randint(100, 1000)}.jpg" for _ in range(num_samples)] + captions = [f"Caption {i} for {url}" for i, url in enumerate(urls)] + url_file.write("\n".join(urls)) + caption_file.write("\n".join(captions)) + + make_tar(root, "SBUCaptionedPhotoDataset.tar.gz", dataset_folder, compression="gz") + return num_samples diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 8d51125f41c..4b2e00b698c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -13,7 +13,7 @@ from torchvision._utils import sequence_to_str from torchvision.prototype import transforms, datasets from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE -from torchvision.prototype.features import Image, Label +from torchvision.prototype.features import EncodedImage, Label, Image assert_samples_equal = functools.partial( assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True @@ -213,3 +213,18 @@ def test_sample_content(self, test_home, dataset_mock, config): assert isinstance(sample["label"], Label) assert sample["image"].shape == (1, 16, 16) + + +@parametrize_dataset_mocks(DATASET_MOCKS["sbu"]) +class TestSBU: + def test_sample_content(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + + dataset = datasets.load(dataset_mock.name, **config) + + for sample in dataset: + assert "image" in sample + assert "caption" in sample + + assert isinstance(sample["image"], EncodedImage) + assert isinstance(sample["caption"], str) diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 1a8dc0907a4..fcde3eac9e3 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -14,6 +14,7 @@ from .oxford_iiit_pet import OxfordIITPet from .pcam import PCAM from .sbd import SBD +from .sbu import SBU from .semeion import SEMEION from .stanford_cars import StanfordCars from .svhn import SVHN diff --git a/torchvision/prototype/datasets/_builtin/sbu.py b/torchvision/prototype/datasets/_builtin/sbu.py new file mode 100644 index 00000000000..a5884a3f29d --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/sbu.py @@ -0,0 +1,100 @@ +import pathlib +import warnings +from typing import List, Any, Dict, Optional, Tuple, BinaryIO + +from torch.utils.model_zoo import tqdm +from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, LineReader, Zipper, Mapper, IterKeyZipper +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, + hint_shuffling, + INFINITE_BUFFER_SIZE, + path_accessor, +) +from torchvision.prototype.features import EncodedImage + + +class SBU(Dataset): + + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + name="sbu", + homepage="http://www.cs.virginia.edu/~vicente/sbucaptions/", + ) + + def _preprocess(self, path: pathlib.Path) -> pathlib.Path: + folder = OnlineResource._extract(path) + data_folder = folder / "dataset" + image_folder = data_folder / "images" + image_folder.mkdir() + broken_urls = [] + with open(data_folder / "SBU_captioned_photo_dataset_urls.txt") as fh: + urls = fh.read().splitlines() + + # TODO: Use workers to download images + for url in tqdm(urls): + try: + # TODO: suppress print statements within HttpResource.download() + HttpResource(url).download(image_folder) + except Exception: + broken_urls.append(url) + + if broken_urls: + broken_urls_file = folder.parent / "broken_urls.txt" + warnings.warn( + f"Failed to download {len(broken_urls)} ({len(broken_urls) / len(urls):.2%}) images. " + f"They are logged in {broken_urls_file}." + ) + with open(broken_urls_file, "w") as fh: + fh.write("\n".join(broken_urls) + "\n") + + return folder + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + return [ + HttpResource( + "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz", + sha256="2bf37d5e1c9e1c6eae7d5103030d58a7f2117fc5e8c6aa9620f0df165acebf09", + preprocess=self._preprocess, + ) + ] + + def _classify_files(self, data: Tuple[str, Any]) -> Optional[int]: + path = pathlib.Path(data[0]) + if path.parent.name == "images": + return 0 + elif path.name == "SBU_captioned_photo_dataset_urls.txt": + return 1 + elif path.name == "SBU_captioned_photo_dataset_captions.txt": + return 2 + else: + return None + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + + images_dp, urls_dp, captions_dp = Demultiplexer( + resource_dps[0], 3, self._classify_files, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + ) + + images_dp = hint_shuffling(images_dp) + images_dp = hint_sharding(images_dp) + + urls_dp = LineReader(urls_dp, decode=True, return_path=False) + captions_dp = LineReader(captions_dp, decode=True, return_path=False) + anns_dp = Zipper(urls_dp, captions_dp) + + dp = IterKeyZipper(images_dp, anns_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._prepare_sample) + + def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[str, str]]) -> Dict[str, Any]: + (path, buffer), (_, caption) = data + return dict( + path=path, + image=EncodedImage.from_file(buffer), + caption=caption.strip(), + )