Skip to content

Commit 140322f

Browse files
krshrimalipmeierdatumbox
authored
Port semeion dataset to prototype namespace (#4840)
* Port semeion dataset * Update torchvision/prototype/datasets/_builtin/semeion.py Co-authored-by: Philip Meier <[email protected]> * explicitly convert the image array to torch.uint8 * explicitly convert the image array to torch.uint8 Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent e8ceaaf commit 140322f

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

torchvision/prototype/datasets/_builtin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .imagenet import ImageNet
66
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
77
from .sbd import SBD
8+
from .semeion import SEMEION
89
from .voc import VOC
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import io
2+
from typing import Any, Callable, Dict, List, Optional, Tuple
3+
4+
import torch
5+
from torchdata.datapipes.iter import (
6+
IterDataPipe,
7+
Mapper,
8+
Shuffler,
9+
CSVParser,
10+
)
11+
from torchvision.prototype.datasets.decoder import raw
12+
from torchvision.prototype.datasets.utils import (
13+
Dataset,
14+
DatasetConfig,
15+
DatasetInfo,
16+
HttpResource,
17+
OnlineResource,
18+
DatasetType,
19+
)
20+
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, image_buffer_from_array
21+
22+
23+
class SEMEION(Dataset):
24+
def _make_info(self) -> DatasetInfo:
25+
return DatasetInfo(
26+
"semeion",
27+
type=DatasetType.RAW,
28+
categories=10,
29+
homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit",
30+
)
31+
32+
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
33+
archive = HttpResource(
34+
"http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data",
35+
sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1",
36+
)
37+
return [archive]
38+
39+
def _collate_and_decode_sample(
40+
self,
41+
data: Tuple[str, ...],
42+
*,
43+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
44+
) -> Dict[str, Any]:
45+
image_data = torch.tensor([float(pixel) for pixel in data[:256]], dtype=torch.uint8).reshape(16, 16)
46+
label_data = [int(label) for label in data[256:] if label]
47+
48+
if decoder is raw:
49+
image = image_data.unsqueeze(0)
50+
else:
51+
image_buffer = image_buffer_from_array(image_data.numpy())
52+
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
53+
54+
label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
55+
category = self.info.categories[label]
56+
return dict(image=image, label=label, category=category)
57+
58+
def _make_datapipe(
59+
self,
60+
resource_dps: List[IterDataPipe],
61+
*,
62+
config: DatasetConfig,
63+
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
64+
) -> IterDataPipe[Dict[str, Any]]:
65+
dp = resource_dps[0]
66+
dp = CSVParser(dp, delimiter=" ")
67+
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
68+
dp = Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
69+
return dp

0 commit comments

Comments
 (0)