Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support exporting RoiAlign align=True to ONNX with opset 16 #6685

Merged
merged 3 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
from collections import OrderedDict
from typing import List, Tuple
from typing import List, Optional, Tuple

import pytest
import torch
Expand All @@ -11,7 +11,7 @@
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops._register_onnx_ops import _onnx_opset_version
from torchvision.ops import _register_onnx_ops

# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
Expand All @@ -32,7 +32,11 @@ def run_model(
dynamic_axes=None,
output_names=None,
input_names=None,
opset_version: Optional[int] = None,
):
if opset_version is None:
opset_version = _register_onnx_ops.base_onnx_opset_version

model.eval()

onnx_io = io.BytesIO()
Expand All @@ -46,10 +50,11 @@ def run_model(
torch_onnx_input,
onnx_io,
do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
verbose=True,
)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
Expand Down Expand Up @@ -140,39 +145,39 @@ def test_roi_align(self):
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])

@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
def test_roi_align_aligned(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
def test_roi_align_malformed_boxes(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
Expand Down
63 changes: 45 additions & 18 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import torch

_onnx_opset_version = 11
_onnx_opset_version_11 = 11
_onnx_opset_version_16 = 16
base_onnx_opset_version = _onnx_opset_version_11


def _register_custom_op():
Expand All @@ -20,32 +22,56 @@ def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)

@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _cast_Long(
def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
)
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
# TODO: Remove this warning after ONNX opset 16 is supported.
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
"The workaround is that the user need apply the patch "
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
)

# ONNX doesn't support negative sampling_ratio
def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))

def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
if sampling_ratio < 0:
warnings.warn(
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0."
)
sampling_ratio = 0
return sampling_ratio

@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but is supported in opset 16. "
"Please export with opset 16 or higher to use aligned=False."
)
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)

@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
Expand All @@ -61,6 +87,7 @@ def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):

from torch.onnx import register_custom_op_symbolic

register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version)
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)