When using the new Scikit-Learn API Wrappers with a compiled Model as input, the wrapper does not work, running into errors citing that the underlying model isn't compiled. The following code, adapted from the example in the SKLearnClassifier
documentation to pass in a Model instance rather than a callable, runs into this issue. I also had to fix a couple bugs present in that code for it to work, and those couple fixes are noted in the code:
from keras.src.layers import Dense, Input
from keras.src.models.model import Model # FIX: previously imported from keras.src.layers
def dynamic_model(X, y, loss, layers=[10]):
# Creates a basic MLP model dynamically choosing the input and
# output shapes.
n_features_in = X.shape[1]
inp = Input(shape=(n_features_in,))
hidden = inp
for layer_size in layers:
hidden = Dense(layer_size, activation="relu")(hidden)
n_outputs = y.shape[1] if len(y.shape) > 1 else 1
out = [Dense(n_outputs, activation="softmax")(hidden)]
model = Model(inp, out)
model.compile(loss=loss, optimizer="rmsprop")
return model
from sklearn.datasets import make_classification
from keras.wrappers import SKLearnClassifier
X, y = make_classification(n_samples=1000, n_features=10, n_classes=2) # FIX: n_classes 3 -> 2
est = SKLearnClassifier(
model=dynamic_model(X, y, loss="categorical_crossentropy", layers=[20, 20, 20]) # pass in compiled Model instance instead of callable
)
est.fit(X, y, epochs=5)
The error arises when fitting the model in that last line and is reproduced below. I believe this is from the fact that the model is cloned by default in self._get_model()
, and clone_model()
does not recompile the model.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-24-c9dcff454e13>](https://localhost:8080/#) in <cell line: 0>()
27 )
28
---> 29 est.fit(X, y, epochs=5)
1 frames
[/usr/local/lib/python3.11/dist-packages/keras/src/wrappers/sklearn_wrapper.py](https://localhost:8080/#) in fit(self, X, y, **kwargs)
162 y = self._process_target(y, reset=True)
163 model = self._get_model(X, y)
--> 164 _check_model(model)
165
166 fit_kwargs = self.fit_kwargs or {}
[/usr/local/lib/python3.11/dist-packages/keras/src/wrappers/utils.py](https://localhost:8080/#) in _check_model(model)
25 # compile model if user gave us an un-compiled model
26 if not model.compiled or not model.loss or not model.optimizer:
---> 27 raise RuntimeError(
28 "Given model needs to be compiled, and have a loss and an "
29 "optimizer."
RuntimeError: Given model needs to be compiled, and have a loss and an optimizer.