Skip to content

Commit 3ac864d

Browse files
authored
[ONNX] Fix model export for images w/ no detection (#2126)
* Fixing nms on boxes when no detection * test * Fix for scale_factor computation * remove newline * Fix for mask_rcnn dynanmic axes * Clean up * Update transform.py * Fix for torchscript * Fix scripting errors * Fix annotation * Fix lint * Fix annotation * Fix for interpolate scripting * Fix for scripting * refactoring * refactor the code * Fix annotation * Fixed annotations * Added test for resize * lint * format * bump ORT * ort-nightly version * Going to ort 1.1.0 * remove version * install typing-extension * Export model for images with no detection * Upgrade ort nightly * update ORT * Update test_onnx.py * updated tests * Updated tests * merge * Update transforms.py * Update cityscapes.py * Update celeba.py * Update caltech.py * Update pkg_helpers.bash * Clean up * Clean up for dynamic split * Remove extra casts * flake8
1 parent 14af9de commit 3ac864d

File tree

7 files changed

+46
-32
lines changed

7 files changed

+46
-32
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ before_install:
2929
- |
3030
if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then
3131
pip install -q --user typing-extensions==3.6.6
32-
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202004141
32+
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.2.0.dev202005021
3333
fi
3434
- conda install av -c conda-forge
3535

test/test_onnx.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,11 +346,17 @@ def get_test_images(self):
346346

347347
def test_faster_rcnn(self):
348348
images, test_images = self.get_test_images()
349-
349+
dummy_image = [torch.ones(3, 100, 100) * 0.3]
350350
model = models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
351351
model.eval()
352352
model(images)
353-
self.run_model(model, [(images,), (test_images,)], input_names=["images_tensors"],
353+
# Test exported model on images of different size, or dummy input
354+
self.run_model(model, [(images,), (test_images,), (dummy_image,)], input_names=["images_tensors"],
355+
output_names=["outputs"],
356+
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
357+
tolerate_small_mismatch=True)
358+
# Test exported model for an image with no detections on other images
359+
self.run_model(model, [(dummy_image,), (images,)], input_names=["images_tensors"],
354360
output_names=["outputs"],
355361
dynamic_axes={"images_tensors": [0, 1, 2, 3], "outputs": [0, 1, 2, 3]},
356362
tolerate_small_mismatch=True)
@@ -391,16 +397,25 @@ def test_paste_mask_in_image(self):
391397

392398
def test_mask_rcnn(self):
393399
images, test_images = self.get_test_images()
394-
400+
dummy_image = [torch.ones(3, 100, 320) * 0.3]
395401
model = models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200, max_size=300)
396402
model.eval()
397403
model(images)
398-
self.run_model(model, [(images,), (test_images,)],
404+
# Test exported model on images of different size, or dummy input
405+
self.run_model(model, [(images,), (test_images,), (dummy_image,)],
399406
input_names=["images_tensors"],
400407
output_names=["boxes", "labels", "scores"],
401408
dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
402409
"scores": [0], "masks": [0, 1, 2, 3]},
403410
tolerate_small_mismatch=True)
411+
# TODO: enable this test once dynamic model export is fixed
412+
# Test exported model for an image with no detections on other images
413+
# self.run_model(model, [(images,),(test_images,)],
414+
# input_names=["images_tensors"],
415+
# output_names=["boxes", "labels", "scores"],
416+
# dynamic_axes={"images_tensors": [0, 1, 2, 3], "boxes": [0, 1], "labels": [0],
417+
# "scores": [0], "masks": [0, 1, 2, 3]},
418+
# tolerate_small_mismatch=True)
404419

405420
# Verify that heatmaps_to_keypoints behaves the same in tracing.
406421
# This test also compares both heatmaps_to_keypoints and _onnx_heatmaps_to_keypoints
@@ -445,6 +460,10 @@ def forward(self, images):
445460
return output[0]['boxes'], output[0]['labels'], output[0]['scores'], output[0]['keypoints']
446461

447462
images, test_images = self.get_test_images()
463+
# TODO:
464+
# Enable test for dummy_image (no detection) once issue is
465+
# _onnx_heatmaps_to_keypoints_loop for empty heatmaps is fixed
466+
# dummy_images = [torch.ones(3, 100, 100) * 0.3]
448467
model = KeyPointRCNN()
449468
model.eval()
450469
model(images)
@@ -453,6 +472,13 @@ def forward(self, images):
453472
output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
454473
dynamic_axes={"images_tensors": [0, 1, 2, 3]},
455474
tolerate_small_mismatch=True)
475+
# TODO: enable this test once dynamic model export is fixed
476+
# Test exported model for an image with no detections on other images
477+
# self.run_model(model, [(dummy_images,), (test_images,)],
478+
# input_names=["images_tensors"],
479+
# output_names=["outputs1", "outputs2", "outputs3", "outputs4"],
480+
# dynamic_axes={"images_tensors": [0, 1, 2, 3]},
481+
# tolerate_small_mismatch=True)
456482

457483

458484
if __name__ == '__main__':

torchvision/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import warnings
22

3+
from .extension import _HAS_OPS
4+
35
from torchvision import models
46
from torchvision import datasets
57
from torchvision import ops
68
from torchvision import transforms
79
from torchvision import utils
810
from torchvision import io
911

10-
from .extension import _HAS_OPS
1112
import torch
1213

1314
try:

torchvision/models/detection/roi_heads.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,19 +75,13 @@ def maskrcnn_inference(x, labels):
7575

7676
# select masks coresponding to the predicted classes
7777
num_masks = x.shape[0]
78-
boxes_per_image = [len(l) for l in labels]
78+
boxes_per_image = [l.shape[0] for l in labels]
7979
labels = torch.cat(labels)
8080
index = torch.arange(num_masks, device=labels.device)
8181
mask_prob = mask_prob[index, labels][:, None]
82+
mask_prob = mask_prob.split(boxes_per_image, dim=0)
8283

83-
if len(boxes_per_image) == 1:
84-
# TODO : remove when dynamic split supported in ONNX
85-
# and remove assignment to mask_prob_list, just assign to mask_prob
86-
mask_prob_list = [mask_prob]
87-
else:
88-
mask_prob_list = mask_prob.split(boxes_per_image, dim=0)
89-
90-
return mask_prob_list
84+
return mask_prob
9185

9286

9387
def project_masks_on_boxes(gt_masks, boxes, matched_idxs, M):
@@ -318,12 +312,6 @@ def keypointrcnn_inference(x, boxes):
318312
kp_scores = []
319313

320314
boxes_per_image = [box.size(0) for box in boxes]
321-
322-
if len(boxes_per_image) == 1:
323-
# TODO : remove when dynamic split supported in ONNX
324-
kp_prob, scores = heatmaps_to_keypoints(x, boxes[0])
325-
return [kp_prob], [scores]
326-
327315
x2 = x.split(boxes_per_image, dim=0)
328316

329317
for xx, bb in zip(x2, boxes):

torchvision/models/detection/rpn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@ def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
1717
# type: (Tensor, int) -> Tuple[int, int]
1818
from torch.onnx import operators
1919
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
20-
# TODO : remove cast to IntTensor/num_anchors.dtype when
21-
# ONNX Runtime version is updated with ReduceMin int64 support
2220
pre_nms_top_n = torch.min(torch.cat(
2321
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype),
24-
num_anchors), 0).to(torch.int32)).to(num_anchors.dtype)
22+
num_anchors), 0))
2523

2624
return num_anchors, pre_nms_top_n
2725

torchvision/models/detection/transform.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target):
1717
im_shape = operators.shape_as_tensor(image)[-2:]
1818
min_size = torch.min(im_shape).to(dtype=torch.float32)
1919
max_size = torch.max(im_shape).to(dtype=torch.float32)
20-
scale_factor = self_min_size / min_size
21-
if max_size * scale_factor > self_max_size:
22-
scale_factor = self_max_size / max_size
20+
scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size)
2321

2422
image = torch.nn.functional.interpolate(
2523
image[None], scale_factor=scale_factor, mode='bilinear',

torchvision/ops/boxes.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torchvision
55

66

7+
@torch.jit.script
78
def nms(boxes, scores, iou_threshold):
89
# type: (Tensor, Tensor, float)
910
"""
@@ -40,6 +41,7 @@ def nms(boxes, scores, iou_threshold):
4041
return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
4142

4243

44+
@torch.jit.script
4345
def batched_nms(boxes, scores, idxs, iou_threshold):
4446
# type: (Tensor, Tensor, Tensor, float)
4547
"""
@@ -74,11 +76,12 @@ def batched_nms(boxes, scores, idxs, iou_threshold):
7476
# we add an offset to all the boxes. The offset is dependent
7577
# only on the class idx, and is large enough so that boxes
7678
# from different classes do not overlap
77-
max_coordinate = boxes.max()
78-
offsets = idxs.to(boxes) * (max_coordinate + 1)
79-
boxes_for_nms = boxes + offsets[:, None]
80-
keep = nms(boxes_for_nms, scores, iou_threshold)
81-
return keep
79+
else:
80+
max_coordinate = boxes.max()
81+
offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes))
82+
boxes_for_nms = boxes + offsets[:, None]
83+
keep = nms(boxes_for_nms, scores, iou_threshold)
84+
return keep
8285

8386

8487
def remove_small_boxes(boxes, min_size):

0 commit comments

Comments
 (0)