Skip to content

Commit 8ad822d

Browse files
committed
Change for DeviceLocalLayout in JAX 0.6.3.
`DeviceLocalLayout` is becoming `Layout`.
1 parent 1d2865a commit 8ad822d

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,12 +216,14 @@ def _create_sparsecore_distribution(
216216
)
217217
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
218218
# Custom sparsecore layout with tiling.
219+
LayoutClass = (
220+
jax_layout.Layout
221+
if jax.__version_info__ >= (0, 6, 3)
222+
else jax_layout.DeviceLocalLayout
223+
)
219224
# pylint: disable-next=protected-access
220225
sparsecore_layout._backend_layout = jax_layout.Format(
221-
jax_layout.DeviceLocalLayout(
222-
major_to_minor=(0, 1),
223-
_tiling=((8,),),
224-
),
226+
LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)),
225227
jax.sharding.NamedSharding(
226228
device_mesh.backend_mesh,
227229
jax.sharding.PartitionSpec(

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,14 @@ def _create_sparsecore_layout(
4141
)
4242
sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh)
4343
# Custom sparsecore layout with tiling.
44-
sparsecore_layout._backend_layout = jax_layout.Format( # pylint: disable=protected-access
45-
jax_layout.DeviceLocalLayout(
46-
major_to_minor=(0, 1),
47-
_tiling=((8,),),
48-
),
44+
LayoutClass = (
45+
jax_layout.Layout
46+
if jax.__version_info__ >= (0, 6, 3)
47+
else jax_layout.DeviceLocalLayout
48+
)
49+
# pylint: disable-next=protected-access
50+
sparsecore_layout._backend_layout = jax_layout.Format(
51+
LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)),
4952
jax.sharding.NamedSharding(
5053
device_mesh.backend_mesh, jax.sharding.PartitionSpec(axes)
5154
),

0 commit comments

Comments
 (0)