Skip to content

Commit 5293f47

Browse files
committed
added detection_collate docstring and refactored AnnotationTransform
1 parent 68c1da5 commit 5293f47

File tree

1 file changed

+32
-11
lines changed

1 file changed

+32
-11
lines changed

torchvision/datasets/voc.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,17 @@ class AnnotationTransform(object):
8484
(default: alphabetic indexing of VOC's 20 classes)
8585
keep_difficult (bool, optional): keep difficult instances or not
8686
(default: False)
87+
channels (int): number of channels
88+
height (int): height
89+
width (int): width
8790
"""
8891

8992
def __init__(self, class_to_ind=None, keep_difficult=False):
9093
self.class_to_ind = class_to_ind or dict(
9194
zip(VOC_CLASSES, range(len(VOC_CLASSES))))
9295
self.keep_difficult = keep_difficult
9396

94-
def __call__(self, target):
97+
def __call__(self, target, channels, height, width):
9598
"""
9699
Arguments:
97100
target (annotation) : the target annotation to be made usable
@@ -108,13 +111,18 @@ def __call__(self, target):
108111
bbox = obj.find('bndbox')
109112

110113
# [xmin, ymin, xmax, ymax]
111-
bndbox = [int(bb.text) - 1 for bb in bbox]
114+
bndbox = []
115+
for i, cur_bb in enumerate(bbox):
116+
bb_sz = int(cur_bb.text) - 1
117+
bb_sz = bb_sz/width if i%2 == 0 else bb_sz/height # scale height or width
118+
bndbox.append(bb_sz)
119+
112120
label_ind = self.class_to_ind[name]
113121
bndbox.append(label_ind)
114122
res += [bndbox] # [xmin, ymin, xmax, ymax, ind]
115123

116124
return res # [[xmin, ymin, xmax, ymax, ind], ... ]
117-
# torch.Tensor(res)
125+
118126

119127
class VOCDetection(data.Dataset):
120128
"""VOC Detection Dataset Object
@@ -169,8 +177,7 @@ def __len__(self):
169177
return len(self.ids)
170178

171179
def show(self, index, subparts=False):
172-
'''Shows an image with its ground truth boxes overlaid
173-
optionally
180+
'''Shows an image with its ground truth boxes overlaid optionally
174181
175182
Note: not using self.__getitem__(), as any transformations passed in
176183
could mess up this functionality.
@@ -179,18 +186,19 @@ def show(self, index, subparts=False):
179186
index (int): index of img to show
180187
subparts (bool, optional): whether or not to display subpart
181188
bboxes of ground truths
189+
(default: False)
182190
'''
183191
img_id = self.ids[index]
184192
target = ET.parse(self._annopath % img_id).getroot()
185193
img = Image.open(self._imgpath % img_id).convert('RGB')
186194
draw = ImageDraw.Draw(img)
187195
i = 0
188196
bndboxs = []
189-
classes = dict()
197+
classes = dict() # maps class name to a class number
190198
for obj in target.iter('object'):
191199
bbox = obj.find('bndbox')
192200
name = obj.find('name').text.lower().strip()
193-
if not name in classes:
201+
if name not in classes:
194202
classes[name] = i
195203
i += 1
196204
bndboxs.append((name, [int(bb.text) - 1 for bb in bbox]))
@@ -199,7 +207,7 @@ def show(self, index, subparts=False):
199207
name = part.find('name').text.lower().strip()
200208
bbox = part.find('bndbox')
201209
bndboxs.append((name, [int(bb.text) - 1 for bb in bbox]))
202-
if not name in classes:
210+
if name not in classes:
203211
classes[name] = i
204212
i += 1
205213
for name, bndbox in bndboxs:
@@ -209,14 +217,27 @@ def show(self, index, subparts=False):
209217
img.show()
210218
return img
211219

220+
212221
def detection_collate(batch):
222+
"""Custom collate fn for dealing with batches of images that have a different
223+
number of associated object annotations (bounding boxes).
224+
225+
Arguments:
226+
batch: (tuple) A tuple of tensor images and lists of annotations
227+
228+
Return:
229+
A tuple containing:
230+
1) (tensor) batch of images stacked on their 0 dim
231+
2) (list of tensors) annotations for a given image are stacked on 0 dim
232+
"""
213233
targets = []
214234
imgs = []
215-
for i,sample in enumerate(batch):
216-
for j, tup in enumerate(sample):
235+
for _, sample in enumerate(batch):
236+
for _, tup in enumerate(sample):
217237
if torch.is_tensor(tup):
218238
imgs.append(tup)
219239
elif isinstance(tup, type([])):
220-
targets.append([torch.Tensor(x) for x in tup])
240+
annos = [torch.Tensor(a) for a in tup]
241+
targets.append(torch.stack(annos, 0))
221242

222243
return (torch.stack(imgs, 0), targets)

0 commit comments

Comments
 (0)