1
- from typing import Any , Dict , List
1
+ import pathlib
2
+ from typing import Any , Dict , List , Union
2
3
3
4
import torch
4
5
from torchdata .datapipes .iter import IterDataPipe , LineReader , Mapper , Decompressor
5
- from torchvision .prototype .datasets .utils import Dataset , DatasetInfo , DatasetConfig , OnlineResource , HttpResource
6
+ from torchvision .prototype .datasets .utils import Dataset2 , OnlineResource , HttpResource
6
7
from torchvision .prototype .datasets .utils ._internal import hint_sharding , hint_shuffling
7
8
from torchvision .prototype .features import Image , Label
8
9
10
+ from .._api import register_dataset , register_info
9
11
10
- class USPS (Dataset ):
11
- def _make_info (self ) -> DatasetInfo :
12
- return DatasetInfo (
13
- "usps" ,
14
- homepage = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps" ,
15
- valid_options = dict (
16
- split = ("train" , "test" ),
17
- ),
18
- categories = 10 ,
19
- )
12
+ NAME = "usps"
13
+
14
+
15
+ @register_info (NAME )
16
+ def _info () -> Dict [str , Any ]:
17
+ return dict (categories = [str (c ) for c in range (10 )])
18
+
19
+
20
+ @register_dataset (NAME )
21
+ class USPS (Dataset2 ):
22
+ """USPS Dataset
23
+ homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps",
24
+ """
25
+
26
+ def __init__ (
27
+ self ,
28
+ root : Union [str , pathlib .Path ],
29
+ * ,
30
+ split : str = "train" ,
31
+ skip_integrity_check : bool = False ,
32
+ ) -> None :
33
+ self ._split = self ._verify_str_arg (split , "split" , {"train" , "test" })
34
+
35
+ self ._categories = _info ()["categories" ]
36
+ super ().__init__ (root , skip_integrity_check = skip_integrity_check )
20
37
21
38
_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass"
22
39
@@ -29,26 +46,24 @@ def _make_info(self) -> DatasetInfo:
29
46
),
30
47
}
31
48
32
- def resources (self , config : DatasetConfig ) -> List [OnlineResource ]:
33
- return [USPS ._RESOURCES [config . split ]]
49
+ def _resources (self ) -> List [OnlineResource ]:
50
+ return [USPS ._RESOURCES [self . _split ]]
34
51
35
52
def _prepare_sample (self , line : str ) -> Dict [str , Any ]:
36
53
label , * values = line .strip ().split (" " )
37
54
values = [float (value .split (":" )[1 ]) for value in values ]
38
55
pixels = torch .tensor (values ).add_ (1 ).div_ (2 )
39
56
return dict (
40
57
image = Image (pixels .reshape (16 , 16 )),
41
- label = Label (int (label ) - 1 , categories = self .categories ),
58
+ label = Label (int (label ) - 1 , categories = self ._categories ),
42
59
)
43
60
44
- def _make_datapipe (
45
- self ,
46
- resource_dps : List [IterDataPipe ],
47
- * ,
48
- config : DatasetConfig ,
49
- ) -> IterDataPipe [Dict [str , Any ]]:
61
+ def _datapipe (self , resource_dps : List [IterDataPipe ]) -> IterDataPipe [Dict [str , Any ]]:
50
62
dp = Decompressor (resource_dps [0 ])
51
63
dp = LineReader (dp , decode = True , return_path = False )
52
64
dp = hint_shuffling (dp )
53
65
dp = hint_sharding (dp )
54
66
return Mapper (dp , self ._prepare_sample )
67
+
68
+ def __len__ (self ) -> int :
69
+ return 7_291 if self ._split == "train" else 2_007
0 commit comments