Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scikit-Learn API Wrappers don't work with Model input #20876

Closed
swbedoya opened this issue Feb 7, 2025 · 3 comments
Closed

Scikit-Learn API Wrappers don't work with Model input #20876

swbedoya opened this issue Feb 7, 2025 · 3 comments

Comments

@swbedoya
Copy link

swbedoya commented Feb 7, 2025

When using the new Scikit-Learn API Wrappers with a compiled Model as input, the wrapper does not work, running into errors citing that the underlying model isn't compiled. The following code, adapted from the example in the SKLearnClassifier documentation to pass in a Model instance rather than a callable, runs into this issue. I also had to fix a couple bugs present in that code for it to work, and those couple fixes are noted in the code:

from keras.src.layers import Dense, Input
from keras.src.models.model import Model # FIX: previously imported from keras.src.layers

def dynamic_model(X, y, loss, layers=[10]):
    # Creates a basic MLP model dynamically choosing the input and
    # output shapes.
    n_features_in = X.shape[1]
    inp = Input(shape=(n_features_in,))

    hidden = inp
    for layer_size in layers:
        hidden = Dense(layer_size, activation="relu")(hidden)

    n_outputs = y.shape[1] if len(y.shape) > 1 else 1
    out = [Dense(n_outputs, activation="softmax")(hidden)]
    model = Model(inp, out)
    model.compile(loss=loss, optimizer="rmsprop")

    return model

from sklearn.datasets import make_classification
from keras.wrappers import SKLearnClassifier

X, y = make_classification(n_samples=1000, n_features=10, n_classes=2) # FIX: n_classes 3 -> 2
est = SKLearnClassifier(
    model=dynamic_model(X, y, loss="categorical_crossentropy", layers=[20, 20, 20]) # pass in compiled Model instance instead of callable
)

est.fit(X, y, epochs=5)

The error arises when fitting the model in that last line and is reproduced below. I believe this is from the fact that the model is cloned by default in self._get_model(), and clone_model() does not recompile the model.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-24-c9dcff454e13>](https://localhost:8080/#) in <cell line: 0>()
     27 )
     28 
---> 29 est.fit(X, y, epochs=5)

1 frames
[/usr/local/lib/python3.11/dist-packages/keras/src/wrappers/sklearn_wrapper.py](https://localhost:8080/#) in fit(self, X, y, **kwargs)
    162         y = self._process_target(y, reset=True)
    163         model = self._get_model(X, y)
--> 164         _check_model(model)
    165 
    166         fit_kwargs = self.fit_kwargs or {}

[/usr/local/lib/python3.11/dist-packages/keras/src/wrappers/utils.py](https://localhost:8080/#) in _check_model(model)
     25     # compile model if user gave us an un-compiled model
     26     if not model.compiled or not model.loss or not model.optimizer:
---> 27         raise RuntimeError(
     28             "Given model needs to be compiled, and have a loss and an "
     29             "optimizer."

RuntimeError: Given model needs to be compiled, and have a loss and an optimizer.
@sonali-kumari1
Copy link
Contributor

Hi @swbedoya -

Thanks for reporting this issue. Here is a similar issue and a PR has been raised for it. You are passing compiled model instance to SKLearnClassifier but ideally you should pass dynamic_model function and pass loss, layers and other required parameters into model_kwargs dictionary like this:

est = SKLearnClassifier(
    model=dynamic_model,
    model_kwargs={
        "loss": "categorical_crossentropy",
        "layers": [20, 20, 20],
    },
)

Copy link

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Feb 27, 2025
Copy link

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants