diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 8780d70..147531f 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -216,12 +216,14 @@ def _create_sparsecore_distribution( ) sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh) # Custom sparsecore layout with tiling. + LayoutClass = ( + jax_layout.Layout + if jax.__version_info__ >= (0, 6, 3) + else jax_layout.DeviceLocalLayout + ) # pylint: disable-next=protected-access sparsecore_layout._backend_layout = jax_layout.Format( - jax_layout.DeviceLocalLayout( - major_to_minor=(0, 1), - _tiling=((8,),), - ), + LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)), # type: ignore jax.sharding.NamedSharding( device_mesh.backend_mesh, jax.sharding.PartitionSpec( diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py index 72ce049..1ec8216 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding_test.py @@ -41,11 +41,14 @@ def _create_sparsecore_layout( ) sparsecore_layout = keras.distribution.TensorLayout(axes, device_mesh) # Custom sparsecore layout with tiling. - sparsecore_layout._backend_layout = jax_layout.Format( # pylint: disable=protected-access - jax_layout.DeviceLocalLayout( - major_to_minor=(0, 1), - _tiling=((8,),), - ), + LayoutClass = ( + jax_layout.Layout + if jax.__version_info__ >= (0, 6, 3) + else jax_layout.DeviceLocalLayout + ) + # pylint: disable-next=protected-access + sparsecore_layout._backend_layout = jax_layout.Format( + LayoutClass(major_to_minor=(0, 1), _tiling=((8,),)), # type: ignore jax.sharding.NamedSharding( device_mesh.backend_mesh, jax.sharding.PartitionSpec(axes) ),