|
| 1 | +import io |
| 2 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
| 3 | + |
| 4 | +import torch |
| 5 | +from torchdata.datapipes.iter import ( |
| 6 | + IterDataPipe, |
| 7 | + Mapper, |
| 8 | + Shuffler, |
| 9 | + CSVParser, |
| 10 | +) |
| 11 | +from torchvision.prototype.datasets.decoder import raw |
| 12 | +from torchvision.prototype.datasets.utils import ( |
| 13 | + Dataset, |
| 14 | + DatasetConfig, |
| 15 | + DatasetInfo, |
| 16 | + HttpResource, |
| 17 | + OnlineResource, |
| 18 | + DatasetType, |
| 19 | +) |
| 20 | +from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array |
| 21 | + |
| 22 | + |
| 23 | +class SEMEION(Dataset): |
| 24 | + def _make_info(self) -> DatasetInfo: |
| 25 | + return DatasetInfo( |
| 26 | + "semeion", |
| 27 | + type=DatasetType.RAW, |
| 28 | + categories=10, |
| 29 | + homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", |
| 30 | + ) |
| 31 | + |
| 32 | + def resources(self, config: DatasetConfig) -> List[OnlineResource]: |
| 33 | + archive = HttpResource( |
| 34 | + "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", |
| 35 | + sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", |
| 36 | + ) |
| 37 | + return [archive] |
| 38 | + |
| 39 | + def _collate_and_decode_sample( |
| 40 | + self, |
| 41 | + data: Tuple[str, ...], |
| 42 | + *, |
| 43 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 44 | + ) -> Dict[str, Any]: |
| 45 | + image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16) |
| 46 | + label_data = [int(label) for label in data[256:] if label] |
| 47 | + |
| 48 | + if decoder is raw: |
| 49 | + image = image_data.unsqueeze(0) |
| 50 | + else: |
| 51 | + image_buffer = image_buffer_from_array(image_data.numpy()) |
| 52 | + image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] |
| 53 | + |
| 54 | + label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) |
| 55 | + category = self.info.categories[label] |
| 56 | + return dict(image=image, label=label, category=category) |
| 57 | + |
| 58 | + def _make_datapipe( |
| 59 | + self, |
| 60 | + resource_dps: List[IterDataPipe], |
| 61 | + *, |
| 62 | + config: DatasetConfig, |
| 63 | + decoder: Optional[Callable[[io.IOBase], torch.Tensor]], |
| 64 | + ) -> IterDataPipe[Dict[str, Any]]: |
| 65 | + dp = resource_dps[0] |
| 66 | + dp = CSVParser(dp, delimiter=" ") |
| 67 | + dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) |
| 68 | + dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) |
| 69 | + return dp |
0 commit comments