When using the JAX backend with version 0.4.28, I'm encountering an "Array has been deleted" error during training when a convolutional layer is placed within a RematScope. The deleted array is the convolutional kernel (weights) of the layer.

This issue seems to be specific to convolutional layers, as using a dense layer within the RematScope works without errors. The training function is already jit-compiled.

JAX version: 0.4.28 Link to failed invocation: https://btx.cloud.google.com/invocations/d2e33ea2-b22c-44ad-8eb8-ca105455c926/targets/keras%2Fgithub%2Fubuntu%2Fgpu%2Fjax%2Fpresubmit/log