When using the Keras backend argmin function on an input array containing subnormal float values, Keras consistently returns the index of 0.0 as the minimum value, even though a smaller subnormal value (-1.401298464324817e-45) exists in the array. Other deep learning frameworks such as PyTorch and Chainer correctly return the index of the subnormal value, but Keras (and TensorFlow) return the index of 0.

Expected Behavior:

The expected behavior is for Keras's argmin function to return the index of the smallest value, which should be the subnormal float value (-1.401298464324817e-45) at index 2. Instead, Keras is returning the index of 0.0 (index 0).

Reproduction Code:

import torch
import tensorflow as tf
import numpy as np
from chainer import functions as F
import jax.numpy as jnp
import tensorflow.keras.backend as K

# Input data
input_data = [
    0.0,
    1.1754943508222875e-38,
    -1.401298464324817e-45,
    0.0,
    459367.0
]

# Test PyTorch
def test_pytorch_argmin(input_data):
    tensor = torch.tensor(input_data, dtype=torch.float32)
    result = torch.argmin(tensor).item()
    print(f"PyTorch argmin result: {result}")
    return result

# Test TensorFlow
def test_tensorflow_argmin(input_data):
    tensor = tf.constant(input_data, dtype=tf.float32)
    result = tf.argmin(tensor).numpy()
    print(f"TensorFlow argmin result: {result}")
    return result

# Test Keras using backend
def test_keras_argmin(input_data):
    tensor = K.constant(input_data, dtype=tf.float32)
    result = K.argmin(tensor, axis=-1).numpy()
    print(f"Keras argmin result: {result}")
    return result

# Test Chainer
def test_chainer_argmin(input_data):
    tensor = np.array(input_data, dtype=np.float32)
    result = F.argmin(tensor).data
    print(f"Chainer argmin result: {result}")
    return result

# Test JAX
def test_jax_argmin(input_data):
    tensor = jnp.array(input_data, dtype=jnp.float32)
    result = jnp.argmin(tensor).item()
    print(f"JAX argmin result: {result}")
    return result

if __name__ == "__main__":
    pytorch_result = test_pytorch_argmin(input_data)
    tensorflow_result = test_tensorflow_argmin(input_data)
    keras_result = test_keras_argmin(input_data)
    chainer_result = test_chainer_argmin(input_data)
    jax_result = test_jax_argmin(input_data)

    print("\nSummary of results:")
    print(f"PyTorch argmin: {pytorch_result}")
    print(f"TensorFlow argmin: {tensorflow_result}")
    print(f"Keras argmin: {keras_result}")
    print(f"Chainer argmin: {chainer_result}")
    print(f"JAX argmin: {jax_result}")

Summary of results:
PyTorch argmin: 2
TensorFlow argmin: 0
Keras argmin: 0
Chainer argmin: 2
JAX argmin: 0

Comment From: sachinprasadhs

I checked the result with different backend using Keras 3, torch results in 2 where as TensorFlow results in 0. The results are not consistent across different backend.

import os
os.environ["KERAS_BACKEND"] = "torch"
import numpy as np
import keras

# Input data
input_data = [
    0.0,
    1.1754943508222875e-38,
    -1.401298464324817e-45,
    0.0,
    459367.0
]

def test_keras_argmin(input_data):
    result = keras.ops.argmin(input_data, axis=-1).numpy()
    print(f"Keras argmin result: {result}")
    return result

test_keras_argmin(input_data)

Comment From: jeffcarp

Thanks for reporting this bug as well as #20350. All backend ops should have the same numerics, which should mirror numpy's behavior.

Comment From: jeffcarp

This can be reproed just by switching between np and jnp arrays:

print('jnp', np.argmin(jnp.array([0, -1.1e-45], dtype=jnp.float32)))
print('np ', np.argmin(np.array([0, -1.1e-45], dtype=np.float32)))
jnp 0
np  1

My hunch is something in the JAX/TF ecosystem is flattening subnormals too early.

[edit] This is not limited to argmin/argmax either. The issue seems to be with how tensors are handled in TF/JAX.

print('jnp', jnp.array([0, -1.1e-45], dtype=jnp.float32).sum())
print('np', np.array([0, -1.1e-45], dtype=np.float32).sum())
print('tf', tf.math.reduce_sum(tf.constant([0, -1.1e-45], dtype=tf.float32)).numpy())
print('pt', torch.sum(torch.tensor([0, -1.1e-45], dtype=torch.float32)).item())
jnp 0.0
np -1e-45
tf 0.0
pt -1.401298464324817e-45

Comment From: harshaljanjani

Will raise a PR in a couple of days. Currently solving a similar problem for ops.argmax().
Edit: Turns out the problem shared similarities but wasn't exactly the same. The issue is definitely due to the value -1.401298464324817e-45 being flushed to -0 for efficient operations.

Comment From: harshaljanjani

Status of the Issue After PR: Compatible Backends: NumPy, JAX, PyTorch Yet to be Compatible: TensorFlow Note to Contributors: When I apply a custom mask of a similar nature in TensorFlow, the problem still seems to persist. The subnormal value gets flushed to zero anyway.

Comment From: harshaljanjani

@sachinprasadhs I think this issue can be put to rest, as #20821 has also fixed the TensorFlow backend, thanks!

Comment From: github-actions[bot]

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.