diff --git a/onnx2pytorch/convert/attribute.py b/onnx2pytorch/convert/attribute.py index 4cdcd61..c27c16e 100644 --- a/onnx2pytorch/convert/attribute.py +++ b/onnx2pytorch/convert/attribute.py @@ -65,6 +65,8 @@ def extract_attributes(node): kwargs["negative_slope"] = extract_attr_values(attr) elif node.op_type in ("Elu", "ThresholdedRelu"): kwargs["alpha"] = extract_attr_values(attr) + elif node.op_type == "HardSigmoid": + kwargs["alpha"] = extract_attr_values(attr) else: kwargs["weight_multiplier"] = extract_attr_values(attr) elif attr.name == "auto_pad": @@ -84,7 +86,10 @@ def extract_attributes(node): else: kwargs["dim"] = v elif attr.name == "beta": - kwargs["bias_multiplier"] = extract_attr_values(attr) + if node.op_type == "HardSigmoid": + kwargs["beta"] = extract_attr_values(attr) + else: + kwargs["bias_multiplier"] = extract_attr_values(attr) elif attr.name == "body": kwargs["body"] = extract_attr_values(attr) elif attr.name == "ceil_mode": diff --git a/onnx2pytorch/convert/operations.py b/onnx2pytorch/convert/operations.py index ea37445..b0106ec 100644 --- a/onnx2pytorch/convert/operations.py +++ b/onnx2pytorch/convert/operations.py @@ -19,7 +19,7 @@ ) from onnx2pytorch.operations import * from onnx2pytorch.operations.base import OperatorWrapper -from onnx2pytorch.operations import Resize, Upsample +from onnx2pytorch.operations import Resize, Upsample, Hardsigmoid from onnx2pytorch.utils import ( get_inputs_names, get_outputs_names, @@ -236,6 +236,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr op = Shape() elif node.op_type == "Sigmoid": op = nn.Sigmoid() + elif node.op_type == "HardSigmoid": + op = Hardsigmoid(**extract_attributes(node)) elif node.op_type == "Slice": op = Slice(**extract_attributes(node)) elif node.op_type == "Softmax": diff --git a/onnx2pytorch/operations/__init__.py b/onnx2pytorch/operations/__init__.py index ad71cce..92f06e5 100644 --- a/onnx2pytorch/operations/__init__.py +++ b/onnx2pytorch/operations/__init__.py @@ -11,6 +11,7 @@ from .gather import Gather from .gathernd import GatherND from .globalaveragepool import GlobalAveragePool +from .hardsigmoid import Hardsigmoid from .instancenorm import InstanceNormWrapper from .loop import Loop from .lstm import LSTMWrapper diff --git a/onnx2pytorch/operations/hardsigmoid.py b/onnx2pytorch/operations/hardsigmoid.py new file mode 100644 index 0000000..2cd793a --- /dev/null +++ b/onnx2pytorch/operations/hardsigmoid.py @@ -0,0 +1,24 @@ +import math + +import torch +from torch import nn + + +class Hardsigmoid(nn.Module): + def __new__(cls, alpha=0.2, beta=0.5): + """ + If alpha and beta same as default values for torch's Hardsigmoid, + return torch's Hardsigmoid. Else, return custom Hardsigmoid. + """ + if math.isclose(alpha, 1 / 6, abs_tol=1e-2) and beta == 0.5: + return nn.Hardsigmoid() + else: + return super().__new__(cls) + + def __init__(self, alpha=0.2, beta=0.5): + super().__init__() + self.alpha = alpha + self.beta = beta + + def forward(self, input): + return torch.clip(input * self.alpha + self.beta, 0, 1) diff --git a/tests/onnx2pytorch/operations/test_hardsigmoid.py b/tests/onnx2pytorch/operations/test_hardsigmoid.py new file mode 100644 index 0000000..e32ffd7 --- /dev/null +++ b/tests/onnx2pytorch/operations/test_hardsigmoid.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock + +import numpy as np +import onnx +import torch +import pytest + +from onnx2pytorch.convert.operations import convert_operations +from onnx2pytorch.operations import Hardsigmoid + + +@pytest.fixture +def x(): + return np.random.randn(3, 4, 5).astype(np.float32) + + +def test_hardsigmoid(x): + alpha = 1 / 6 + beta = 1 / 2 + op = Hardsigmoid(alpha=alpha, beta=beta) + # For pytorch's default values it should use torch's Hardsigmoid + assert isinstance(op, torch.nn.Hardsigmoid) + x = np.random.randn(3, 4, 5).astype(np.float32) + y = np.clip(x * alpha + beta, 0, 1) + out = op(torch.from_numpy(x)) + np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6) + + +def test_hardsigmoid_with_custom_alpha_and_beta(x): + alpha = 0.2 + beta = 0.5 + op = Hardsigmoid(alpha=alpha, beta=beta) + assert not isinstance(op, torch.nn.Hardsigmoid) + y = np.clip(x * alpha + beta, 0, 1) + out = op(torch.from_numpy(x)) + np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6) + + +def test_hardsigmoid_conversion(): + alpha = np.float32(0.2) + beta = np.float32(0.5) + node = onnx.helper.make_node( + "HardSigmoid", + inputs=["x"], + outputs=["y"], + alpha=alpha, + beta=beta, + ) + + graph = MagicMock() + graph.initializers = [] + graph.node = [node] + converted_ops = list(convert_operations(graph, 10)) + op_id, op_name, op = converted_ops[0] + assert isinstance(op, Hardsigmoid) + assert op.alpha == alpha + assert op.beta == beta