Skip to content

Commit 757fda4

Browse files
committed
fix test
1 parent 224c7df commit 757fda4

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

torchprime/sharding/shard_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def shard_torchax_model_from_config(
186186
"""
187187
import jax
188188
from jax.sharding import NamedSharding, PartitionSpec
189-
from torchax.interop import torch_view
189+
from torchax.interop import jax_view, torch_view
190190

191191
jax_mark_sharding = torch_view(jax.lax.with_sharding_constraint)
192192

@@ -197,7 +197,7 @@ def shard_param(tensor, spec: tuple[str, ...]):
197197
# and models are usually constructed eagerly in torchax.
198198
return torch_view(
199199
jax.make_array_from_callback(
200-
tensor.shape, sharding, lambda slice_index: tensor[slice_index]
200+
tensor.shape, sharding, lambda slice_index: jax_view(tensor[slice_index])
201201
)
202202
)
203203

torchprime/sharding/tests/test_shard_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ def test_shard_model_from_config_torchax():
239239
devices = mesh_utils.create_device_mesh((jax.device_count(),))
240240
mesh = Mesh(devices, ("fsdp",))
241241

242-
model = shard_torchax_model_from_config(model, config, mesh)
242+
with torchax.default_env():
243+
model = shard_torchax_model_from_config(model, config, mesh)
243244

244245
# In order to shard activations, corresponding modules are
245246
# wrapped with ShardedModule.

0 commit comments

Comments
 (0)