Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hardsigmoid with tests. #67

Merged
merged 1 commit into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion onnx2pytorch/convert/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def extract_attributes(node):
kwargs["negative_slope"] = extract_attr_values(attr)
elif node.op_type in ("Elu", "ThresholdedRelu"):
kwargs["alpha"] = extract_attr_values(attr)
elif node.op_type == "HardSigmoid":
kwargs["alpha"] = extract_attr_values(attr)
else:
kwargs["weight_multiplier"] = extract_attr_values(attr)
elif attr.name == "auto_pad":
Expand All @@ -84,7 +86,10 @@ def extract_attributes(node):
else:
kwargs["dim"] = v
elif attr.name == "beta":
kwargs["bias_multiplier"] = extract_attr_values(attr)
if node.op_type == "HardSigmoid":
kwargs["beta"] = extract_attr_values(attr)
else:
kwargs["bias_multiplier"] = extract_attr_values(attr)
elif attr.name == "body":
kwargs["body"] = extract_attr_values(attr)
elif attr.name == "ceil_mode":
Expand Down
4 changes: 3 additions & 1 deletion onnx2pytorch/convert/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
)
from onnx2pytorch.operations import *
from onnx2pytorch.operations.base import OperatorWrapper
from onnx2pytorch.operations import Resize, Upsample
from onnx2pytorch.operations import Resize, Upsample, Hardsigmoid
from onnx2pytorch.utils import (
get_inputs_names,
get_outputs_names,
Expand Down Expand Up @@ -236,6 +236,8 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
op = Shape()
elif node.op_type == "Sigmoid":
op = nn.Sigmoid()
elif node.op_type == "HardSigmoid":
op = Hardsigmoid(**extract_attributes(node))
elif node.op_type == "Slice":
op = Slice(**extract_attributes(node))
elif node.op_type == "Softmax":
Expand Down
1 change: 1 addition & 0 deletions onnx2pytorch/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .gather import Gather
from .gathernd import GatherND
from .globalaveragepool import GlobalAveragePool
from .hardsigmoid import Hardsigmoid
from .instancenorm import InstanceNormWrapper
from .loop import Loop
from .lstm import LSTMWrapper
Expand Down
24 changes: 24 additions & 0 deletions onnx2pytorch/operations/hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import math

import torch
from torch import nn


class Hardsigmoid(nn.Module):
def __new__(cls, alpha=0.2, beta=0.5):
"""
If alpha and beta same as default values for torch's Hardsigmoid,
return torch's Hardsigmoid. Else, return custom Hardsigmoid.
"""
if math.isclose(alpha, 1 / 6, abs_tol=1e-2) and beta == 0.5:
return nn.Hardsigmoid()
else:
return super().__new__(cls)

def __init__(self, alpha=0.2, beta=0.5):
super().__init__()
self.alpha = alpha
self.beta = beta

def forward(self, input):
return torch.clip(input * self.alpha + self.beta, 0, 1)
57 changes: 57 additions & 0 deletions tests/onnx2pytorch/operations/test_hardsigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from unittest.mock import MagicMock

import numpy as np
import onnx
import torch
import pytest

from onnx2pytorch.convert.operations import convert_operations
from onnx2pytorch.operations import Hardsigmoid


@pytest.fixture
def x():
return np.random.randn(3, 4, 5).astype(np.float32)


def test_hardsigmoid(x):
alpha = 1 / 6
beta = 1 / 2
op = Hardsigmoid(alpha=alpha, beta=beta)
# For pytorch's default values it should use torch's Hardsigmoid
assert isinstance(op, torch.nn.Hardsigmoid)
x = np.random.randn(3, 4, 5).astype(np.float32)
y = np.clip(x * alpha + beta, 0, 1)
out = op(torch.from_numpy(x))
np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6)


def test_hardsigmoid_with_custom_alpha_and_beta(x):
alpha = 0.2
beta = 0.5
op = Hardsigmoid(alpha=alpha, beta=beta)
assert not isinstance(op, torch.nn.Hardsigmoid)
y = np.clip(x * alpha + beta, 0, 1)
out = op(torch.from_numpy(x))
np.testing.assert_allclose(out, torch.from_numpy(y), rtol=1e-6, atol=1e-6)


def test_hardsigmoid_conversion():
alpha = np.float32(0.2)
beta = np.float32(0.5)
node = onnx.helper.make_node(
"HardSigmoid",
inputs=["x"],
outputs=["y"],
alpha=alpha,
beta=beta,
)

graph = MagicMock()
graph.initializers = []
graph.node = [node]
converted_ops = list(convert_operations(graph, 10))
op_id, op_name, op = converted_ops[0]
assert isinstance(op, Hardsigmoid)
assert op.alpha == alpha
assert op.beta == beta