diff --git a/docs/source/features/stablehlo.md b/docs/source/features/stablehlo.md index ec0d0682d627..aedabf66a75e 100644 --- a/docs/source/features/stablehlo.md +++ b/docs/source/features/stablehlo.md @@ -41,8 +41,8 @@ for more details on how to use the MLIR code generated from it. from torch.export import export import torchvision import torch -import torch_xla2 as tx -import torch_xla2.export +import torchax as tx +import torchax.export import jax import jax.numpy as jnp @@ -111,10 +111,10 @@ import unittest import torch import torch.nn.functional as F from torch.library import Library, impl, impl_abstract -import torch_xla2 -import torch_xla2.export -from torch_xla2.ops import jaten -from torch_xla2.ops import jlibrary +import torchax +import torchax.export +from torchax.ops import jaten +from torchax.ops import jlibrary # Create a `mylib` library which has a basic SDPA op. @@ -163,7 +163,7 @@ class LibraryTest(unittest.TestCase): def setUp(self): torch.manual_seed(0) - torch_xla2.default_env().config.use_torch_native_for_cpu_tensor = False + torchax.default_env().config.use_torch_native_for_cpu_tensor = False def test_basic_sdpa_library(self): @@ -179,7 +179,7 @@ class LibraryTest(unittest.TestCase): args = (arg, arg, arg, ) exported = torch.export.export(model, args=args) - stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported) + stablehlo = torchax.export.exported_program_to_stablehlo(exported) module_str = str(stablehlo.mlir_module()) ## TODO Update this machinery from producing function calls to producing