Skip to content

Commit

Permalink
[ONNX] Support exporting RoiAlign align=True to ONNX with opset 16 (#…
Browse files Browse the repository at this point in the history
…6685)

* Support exporting RoiAlign align=True to ONNX with opset 16

* lint: ufmt

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
BowenBao and datumbox authored Oct 4, 2022
1 parent 344ccc0 commit 45f87fa
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 29 deletions.
27 changes: 16 additions & 11 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
from collections import OrderedDict
from typing import List, Tuple
from typing import List, Optional, Tuple

import pytest
import torch
Expand All @@ -11,7 +11,7 @@
from torchvision.models.detection.roi_heads import RoIHeads
from torchvision.models.detection.rpn import AnchorGenerator, RegionProposalNetwork, RPNHead
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.ops._register_onnx_ops import _onnx_opset_version
from torchvision.ops import _register_onnx_ops

# In environments without onnxruntime we prefer to
# invoke all tests in the repo and have this one skipped rather than fail.
Expand All @@ -32,7 +32,11 @@ def run_model(
dynamic_axes=None,
output_names=None,
input_names=None,
opset_version: Optional[int] = None,
):
if opset_version is None:
opset_version = _register_onnx_ops.base_onnx_opset_version

model.eval()

onnx_io = io.BytesIO()
Expand All @@ -46,10 +50,11 @@ def run_model(
torch_onnx_input,
onnx_io,
do_constant_folding=do_constant_folding,
opset_version=_onnx_opset_version,
opset_version=opset_version,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
verbose=True,
)
# validate the exported model with onnx runtime
for test_inputs in inputs_list:
Expand Down Expand Up @@ -140,39 +145,39 @@ def test_roi_align(self):
model = ops.RoIAlign((5, 5), 1, -1)
self.run_model(model, [(x, single_roi)])

@pytest.mark.skip(reason="ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16.")
def test_roi_align_aligned(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 2, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 0.5, 3, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1.8, 2, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, 0, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=torch.float32)
model = ops.RoIAlign((2, 2), 2.5, -1, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

@pytest.mark.skip(reason="Issue in exporting ROIAlign with aligned = True for malformed boxes")
def test_roi_align_malformed_boxes(self):
supported_onnx_version = _register_onnx_ops._onnx_opset_version_16
x = torch.randn(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1, 1, aligned=True)
self.run_model(model, [(x, single_roi)])
self.run_model(model, [(x, single_roi)], opset_version=supported_onnx_version)

def test_roi_pool(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
Expand Down
63 changes: 45 additions & 18 deletions torchvision/ops/_register_onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import torch

_onnx_opset_version = 11
_onnx_opset_version_11 = 11
_onnx_opset_version_16 = 16
base_onnx_opset_version = _onnx_opset_version_11


def _register_custom_op():
Expand All @@ -20,32 +22,56 @@ def symbolic_multi_label_nms(g, boxes, scores, iou_threshold):
nms_out = g.op("NonMaxSuppression", boxes, scores, max_output_per_class, iou_threshold)
return squeeze(g, select(g, nms_out, 1, g.op("Constant", value_t=torch.tensor([2], dtype=torch.long))), 1)

@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _cast_Long(
def _process_batch_indices_for_roi_align(g, rois):
return _cast_Long(
g, squeeze(g, select(g, rois, 1, g.op("Constant", value_t=torch.tensor([0], dtype=torch.long))), 1), False
)
rois = select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))
# TODO: Remove this warning after ONNX opset 16 is supported.
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but will be supported in opset 16. "
"The workaround is that the user need apply the patch "
"https://github.com/microsoft/onnxruntime/pull/8564 "
"and build ONNXRuntime from source."
)

# ONNX doesn't support negative sampling_ratio
def _process_rois_for_roi_align(g, rois):
return select(g, rois, 1, g.op("Constant", value_t=torch.tensor([1, 2, 3, 4], dtype=torch.long)))

def _process_sampling_ratio_for_roi_align(g, sampling_ratio: int):
if sampling_ratio < 0:
warnings.warn(
"ONNX doesn't support negative sampling ratio, therefore is set to 0 in order to be exported."
"ONNX export for RoIAlign with a non-zero sampling_ratio is not supported. "
"The model will be exported with a sampling_ratio of 0."
)
sampling_ratio = 0
return sampling_ratio

@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset11(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
if aligned:
warnings.warn(
"ROIAlign with aligned=True is not supported in ONNX, but is supported in opset 16. "
"Please export with opset 16 or higher to use aligned=False."
)
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
sampling_ratio_i=sampling_ratio,
)

@parse_args("v", "v", "f", "i", "i", "i", "i")
def roi_align_opset16(g, input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
batch_indices = _process_batch_indices_for_roi_align(g, rois)
rois = _process_rois_for_roi_align(g, rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(g, sampling_ratio)
return g.op(
"RoiAlign",
input,
rois,
batch_indices,
coordinate_transformation_mode_s=coordinate_transformation_mode,
spatial_scale_f=spatial_scale,
output_height_i=pooled_height,
output_width_i=pooled_width,
Expand All @@ -61,6 +87,7 @@ def roi_pool(g, input, rois, spatial_scale, pooled_height, pooled_width):

from torch.onnx import register_custom_op_symbolic

register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_align", roi_align, _onnx_opset_version)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version)
register_custom_op_symbolic("torchvision::nms", symbolic_multi_label_nms, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset11, _onnx_opset_version_11)
register_custom_op_symbolic("torchvision::roi_align", roi_align_opset16, _onnx_opset_version_16)
register_custom_op_symbolic("torchvision::roi_pool", roi_pool, _onnx_opset_version_11)

0 comments on commit 45f87fa

Please sign in to comment.