Skip to content

[ONNX] Fix dtype for NonMaxSuppression #7056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 15, 2023
14 changes: 9 additions & 5 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,25 @@
def _register_custom_op():
from torch.onnx.symbolic_helper import parse_args
from torch.onnx.symbolic_opset11 import select, squeeze, unsqueeze
from torch.onnx.symbolic_opset9 import _cast_Long

@parse_args("v", "v", "f")
def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
boxes = unsqueeze(g, boxes, 0)
scores = unsqueeze(g, unsqueeze(g, scores, 0), 0)
max_output_per_class = g.op("Constant", value_t=torch.tensor([sys.maxsize], dtype=torch.long))
iou_threshold = g.op("Constant", value_t=torch.tensor([iou_threshold], dtype=torch.float))
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
nms_out = g.op(
"NonMaxSuppression",
g.op("Cast", boxes, to_i=torch.onnx.TensorProtoDataType.FLOAT),
g.op("Cast", scores, to_i=torch.onnx.TensorProtoDataType.FLOAT),
max_output_per_class,
iou_threshold,
)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)

def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
)
indices = squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1)
return g.op("Cast", indices, to_i=torch.onnx.TensorProtoDataType.INT64)

def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
Expand Down