@@ -84,14 +84,17 @@ class AnnotationTransform(object):
84
84
(default: alphabetic indexing of VOC's 20 classes)
85
85
keep_difficult (bool, optional): keep difficult instances or not
86
86
(default: False)
87
+ channels (int): number of channels
88
+ height (int): height
89
+ width (int): width
87
90
"""
88
91
89
92
def __init__ (self , class_to_ind = None , keep_difficult = False ):
90
93
self .class_to_ind = class_to_ind or dict (
91
94
zip (VOC_CLASSES , range (len (VOC_CLASSES ))))
92
95
self .keep_difficult = keep_difficult
93
96
94
- def __call__ (self , target ):
97
+ def __call__ (self , target , channels , height , width ):
95
98
"""
96
99
Arguments:
97
100
target (annotation) : the target annotation to be made usable
@@ -108,13 +111,18 @@ def __call__(self, target):
108
111
bbox = obj .find ('bndbox' )
109
112
110
113
# [xmin, ymin, xmax, ymax]
111
- bndbox = [int (bb .text ) - 1 for bb in bbox ]
114
+ bndbox = []
115
+ for i , cur_bb in enumerate (bbox ):
116
+ bb_sz = int (cur_bb .text ) - 1
117
+ bb_sz = bb_sz / width if i % 2 == 0 else bb_sz / height # scale height or width
118
+ bndbox .append (bb_sz )
119
+
112
120
label_ind = self .class_to_ind [name ]
113
121
bndbox .append (label_ind )
114
122
res += [bndbox ] # [xmin, ymin, xmax, ymax, ind]
115
123
116
124
return res # [[xmin, ymin, xmax, ymax, ind], ... ]
117
- # torch.Tensor(res)
125
+
118
126
119
127
class VOCDetection (data .Dataset ):
120
128
"""VOC Detection Dataset Object
@@ -169,8 +177,7 @@ def __len__(self):
169
177
return len (self .ids )
170
178
171
179
def show (self , index , subparts = False ):
172
- '''Shows an image with its ground truth boxes overlaid
173
- optionally
180
+ '''Shows an image with its ground truth boxes overlaid optionally
174
181
175
182
Note: not using self.__getitem__(), as any transformations passed in
176
183
could mess up this functionality.
@@ -179,18 +186,19 @@ def show(self, index, subparts=False):
179
186
index (int): index of img to show
180
187
subparts (bool, optional): whether or not to display subpart
181
188
bboxes of ground truths
189
+ (default: False)
182
190
'''
183
191
img_id = self .ids [index ]
184
192
target = ET .parse (self ._annopath % img_id ).getroot ()
185
193
img = Image .open (self ._imgpath % img_id ).convert ('RGB' )
186
194
draw = ImageDraw .Draw (img )
187
195
i = 0
188
196
bndboxs = []
189
- classes = dict ()
197
+ classes = dict () # maps class name to a class number
190
198
for obj in target .iter ('object' ):
191
199
bbox = obj .find ('bndbox' )
192
200
name = obj .find ('name' ).text .lower ().strip ()
193
- if not name in classes :
201
+ if name not in classes :
194
202
classes [name ] = i
195
203
i += 1
196
204
bndboxs .append ((name , [int (bb .text ) - 1 for bb in bbox ]))
@@ -199,7 +207,7 @@ def show(self, index, subparts=False):
199
207
name = part .find ('name' ).text .lower ().strip ()
200
208
bbox = part .find ('bndbox' )
201
209
bndboxs .append ((name , [int (bb .text ) - 1 for bb in bbox ]))
202
- if not name in classes :
210
+ if name not in classes :
203
211
classes [name ] = i
204
212
i += 1
205
213
for name , bndbox in bndboxs :
@@ -209,14 +217,27 @@ def show(self, index, subparts=False):
209
217
img .show ()
210
218
return img
211
219
220
+
212
221
def detection_collate (batch ):
222
+ """Custom collate fn for dealing with batches of images that have a different
223
+ number of associated object annotations (bounding boxes).
224
+
225
+ Arguments:
226
+ batch: (tuple) A tuple of tensor images and lists of annotations
227
+
228
+ Return:
229
+ A tuple containing:
230
+ 1) (tensor) batch of images stacked on their 0 dim
231
+ 2) (list of tensors) annotations for a given image are stacked on 0 dim
232
+ """
213
233
targets = []
214
234
imgs = []
215
- for i , sample in enumerate (batch ):
216
- for j , tup in enumerate (sample ):
235
+ for _ , sample in enumerate (batch ):
236
+ for _ , tup in enumerate (sample ):
217
237
if torch .is_tensor (tup ):
218
238
imgs .append (tup )
219
239
elif isinstance (tup , type ([])):
220
- targets .append ([torch .Tensor (x ) for x in tup ])
240
+ annos = [torch .Tensor (a ) for a in tup ]
241
+ targets .append (torch .stack (annos , 0 ))
221
242
222
243
return (torch .stack (imgs , 0 ), targets )
0 commit comments