Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

incompatibility of skorch.callbacks.TrainEndCheckpoint with sklearn.compose.TransformedTargetRegressor #459

Closed
gillesdami opened this issue Apr 18, 2019 · 4 comments · Fixed by #463
Labels

Comments

@gillesdami
Copy link

skorch version: 0.5.0.post0

Minimal example inspired from the docs:

import numpy as np
from sklearn.compose import TransformedTargetRegressor
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from skorch.callbacks import TrainEndCheckpoint
from skorch import NeuralNetClassifier
from torch import nn

X, y = make_classification(1000, 10, n_informative=5, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Sequential):
    def __init__(self, num_units=10):
        super().__init__(
            nn.Linear(10, num_units),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(num_units, 10),
            nn.Linear(10, 2),
            nn.Softmax(dim=-1)
        )

train_end_cp = TrainEndCheckpoint(dirname='exp1', fn_prefix='train_end_')
net = NeuralNetClassifier(
    MyModule, lr=0.5, callbacks=[train_end_cp]
)

net = TransformedTargetRegressor(net, func=lambda x: x, inverse_func=lambda x: x)

_ = net.fit(X, y)

error returned:

Traceback (most recent call last):
  File "test.py", line 31, in <module>
    _ = net.fit(X, y)
  File "/home/damien/.local/lib/python3.6/site-packages/sklearn/compose/_target.py", line 198, in fit
    self.regressor_ = clone(self.regressor)
  File "/home/damien/.local/lib/python3.6/site-packages/sklearn/base.py", line 64, in clone
    new_object_params[name] = clone(param, safe=False)
  File "/home/damien/.local/lib/python3.6/site-packages/sklearn/base.py", line 52, in clone
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/home/damien/.local/lib/python3.6/site-packages/sklearn/base.py", line 52, in <listcomp>
    return estimator_type([clone(e, safe=safe) for e in estimator])
  File "/home/damien/.local/lib/python3.6/site-packages/sklearn/base.py", line 65, in clone
    new_object = klass(**new_object_params)
TypeError: __init__() got an unexpected keyword argument 'monitor

explanation:

For some reason, sklearn will make a copy of the callbacks when using TransformedTargetRegressor. However they does not respect the BaseEstimator specs.

fix:

The callbacks classes inherited from Callback should have **args in their __init__ function to ensure that they accept all Callback.__init__ parameters.

May I open a PR ?

@BenjaminBossan
Copy link
Collaborator

Thanks for reporting the bug. I believe the solution would not be optimal, it is better to be explicit about parameters in this case.

There are 3 parameters missing from TrainEndCheckpoint.__init__, target, monitor, and event_name. They are not there because they are not needed but at the same time, they should be there because TrainEndCheckpoint inherits from Checkpoint.

The solution would be either to add these parameters with the new defaults -- but it would not really make sense to have them on TrainEndCheckpoint -- or to move the common functionality out of Checkpoint, so that it can be recycled without inheritance -- but this would require the functions to have a lot of parameters.

I don't have a strong preference, what do others think? @gillesdami If you would like to work on one of the solutions, or a better one, go ahead.

@thomasjpfan
Copy link
Member

I think adding a **kwargs to TrainEndCheckpoint (with a comment) would be the cleanest fix.

@gillesdami
Copy link
Author

Adding explicitly target, monitor and event_name would confuse everyone. In my opinion, we should add **kwargs or refactor the code to avoid inheritance.
I'd be happy to work on a solution, I'll open a PR once we've chosen the way to go.

@BenjaminBossan
Copy link
Collaborator

The reason why I would avoid **kwargs if possible is that it can be confusing to the user as well (what happens with those kwargs?) and because they silently swallow typos of correct arguments. Furthermore, sklearn discourages the use of **kwargs during init.

Another solution would be to add the three missing parameters but make a check that their values are what is needed (this check would ideally also work after a set_params). Otherwise, I prefer the refactoring approach, though that is probably the most difficult one to get right.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants