Keras fails to compute gradients for autoencoder-esce model using Tensorflow backend with mixed precision and jit compilation enabled. See code here: colab.
This is caused by UpSampling2D
layer. When gradients are computed, the type is resolved as float32
instead of float16
, and this causes Relu that comes next to throw a dtype mismatch exception.
The only working workaround I found is explicitly setting dtype to float32
for UpSampling2D
layer. This inserts a cast
node inbetween relu
and upsample
which helps dealing with dtype conversion.
Not sure which project this issue should be submitted to: Keras, TF or XLA