1
1
import os
2
2
import numpy as np
3
3
from PIL import Image
4
+ from typing import Any , Callable , List , Optional , Tuple , Union
4
5
5
6
import torch
6
7
from .vision import VisionDataset
@@ -54,26 +55,28 @@ class PhotoTour(VisionDataset):
54
55
'fdd9152f138ea5ef2091746689176414'
55
56
],
56
57
}
57
- mean = {'notredame' : 0.4854 , 'yosemite' : 0.4844 , 'liberty' : 0.4437 ,
58
- 'notredame_harris' : 0.4854 , 'yosemite_harris' : 0.4844 , 'liberty_harris' : 0.4437 }
59
- std = {'notredame' : 0.1864 , 'yosemite' : 0.1818 , 'liberty' : 0.2019 ,
60
- 'notredame_harris' : 0.1864 , 'yosemite_harris' : 0.1818 , 'liberty_harris' : 0.2019 }
58
+ means = {'notredame' : 0.4854 , 'yosemite' : 0.4844 , 'liberty' : 0.4437 ,
59
+ 'notredame_harris' : 0.4854 , 'yosemite_harris' : 0.4844 , 'liberty_harris' : 0.4437 }
60
+ stds = {'notredame' : 0.1864 , 'yosemite' : 0.1818 , 'liberty' : 0.2019 ,
61
+ 'notredame_harris' : 0.1864 , 'yosemite_harris' : 0.1818 , 'liberty_harris' : 0.2019 }
61
62
lens = {'notredame' : 468159 , 'yosemite' : 633587 , 'liberty' : 450092 ,
62
63
'liberty_harris' : 379587 , 'yosemite_harris' : 450912 , 'notredame_harris' : 325295 }
63
64
image_ext = 'bmp'
64
65
info_file = 'info.txt'
65
66
matches_files = 'm50_100000_100000_0.txt'
66
67
67
- def __init__ (self , root , name , train = True , transform = None , download = False ):
68
+ def __init__ (
69
+ self , root : str , name : str , train : bool = True , transform : Optional [Callable ] = None , download : bool = False
70
+ ) -> None :
68
71
super (PhotoTour , self ).__init__ (root , transform = transform )
69
72
self .name = name
70
73
self .data_dir = os .path .join (self .root , name )
71
74
self .data_down = os .path .join (self .root , '{}.zip' .format (name ))
72
75
self .data_file = os .path .join (self .root , '{}.pt' .format (name ))
73
76
74
77
self .train = train
75
- self .mean = self .mean [name ]
76
- self .std = self .std [name ]
78
+ self .mean = self .means [name ]
79
+ self .std = self .stds [name ]
77
80
78
81
if download :
79
82
self .download ()
@@ -85,7 +88,7 @@ def __init__(self, root, name, train=True, transform=None, download=False):
85
88
# load the serialized data
86
89
self .data , self .labels , self .matches = torch .load (self .data_file )
87
90
88
- def __getitem__ (self , index ) :
91
+ def __getitem__ (self , index : int ) -> Union [ torch . Tensor , Tuple [ Any , Any , torch . Tensor ]] :
89
92
"""
90
93
Args:
91
94
index (int): Index
@@ -105,18 +108,18 @@ def __getitem__(self, index):
105
108
data2 = self .transform (data2 )
106
109
return data1 , data2 , m [2 ]
107
110
108
- def __len__ (self ):
111
+ def __len__ (self ) -> int :
109
112
if self .train :
110
113
return self .lens [self .name ]
111
114
return len (self .matches )
112
115
113
- def _check_datafile_exists (self ):
116
+ def _check_datafile_exists (self ) -> bool :
114
117
return os .path .exists (self .data_file )
115
118
116
- def _check_downloaded (self ):
119
+ def _check_downloaded (self ) -> bool :
117
120
return os .path .exists (self .data_dir )
118
121
119
- def download (self ):
122
+ def download (self ) -> None :
120
123
if self ._check_datafile_exists ():
121
124
print ('# Found cached data {}' .format (self .data_file ))
122
125
return
@@ -150,20 +153,20 @@ def download(self):
150
153
with open (self .data_file , 'wb' ) as f :
151
154
torch .save (dataset , f )
152
155
153
- def extra_repr (self ):
156
+ def extra_repr (self ) -> str :
154
157
return "Split: {}" .format ("Train" if self .train is True else "Test" )
155
158
156
159
157
- def read_image_file (data_dir , image_ext , n ) :
160
+ def read_image_file (data_dir : str , image_ext : str , n : int ) -> torch . Tensor :
158
161
"""Return a Tensor containing the patches
159
162
"""
160
163
161
- def PIL2array (_img ) :
164
+ def PIL2array (_img : Image . Image ) -> np . ndarray :
162
165
"""Convert PIL image type to numpy 2D array
163
166
"""
164
167
return np .array (_img .getdata (), dtype = np .uint8 ).reshape (64 , 64 )
165
168
166
- def find_files (_data_dir , _image_ext ) :
169
+ def find_files (_data_dir : str , _image_ext : str ) -> List [ str ] :
167
170
"""Return a list with the file names of the images containing the patches
168
171
"""
169
172
files = []
@@ -185,7 +188,7 @@ def find_files(_data_dir, _image_ext):
185
188
return torch .ByteTensor (np .array (patches [:n ]))
186
189
187
190
188
- def read_info_file (data_dir , info_file ) :
191
+ def read_info_file (data_dir : str , info_file : str ) -> torch . Tensor :
189
192
"""Return a Tensor containing the list of labels
190
193
Read the file and keep only the ID of the 3D point.
191
194
"""
@@ -195,7 +198,7 @@ def read_info_file(data_dir, info_file):
195
198
return torch .LongTensor (labels )
196
199
197
200
198
- def read_matches_files (data_dir , matches_file ) :
201
+ def read_matches_files (data_dir : str , matches_file : str ) -> torch . Tensor :
199
202
"""Return a Tensor containing the ground truth matches
200
203
Read the file and keep only 3D point ID.
201
204
Matches are represented with a 1, non matches with a 0.
0 commit comments