In the JAX distribution lib, use AbstractMesh instead of Mesh since it doesn't result in a JIT cache misses when the devices change. It may also simplify the distribution API.
Comment From: vedantag17
Is anyone working on this issue?
Comment From: hertschuh
@vedantag17
Is anyone working on this issue?
Not yet, no.
Comment From: vedantag17
Oh, I'm a newbie in Keras, but here's my understanding: The current function builds a concrete device array. In contrast, AbstractMesh requires an immutable tuple of (axis_name, axis_size) pairs. Is this correct?
Comment From: hertschuh
@vedantag17 , yes, that is correct.
Comment From: vedantag17
@hertschuh can you review the PR, there are some failing checks.