Skip to content

Commit

Permalink
Make skorch work with sklearn 1.6.0, attempt 2 (#1078)
Browse files Browse the repository at this point in the history
Alternative to #1076

As described in that PR, skorch is currently not compatible with sklearn
1.6.0 or above. As per suggestion, instead of implementing
__sklearn_tags__, this PR solves the issue by inheriting from
BaseEstimator.

Related changes:

- It is important to set the correct order when inheriting from
  BaseEstimator and, say, ClassifierMixin (BaseEstimator should come
  last).
- As explained in #1076, using GridSearchCV with y being a torch tensor
  currently fails and two tests had to be adjusted.

Unrelated changes

- Removed unnecessary imports from callbacks/base.py.
  • Loading branch information
BenjaminBossan authored Dec 18, 2024
1 parent ad0259b commit 4f755b9
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 15 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
### Changed

- All neural net classes now inherit from sklearn's [`BaseEstimator`](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html). This is to support compatibility with sklearn 1.6.0 and above. Classification models additionally inherit from [`ClassifierMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.ClassifierMixin.html) and regressors from [`RegressorMixin`](https://scikit-learn.org/stable/modules/generated/sklearn.base.RegressorMixin.html).

### Fixed

- Fix an issue with using `NeuralNetBinaryClassifier` with `torch.compile` (#1058)
Expand Down
3 changes: 0 additions & 3 deletions skorch/callbacks/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
""" Basic callback definition. """

import warnings

from sklearn.base import BaseEstimator
from skorch.exceptions import SkorchWarning


__all__ = ['Callback']
Expand Down
4 changes: 2 additions & 2 deletions skorch/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_neural_net_clf_doc(doc):


# pylint: disable=missing-docstring
class NeuralNetClassifier(NeuralNet, ClassifierMixin):
class NeuralNetClassifier(ClassifierMixin, NeuralNet):
__doc__ = get_neural_net_clf_doc(NeuralNet.__doc__)

def __init__(
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_neural_net_binary_clf_doc(doc):
return doc


class NeuralNetBinaryClassifier(NeuralNet, ClassifierMixin):
class NeuralNetBinaryClassifier(ClassifierMixin, NeuralNet):
# pylint: disable=missing-docstring
__doc__ = get_neural_net_binary_clf_doc(NeuralNet.__doc__)

Expand Down
2 changes: 1 addition & 1 deletion skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from skorch.utils import check_is_fitted, params_for


class _HuggingfaceTokenizerBase(BaseEstimator, TransformerMixin):
class _HuggingfaceTokenizerBase(TransformerMixin, BaseEstimator):
"""Base class for yet to train and pretrained tokenizers
Implements the ``vocabulary_`` attribute and the methods
Expand Down
2 changes: 1 addition & 1 deletion skorch/llm/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def generate_logits(self, *, label_id, **kwargs):
return recorded_logits + recorder.recorded_scores[:]


class _LlmBase(BaseEstimator, ClassifierMixin):
class _LlmBase(ClassifierMixin, BaseEstimator):
"""Base class for LLM models
This class handles a few of the checks, as well as the whole prediction
Expand Down
6 changes: 3 additions & 3 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@


# pylint: disable=too-many-instance-attributes
class NeuralNet:
class NeuralNet(BaseEstimator):
# pylint: disable=anomalous-backslash-in-string
"""NeuralNet base class.
Expand Down Expand Up @@ -1992,7 +1992,7 @@ def _get_params_callbacks(self, deep=True):
return params

def get_params(self, deep=True, **kwargs):
params = BaseEstimator.get_params(self, deep=deep, **kwargs)
params = super().get_params(deep=deep, **kwargs)
# Callback parameters are not returned by .get_params, needs
# special treatment.
params_cb = self._get_params_callbacks(deep=deep)
Expand Down Expand Up @@ -2111,7 +2111,7 @@ def set_params(self, **kwargs):
normal_params[key] = val

self._apply_virtual_params(virtual_params)
BaseEstimator.set_params(self, **normal_params)
super().set_params(**normal_params)

for key, val in special_params.items():
if key.endswith('_'):
Expand Down
5 changes: 3 additions & 2 deletions skorch/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import gpytorch
import numpy as np
import torch
from sklearn.base import ClassifierMixin, RegressorMixin

from skorch.net import NeuralNet
from skorch.dataset import ValidSplit
Expand Down Expand Up @@ -391,7 +392,7 @@ def __getstate__(self):
raise pickle.PicklingError(msg) from exc


class _GPRegressorPredictMixin:
class _GPRegressorPredictMixin(RegressorMixin):
"""Mixin class that provides a predict method for GP regressors."""
def predict(self, X, return_std=False, return_cov=False):
"""Returns the predicted mean and optionally standard deviation.
Expand Down Expand Up @@ -778,7 +779,7 @@ def get_gp_binary_clf_doc(doc):
return doc


class GPBinaryClassifier(GPBase):
class GPBinaryClassifier(ClassifierMixin, GPBase):
__doc__ = get_gp_binary_clf_doc(NeuralNet.__doc__)

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion skorch/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_neural_net_reg_doc(doc):


# pylint: disable=missing-docstring
class NeuralNetRegressor(NeuralNet, RegressorMixin):
class NeuralNetRegressor(RegressorMixin, NeuralNet):
__doc__ = get_neural_net_reg_doc(NeuralNet.__doc__)

def __init__(
Expand Down
12 changes: 10 additions & 2 deletions skorch/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def test_grid_search_with_slds_works(
self, slds, y, classifier_module):
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetClassifier
from skorch.utils import to_numpy

net = NeuralNetClassifier(
classifier_module,
Expand All @@ -450,12 +451,16 @@ def test_grid_search_with_slds_works(
gs = GridSearchCV(
net, params, refit=False, cv=3, scoring='accuracy', error_score='raise'
)
gs.fit(slds, y) # does not raise
# TODO: after sklearn > 1.6 is released, the to_numpy call should no longer be
# required and be removed, see:
# https://github.com/skorch-dev/skorch/pull/1078#discussion_r1887197261
gs.fit(slds, to_numpy(y)) # does not raise

def test_grid_search_with_slds_and_internal_split_works(
self, slds, y, classifier_module):
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetClassifier
from skorch.utils import to_numpy

net = NeuralNetClassifier(classifier_module)
params = {
Expand All @@ -465,7 +470,10 @@ def test_grid_search_with_slds_and_internal_split_works(
gs = GridSearchCV(
net, params, refit=True, cv=3, scoring='accuracy', error_score='raise'
)
gs.fit(slds, y) # does not raise
# TODO: after sklearn > 1.6 is released, the to_numpy call should no longer be
# required and be removed, see:
# https://github.com/skorch-dev/skorch/pull/1078#discussion_r1887197261
gs.fit(slds, to_numpy(y)) # does not raise

def test_grid_search_with_slds_X_and_slds_y(
self, slds, slds_y, classifier_module):
Expand Down

0 comments on commit 4f755b9

Please sign in to comment.