In the JAX distribution lib, use AbstractMesh instead of Mesh since it doesn't result in a JIT cache misses when the devices change. It may also simplify the distribution API.

Comment From: vedantag17

Is anyone working on this issue?

Comment From: hertschuh

@vedantag17

Is anyone working on this issue?

Not yet, no.

Comment From: vedantag17

Oh, I'm a newbie in Keras, but here's my understanding: The current function builds a concrete device array. In contrast, AbstractMesh requires an immutable tuple of (axis_name, axis_size) pairs. Is this correct?

Comment From: hertschuh

@vedantag17 , yes, that is correct.

Comment From: vedantag17

@hertschuh can you review the PR, there are some failing checks.