Problem
The Model.export()
API in Keras 3 supports exporting to a TensorFlow SavedModel artifact for inference. When trying to export Gemma 2 and ShieldGemma to TF SavedModel, I ran into two different ValueError
s:
- If no
input_signature
is provided, aValueError
will be thrown related to a structural mismatch between the expected and actual inputs passed to theGemmaCausalLM
class; and - If an
input_signature
is provided as alist[keras.InputSpec]
, aValueError
will be thrown related the the wrong number of values being passed to a TF function.
However, if yo uwrap the dict
from model.input
in a list
, as input_signature=[model.input]
, the export runs to completion.
This is not restricted to Gemma models, as shown in this minimal reproducible example.
Thanks to @mattdangerw for helping to isolate this minimal example.
Comment From: mattdangerw
When this is fixed the follow should work (and would make a good unit test on the tf backend).
inputs = {
"foo": keras.Input(shape=()),
"bar": keras.Input(shape=()),
}
outputs = keras.layers.Add()([inputs["foo"], inputs["bar"]])
model = keras.Model(inputs, outputs)
model.export("test/", format="tf_saved_model")