Using keras 2.15 installed with tensorflow 2.15, I'm taking a sample code from keras documentation: https://keras.io/guides/serialization_and_saving/ with the only change - I'm saving "h5" file instead of "keras".
Sample code produces output:
numpy: 1.26.4
tensorflow: 2.15.1
keras: 2.15.0
TypeError: Error when deserializing class 'Dense' using config={'name': 'dense', 'trainable': True, 'dtype': 'float32', 'units': 1, 'activation': {'module': 'builtins', 'class_name': 'function', 'config': 'my_package>custom_fn', 'registered_name': 'function'}, 'use_bias': True, 'kernel_initializer': {'module': 'keras.initializers', 'class_name': 'GlorotUniform', 'config': {'seed': None}, 'registered_name': None}, 'bias_initializer': {'module': 'keras.initializers', 'class_name': 'Zeros', 'config': {}, 'registered_name': None}, 'kernel_regularizer': None, 'bias_regularizer': None, 'activity_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None}.
Exception encountered: Unknown activation function: 'function'. Please ensure you are using a `keras.utils.custom_object_scope` and that this object is included in the scope. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
Sample code:
import numpy as np
import tensorflow as tf
import keras
print("numpy:", np.__version__)
print("tensorflow:", tf.__version__)
print("keras:", keras.__version__)
keras.saving.get_custom_objects().clear()
@keras.saving.register_keras_serializable(package="MyLayers")
class CustomLayer(keras.layers.Layer):
def __init__(self, factor):
super().__init__()
self.factor = factor
def call(self, x):
return x * self.factor
def get_config(self):
return {"factor": self.factor}
@keras.saving.register_keras_serializable(package="my_package", name="custom_fn")
def custom_fn(x):
return x**2
# Create the model.
def get_model():
inputs = keras.Input(shape=(4,))
mid = CustomLayer(0.5)(inputs)
outputs = keras.layers.Dense(1, activation=custom_fn)(mid)
model = keras.Model(inputs, outputs)
model.compile(optimizer="rmsprop", loss="mean_squared_error")
return model
# Train the model.
def train_model(model):
input = np.random.random((4, 4))
target = np.random.random((4, 1))
model.fit(input, target)
return model
if __name__ == "__main__":
# This is the only difference wit the documentation
# when using "keras", loading succeeds.
file_format = "h5"
file_name = f"custom_model_reg.{file_format}"
model = get_model()
model = train_model(model)
model.save(file_name)
# Raises error
reconstructed_model = keras.models.load_model(file_name)
If I create this model in keras 2.12, loading succeeds.
Comparing metadata for this model, created in 2.12 and 2.15, there is a certain difference:
Here is 2.12 metadata:
{
"class_name": "Dense",
"config": {
"name": "dense",
"trainable": true,
"dtype": "float32",
"units": 1,
"activation": "custom_fn",
...
and here is 2.15:
"class_name": "Dense",
"config": {
"name": "dense",
"trainable": true,
"dtype": "float32",
"units": 1,
"activation": {
"module": "builtins",
"class_name": "function",
"config": "custom_fn",
"registered_name": "function"
},
...
2.15 changed "activation" definition from string to dictionary.
Further debugging shows that when we try to load "h5" file, execution eventually reaches function keras.src.saving.legacy.serialization.class_and_config_for_serialized_keras_object
, which takes only "class_name" to resolve the object, and, naturally, fails, because class_name is "function":
class_name = config["class_name"]
cls = object_registration.get_registered_object(
class_name, custom_objects, module_objects
)
if cls is None:
raise ValueError(
f"Unknown {printable_module_name}: '{class_name}'. "
So the question is - is there a way to fix this or at least workaround?
tensorflow 2.15 is highest version available to me.
Comment From: sonali-kumari1
Hi @nchaly -
The error you are encountering indicates that custom_fn
is not being recognized during the deserialization process. You can use model.export(file_name)
instead of model.save(file_name)
when saving and loading your model in .h5
with tensorflow(2.15.0) version to avoid the error. Attaching gist for your reference. Thanks!
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.