Skip to content

Commit

Permalink
Fix unexpected type for intercept in PoissonRegressor and GammaRegres…
Browse files Browse the repository at this point in the history
…sor (#1070)

* Fix unexpected type for intercept in PoissonRegressor and GammaRegressor

Signed-off-by: Xavier Dupre <[email protected]>

* fix changelogs

Signed-off-by: Xavier Dupre <[email protected]>

---------

Signed-off-by: Xavier Dupre <[email protected]>
  • Loading branch information
xadupre authored Apr 2, 2024
1 parent 36eaa4d commit 59938a2
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGELOGS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 1.17.0 (development)

* Fix unexpected type for intercept in PoissonRegressor and GammaRegressor
[#1070](https://github.com/onnx/sklearn-onnx/pull/1070)
* Add support for scikti-learn 1.4.0,
[#1058](https://github.com/onnx/sklearn-onnx/pull/1058),
fixes issues [Many examples in the gallery are showing "broken"](https://github.com/onnx/sklearn-onnx/pull/1057),
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ exclude = [
# Same as Black.
line-length = 88

[tool.ruff.mccabe]
[tool.ruff.lint.mccabe]
max-complexity = 10

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"skl2onnx/algebra/onnx_ops.py" = ["F821"]
2 changes: 1 addition & 1 deletion skl2onnx/operator_converters/gamma_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def convert_sklearn_gamma_regressor(

intercept = (
op.intercept_.astype(dtype)
if len(op.intercept_.shape) > 0
if isinstance(op.intercept_, np.ndarray) and len(op.intercept_.shape) > 0
else np.array([op.intercept_], dtype=dtype)
)
eta = OnnxAdd(
Expand Down
2 changes: 1 addition & 1 deletion skl2onnx/operator_converters/linear_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def convert_sklearn_poisson_regressor(

intercept = (
op.intercept_.astype(dtype)
if len(op.intercept_.shape) > 0
if isinstance(op.intercept_, np.ndarray) and len(op.intercept_.shape) > 0
else np.array([op.intercept_], dtype=dtype)
)
eta = OnnxAdd(
Expand Down
34 changes: 33 additions & 1 deletion tests/test_sklearn_gamma_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

try:
from sklearn.linear_model import GammaRegressor
from sklearn.linear_model import GammaRegressor, PoissonRegressor
except ImportError:
GammaRegressor = None
from onnxruntime import __version__ as ort_version
Expand Down Expand Up @@ -90,6 +90,38 @@ def test_gamma_regressor_double(self):
basename="SklearnGammaRegressor",
)

@unittest.skipIf(GammaRegressor is None, reason="scikit-learn<1.0")
def test_poisson_without_intercept(self):
# Poisson
model = PoissonRegressor(fit_intercept=False)
X = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 3.0]])
y = np.array([19.0, 26.0, 33.0, 30.0])
model.fit(X, y)

model_onnx = convert_sklearn(
model,
"scikit-learn Poisson Regressor without Intercept",
[("input", FloatTensorType([None, X.shape[1]]))],
)

self.assertIsNotNone(model_onnx is not None)

@unittest.skipIf(GammaRegressor is None, reason="scikit-learn<1.0")
def test_gamma_without_intercept(self):
# Gamma
model = GammaRegressor(fit_intercept=False)
X = np.array([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 3.0]])
y = np.array([19.0, 26.0, 33.0, 30.0])
model.fit(X, y)

model_onnx = convert_sklearn(
model,
"scikit-learn Gamma Regressor without Intercept",
[("input", FloatTensorType([None, X.shape[1]]))],
)

self.assertIsNotNone(model_onnx is not None)


if __name__ == "__main__":
unittest.main(verbosity=3)

0 comments on commit 59938a2

Please sign in to comment.