Working on a mini-gpt example for @monicadsong. Loads a decently large (few GB) tf.data.Dataset. Colab:

https://colab.research.google.com/gist/mattdangerw/4f871c46f3eb5af49f828e2aea3bef79/mini-gpt-from-scatch.ipynb

This works as written on tf and jax backends without issue, but on the torch backend we OOM the GPU in the middle of the first epoch. This appears to be a leak or something inconsistent as we see this a variable number of steps into training. A few hundred or few thousand depending on the run.

[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
    689     def __call__(self, y_true, y_pred, sample_weight=None):
    690         with ops.name_scope(self.name):
--> 691             return self.call(y_true, y_pred, sample_weight)
    692 
    693     def call(self, y_true, y_pred, sample_weight=None):

[/usr/local/lib/python3.11/dist-packages/keras/src/trainers/compile_utils.py](https://localhost:8080/#) in call(self, y_true, y_pred, sample_weight)
    698             _, loss_fn, loss_weight, _ = self._flat_losses[0]
    699             loss_value = ops.cast(
--> 700                 loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
    701             )
    702             if loss_weight is not None:

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/loss.py](https://localhost:8080/#) in __call__(self, y_true, y_pred, sample_weight)
     65             )
     66 
---> 67             losses = self.call(y_true, y_pred)
     68             out_mask = backend.get_keras_mask(losses)
     69 

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in call(self, y_true, y_pred)
     31         y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)
     32         y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)
---> 33         return self.fn(y_true, y_pred, **self._fn_kwargs)
     34 
     35     def get_config(self):

[/usr/local/lib/python3.11/dist-packages/keras/src/losses/losses.py](https://localhost:8080/#) in sparse_categorical_crossentropy(y_true, y_pred, from_logits, ignore_class, axis)
   2244         )
   2245 
-> 2246     res = ops.sparse_categorical_crossentropy(
   2247         y_true,
   2248         y_pred,

[/usr/local/lib/python3.11/dist-packages/keras/src/ops/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
   1961             from_logits=from_logits, axis=axis
   1962         ).symbolic_call(target, output)
-> 1963     return backend.nn.sparse_categorical_crossentropy(
   1964         target, output, from_logits=from_logits, axis=axis
   1965     )

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in sparse_categorical_crossentropy(target, output, from_logits, axis)
    705         output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())
    706         log_prob = torch.log(output)
--> 707     target = one_hot(target, output.shape[axis], axis=axis)
    708     return -torch.sum(target * log_prob, dim=axis)
    709 

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/nn.py](https://localhost:8080/#) in one_hot(x, num_classes, axis, dtype, sparse)
    629     # `where` afterwards.
    630     output = tnn.one_hot(maximum(x, 0), num_classes)
--> 631     output = where(expand_dims(x, axis=-1) >= 0, output, zero)
    632     output = convert_to_tensor(output, dtype=dtype)
    633     dims = output.dim()

[/usr/local/lib/python3.11/dist-packages/keras/src/backend/torch/numpy.py](https://localhost:8080/#) in where(condition, x1, x2)
   1529         x1 = convert_to_tensor(x1)
   1530         x2 = convert_to_tensor(x2)
-> 1531         return torch.where(condition, x1, x2)
   1532     else:
   1533         return torch.where(condition)

OutOfMemoryError: CUDA out of memory. Tried to allocate 7.81 GiB. GPU 0 has a total capacity of 39.56 GiB of which 7.34 GiB is free. Process 30879 has 32.21 GiB memory in use. Of the allocated memory 25.56 GiB is allocated by PyTorch, and 6.14 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Comment From: mattdangerw

Unclear to me if this is an issue with loss computation where this stack trace is from, or just a leak in a data iterator when going from tf.data -> torch (where this line crashes not because it's the source of the leak but because this is a line that requires a lot of memory).