When using the Keras backend argmin
function on an input array containing subnormal float values, Keras consistently returns the index of 0.0
as the minimum value, even though a smaller subnormal value (-1.401298464324817e-45
) exists in the array. Other deep learning frameworks such as PyTorch and Chainer correctly return the index of the subnormal value, but Keras (and TensorFlow) return the index of 0
.
Expected Behavior:
The expected behavior is for Keras's argmin
function to return the index of the smallest value, which should be the subnormal float value (-1.401298464324817e-45
) at index 2. Instead, Keras is returning the index of 0.0
(index 0).
Reproduction Code:
import torch
import tensorflow as tf
import numpy as np
from chainer import functions as F
import jax.numpy as jnp
import tensorflow.keras.backend as K
# Input data
input_data = [
0.0,
1.1754943508222875e-38,
-1.401298464324817e-45,
0.0,
459367.0
]
# Test PyTorch
def test_pytorch_argmin(input_data):
tensor = torch.tensor(input_data, dtype=torch.float32)
result = torch.argmin(tensor).item()
print(f"PyTorch argmin result: {result}")
return result
# Test TensorFlow
def test_tensorflow_argmin(input_data):
tensor = tf.constant(input_data, dtype=tf.float32)
result = tf.argmin(tensor).numpy()
print(f"TensorFlow argmin result: {result}")
return result
# Test Keras using backend
def test_keras_argmin(input_data):
tensor = K.constant(input_data, dtype=tf.float32)
result = K.argmin(tensor, axis=-1).numpy()
print(f"Keras argmin result: {result}")
return result
# Test Chainer
def test_chainer_argmin(input_data):
tensor = np.array(input_data, dtype=np.float32)
result = F.argmin(tensor).data
print(f"Chainer argmin result: {result}")
return result
# Test JAX
def test_jax_argmin(input_data):
tensor = jnp.array(input_data, dtype=jnp.float32)
result = jnp.argmin(tensor).item()
print(f"JAX argmin result: {result}")
return result
if __name__ == "__main__":
pytorch_result = test_pytorch_argmin(input_data)
tensorflow_result = test_tensorflow_argmin(input_data)
keras_result = test_keras_argmin(input_data)
chainer_result = test_chainer_argmin(input_data)
jax_result = test_jax_argmin(input_data)
print("\nSummary of results:")
print(f"PyTorch argmin: {pytorch_result}")
print(f"TensorFlow argmin: {tensorflow_result}")
print(f"Keras argmin: {keras_result}")
print(f"Chainer argmin: {chainer_result}")
print(f"JAX argmin: {jax_result}")
Summary of results:
PyTorch argmin: 2
TensorFlow argmin: 0
Keras argmin: 0
Chainer argmin: 2
JAX argmin: 0
Comment From: sachinprasadhs
I checked the result with different backend using Keras 3, torch results in 2 where as TensorFlow results in 0. The results are not consistent across different backend.
import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import keras
# Input data
input_data = [
0.0,
1.1754943508222875e-38,
-1.401298464324817e-45,
0.0,
459367.0
]
def test_keras_argmin(input_data):
result = keras.ops.argmin(input_data, axis=-1).numpy()
print(f"Keras argmin result: {result}")
return result
test_keras_argmin(input_data)
Comment From: jeffcarp
Thanks for reporting this bug as well as #20350. All backend ops should have the same numerics, which should mirror numpy's behavior.
Comment From: jeffcarp
This can be reproed just by switching between np
and jnp
arrays:
print('jnp', np.argmin(jnp.array([0, -1.1e-45], dtype=jnp.float32)))
print('np ', np.argmin(np.array([0, -1.1e-45], dtype=np.float32)))
jnp 0
np 1
My hunch is something in the JAX/TF ecosystem is flattening subnormals too early.
[edit] This is not limited to argmin/argmax either. The issue seems to be with how tensors are handled in TF/JAX.
print('jnp', jnp.array([0, -1.1e-45], dtype=jnp.float32).sum())
print('np', np.array([0, -1.1e-45], dtype=np.float32).sum())
print('tf', tf.math.reduce_sum(tf.constant([0, -1.1e-45], dtype=tf.float32)).numpy())
print('pt', torch.sum(torch.tensor([0, -1.1e-45], dtype=torch.float32)).item())
jnp 0.0
np -1e-45
tf 0.0
pt -1.401298464324817e-45
Comment From: harshaljanjani
Will raise a PR in a couple of days. Currently solving a similar problem for ops.argmax()
.
Edit: Turns out the problem shared similarities but wasn't exactly the same. The issue is definitely due to the value -1.401298464324817e-45
being flushed to -0
for efficient operations.
Comment From: harshaljanjani
Status of the Issue After PR: Compatible Backends: NumPy, JAX, PyTorch Yet to be Compatible: TensorFlow Note to Contributors: When I apply a custom mask of a similar nature in TensorFlow, the problem still seems to persist. The subnormal value gets flushed to zero anyway.
Comment From: harshaljanjani
@sachinprasadhs I think this issue can be put to rest, as #20821 has also fixed the TensorFlow backend, thanks!
Comment From: github-actions[bot]
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.