I define an EnsembleModel
class that is constructed from a list of other Keras models.
class EnsembleModel(keras.Model):
def __init__(
self,
models: Iterable[keras.Model],
reduce_fn: Callable = keras.ops.mean,
**kwargs):
super(EnsembleModel, self).__init__(**kwargs)
self.models = models
# self.model0 = models[0]
# self.model1 = models[1]
self.reduce_fn = reduce_fn
@tf.function(input_signature=[input_signature])
def call(
self,
input: Dict[Text, Any]) -> Any:
all_outputs = [keras.ops.reshape(model(input), newshape=(-1,)) for model in self.models]
output = self.reduce_fn(all_outputs, axis=0)
return output
averaging_model = EnsembleModel(models=[model0, model1])
I then wish to export the ensemble model:
averaging_model.export("export/1/", input_signature=[input_signature])
But I get an error on the export:
AssertionError: Tried to export a function which references an 'untracked' resource. TensorFlow objects (e.g.
tf.Variable) captured by functions must be 'tracked' by assigning them to an attribute of a tracked object or
assigned to an attribute of the main object directly. See the information below:
Function name = b'__inference_signature_wrapper___call___10899653'
Captured Tensor = <ResourceHandle(name="10671455", device="/job:localhost/replica:0/task:0/device:CPU:0",
container="localhost", type="tensorflow::lookup::LookupInterface", dtype and shapes : "[ ]")>
Trackable referencing this tensor = <tensorflow.python.ops.lookup_ops.StaticHashTable object at
0x7fd62d126990>
Internal Tensor = Tensor("10899255:0", shape=(), dtype=resource)
If I explicitly assign the models to variables in the constructor:
self.model0 = models[0]
self.model1 = models[1]
It works fine (even if I don't reference those variables anywhere else). But I want an instance of the EnsembleModel
class to support an arbitrary list of models. How can I ensure the models are "tracked" so that I don't get an error on export?