diff --git a/torchax/dev-requirements.txt b/torchax/dev-requirements.txt index ac917249759..c23caf17ad9 100644 --- a/torchax/dev-requirements.txt +++ b/torchax/dev-requirements.txt @@ -3,3 +3,4 @@ torch==2.6.0 ; sys_platform == 'darwin' # macOS torch==2.6.0+cpu; sys_platform != 'darwin' # Non-macOS (CPU-only), like on TPU yapf==0.40.2 # N.B.: keep in sync with `infra/ansible/config/pip.yaml`, `.github/workflows/lintercheck.yml` flax==0.10.6 +ninja==1.11.1.4 diff --git a/torchax/test/test_autocast.py b/torchax/test/test_autocast.py new file mode 100644 index 00000000000..140ac272fc9 --- /dev/null +++ b/torchax/test/test_autocast.py @@ -0,0 +1,43 @@ +import unittest +import jax +import jax.numpy as jnp +import torchax +from torchax import interop +import torch + + + +class AutocastTest(unittest.TestCase): + + def setUp(self): + self.env = torchax.default_env() + + + def test_auto_cast_ir(self): + with self.env: + with torch.autocast('jax', dtype=torch.bfloat16): + a = jax.ShapeDtypeStruct((2,2), jnp.float32) + b = jax.ShapeDtypeStruct((2,2), jnp.float32) + ir_text = jax.jit(interop.jax_view(torch.matmul)).lower(a, b).as_text() + self.assertIn('tensor<2x2xbf16>', ir_text) + + def test_auto_cast_matmul(self): + with self.env: + a = torch.randn(2, 2, device='jax') + b = torch.randn(2, 2, device='jax') + with torch.autocast('jax', dtype=torch.bfloat16): + c = a @ b + + self.assertEqual(c.dtype, torch.bfloat16) + + with torch.autocast('cpu', dtype=torch.bfloat16): + c_cpu = a.cpu() @ b.cpu() + + self.assertTrue(torch.allclose(c.cpu(), c_cpu)) + + + +if __name__ == '__main__': + unittest.main() + + diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py index 7f0e1569633..034e2b3bf80 100644 --- a/torchax/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -174,7 +174,7 @@ def run_export_and_compare(testcase, # Sort related ops should ignore index; # For example: sort( [1, 0, 0]) -> [0, 0, 1] # the correct index can be [1, 2, 0] or [2, 1, 0] -should_ignore_indexes = {"topk", "mode", "kthvalue"} +should_ignore_indexes = {"topk", "mode", "kthvalue", "linalg.solve_ex"} class TestOpInfo(TestCase): diff --git a/torchax/torchax/__init__.py b/torchax/torchax/__init__.py index 36a49f80572..6c03efedc40 100644 --- a/torchax/torchax/__init__.py +++ b/torchax/torchax/__init__.py @@ -80,16 +80,27 @@ def disable_temporarily(): torch.utils.rename_privateuse1_backend('jax') unsupported_dtype = [torch.quint8] -torch.utils.generate_methods_for_privateuse1_backend( - for_tensor=True, - for_module=True, - for_storage=True, - unsupported_dtype=unsupported_dtype) +# torch.utils.generate_methods_for_privateuse1_backend( +# for_tensor=True, +# for_module=True, +# for_storage=True, +# unsupported_dtype=unsupported_dtype) import jax import torchax.device_module - -torch._register_device_module('jax', torchax.device_module) +import torch.utils.cpp_extension + +torch._register_device_module('privateuseone', torchax.device_module) + +module = torch.utils.cpp_extension.load( + name="custom_device_extension", + sources=[ + os.path.join(os.path.dirname(__file__), "cpp/registration.cpp"), + ], + extra_include_paths=["cpp_extensions"], + extra_cflags=["-g"], + verbose=True, +) def enable_accuracy_mode(): diff --git a/torchax/torchax/cpp/registration.cpp b/torchax/torchax/cpp/registration.cpp new file mode 100644 index 00000000000..ded25ab1f35 --- /dev/null +++ b/torchax/torchax/cpp/registration.cpp @@ -0,0 +1,37 @@ +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include + +// This basic implementation doesn't bother dealing with different device indices +// (e.g. custom_device:0 vs. custom_device:1). +// We could do that by letting the user pass in a device index in our exposed device function. +// Note that if you do that, you'll also need to register a device guard to core. +// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. +c10::Device get_custom_device(int idx) { + return c10::Device(c10::DeviceType::PrivateUse1, idx); +} + +C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::FakeGuardImpl); + + +// Here, we're exposing a custom device object that corresponds to our custom backend. +// We do this using pybind: exposing an "extension_name.custom_device()" function in python, +// that's implemented in C++. +// The implementation in this file maps directly to the `PrivateUse1` device type. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("custom_device", &get_custom_device, "get custom device object"); +} \ No newline at end of file diff --git a/torchax/torchax/device_module.py b/torchax/torchax/device_module.py index 20fceaf06b4..afdbba21cbd 100644 --- a/torchax/torchax/device_module.py +++ b/torchax/torchax/device_module.py @@ -24,3 +24,8 @@ def is_available(): def current_device(): return 0 + + +import torch +def get_amp_supported_dtype(): + return [torch.float16, torch.bfloat16] diff --git a/torchax/torchax/ops/autocast_policy.py b/torchax/torchax/ops/autocast_policy.py new file mode 100644 index 00000000000..2736169bc57 --- /dev/null +++ b/torchax/torchax/ops/autocast_policy.py @@ -0,0 +1,210 @@ +import torch +import torch._C +from torch.utils import _pytree as pytree + +def call_with_next_key(op, args, kwargs): + return op(*args, **kwargs) + +target_precision = torch.bfloat16 + +def lower_precision_fp(op): + def inner(*args, **kwargs): + target_precision = torch.get_autocast_dtype('privateuseone') + autocast_keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutocastPrivateUse1) + with torch._C._ExcludeDispatchKeyGuard(autocast_keyset): + is_float_tensor = lambda a: isinstance(a, torch.Tensor) and a.is_floating_point() + args, kwargs = pytree.tree_map_only( + is_float_tensor, + lambda x: x.to(target_precision), + (args, kwargs)) + return op(*args, **kwargs) + return inner + + +lib = torch.library.Library('aten', 'FRAGMENT') +my_lib = torch.library.Library('_', 'IMPL', 'AutocastPrivateUse1') +my_lib.fallback(torch.library.fallthrough_kernel) + + +for op in [torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.bmm.default, + torch.ops.aten.mm.default, + torch.ops.aten.baddbmm.default, + torch.ops.aten.addmm.default, + torch.ops.aten.addbmm.default, + torch.ops.aten.linear.default, + torch.ops.aten.matmul.default, + torch.ops.aten.conv_tbc.default, + torch.ops.aten.conv_transpose1d.default, + torch.ops.aten.conv_transpose2d.input, + torch.ops.aten.conv_transpose3d.input, + torch.ops.aten.prelu.default, + torch.ops.aten.relu.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.einsum.default, + ]: + lib.impl(op.name(), lower_precision_fp(op), "AutocastPrivateUse1", with_keyset=False) + +# https://github.com/pytorch/xla/blob/20899c7258680a36cd3bec1c820e8a52c16a4bbf/torch_xla/csrc/autocast_mode.cpp#L29 +# enum class CastPolicy : uint8_t { +# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before +# // running the op. Currently, lower_precision_fp is +# // fp16 for AutocastCUDA, and is defined by user +# // (default bf16) for AutocastCPU or other device. +# fp32, // Cast all inputs to at::kFloat before running the op. +# fp32_set_opt_dtype, // Treats functions (like softmax) that +# // 1. we'd like to run in fp32 and +# // 2. have a std::optional arg that controls +# // the output type. +# // fp32_set_opt_dtype wrappers' policy is: if the output +# // type is already set, don't touch it, otherwise, set +# // it to at::kFloat. +# fp32_append_dtype, // Treats functions (like norm) that +# // 1. we'd like to run in fp32 and +# // 2. have some overloads that accept an output type and +# // other overloads that don't. +# // fp32_append_dtype wrappers wrap the overloads that don't +# // have an output dtype. +# // The wrapper policy is: append at::kFloat to the args, +# // and redispatch to the type-aware overload. +# promote, // Run in the widest dtype among several args. +# }; +# TORCH_LIBRARY_IMPL(aten, AutocastXLA, m) { +# // lower_precision_fp cast policy +# KERNEL_XLA(conv1d, lower_precision_fp) +# KERNEL_XLA2(conv1d, padding, lower_precision_fp) +# KERNEL_XLA(conv2d, lower_precision_fp) +# KERNEL_XLA2(conv2d, padding, lower_precision_fp) +# KERNEL_XLA(conv3d, lower_precision_fp) +# KERNEL_XLA2(conv3d, padding, lower_precision_fp) +# KERNEL_XLA(bmm, lower_precision_fp) +# KERNEL_XLA(mm, lower_precision_fp) +# KERNEL_XLA(baddbmm, lower_precision_fp) +# KERNEL_XLA(addmm, lower_precision_fp) +# KERNEL_XLA(addbmm, lower_precision_fp) +# KERNEL_XLA(linear, lower_precision_fp) +# KERNEL_XLA(matmul, lower_precision_fp) +# KERNEL_XLA(conv_tbc, lower_precision_fp) +# KERNEL_XLA(conv_transpose1d, lower_precision_fp) +# KERNEL_XLA2(conv_transpose2d, input, lower_precision_fp) +# KERNEL_XLA2(conv_transpose3d, input, lower_precision_fp) +# KERNEL_XLA(prelu, lower_precision_fp) +# KERNEL_XLA(relu, lower_precision_fp) +# KERNEL_XLA(max_pool2d, lower_precision_fp) +# KERNEL_XLA(einsum, lower_precision_fp) +# // Disable `scaled_dot_product_attention` for now since it causes +# // undefined symbol with official torch whl. +# // KERNEL_XLA(scaled_dot_product_attention, lower_precision_fp) + +# // fp32 cast policy +# // Commented out ops are included in the AutoCastCPU Policy, +# // but not lowered. Enable if op is lowered. +# KERNEL_XLA(batch_norm, fp32) +# KERNEL_XLA(_softmax, fp32) +# KERNEL_XLA2(softmax, int, fp32) +# KERNEL_XLA2(softmax, Dimname, fp32) +# KERNEL_XLA2(log_softmax, int, fp32) +# KERNEL_XLA2(log_softmax, Dimname, fp32) +# KERNEL_XLA(binary_cross_entropy, fp32) +# // KERNEL_XLA(grid_sampler, fp32) +# // KERNEL_XLA(polar, fp32) +# KERNEL_XLA2(pow, Tensor_Scalar, fp32) +# KERNEL_XLA(prod, fp32) +# KERNEL_XLA2(prod, dim_int, fp32) +# KERNEL_XLA2(prod, dim_Dimname, fp32) +# // KERNEL_XLA(quantile, fp32) +# // KERNEL_XLA2(quantile, scalar, fp32) +# // KERNEL_XLA(nanquantile, fp32) +# // KERNEL_XLA2(nanquantile, scalar, fp32) +# // KERNEL_XLA(stft, fp32) +# // KERNEL_XLA2(stft, center, fp32) +# KERNEL_XLA(cdist, fp32) +# // KERNEL_XLA(grid_sampler_2d, fp32) +# // KERNEL_XLA(grid_sampler_3d, fp32) +# KERNEL_XLA(trace, fp32) +# // KERNEL_XLA(view_as_complex, fp32) +# KERNEL_XLA(cholesky, fp32) +# KERNEL_XLA(cholesky_inverse, fp32) +# KERNEL_XLA(cholesky_solve, fp32) +# KERNEL_XLA(inverse, fp32) +# // KERNEL_XLA(lu_solve, fp32) +# // KERNEL_XLA(orgqr, fp32) +# // KERNEL_XLA(ormqr, fp32) +# // KERNEL_XLA(pinverse, fp32) +# KERNEL_XLA(reflection_pad1d, fp32) +# KERNEL_XLA(reflection_pad2d, fp32) +# KERNEL_XLA(replication_pad1d, fp32) +# KERNEL_XLA(replication_pad2d, fp32) +# KERNEL_XLA(replication_pad3d, fp32) +# KERNEL_XLA(mse_loss, fp32) +# KERNEL_XLA(cosine_embedding_loss, fp32) +# KERNEL_XLA(nll_loss, fp32) +# KERNEL_XLA(nll_loss2d, fp32) +# KERNEL_XLA(hinge_embedding_loss, fp32) +# // KERNEL_XLA(poisson_nll_loss, fp32) +# KERNEL_XLA(smooth_l1_loss, fp32) +# KERNEL_XLA(cross_entropy_loss, fp32) +# KERNEL_XLA(l1_loss, fp32) +# // KERNEL_XLA(huber_loss, fp32) +# KERNEL_XLA(margin_ranking_loss, fp32) +# KERNEL_XLA(soft_margin_loss, fp32) +# KERNEL_XLA(triplet_margin_loss, fp32) +# KERNEL_XLA(multi_margin_loss, fp32) +# KERNEL_XLA2(ctc_loss, IntList, fp32) +# KERNEL_XLA2(ctc_loss, Tensor, fp32) +# KERNEL_XLA(kl_div, fp32) +# KERNEL_XLA(multilabel_margin_loss, fp32) +# KERNEL_XLA(binary_cross_entropy_with_logits, fp32) +# // KERNEL_XLA(fft_fft, fp32) +# // KERNEL_XLA(fft_ifft, fp32) +# // KERNEL_XLA(fft_fft2, fp32) +# // KERNEL_XLA(fft_ifft2, fp32) +# // KERNEL_XLA(fft_fftn, fp32) +# // KERNEL_XLA(fft_ifftn, fp32) +# // KERNEL_XLA(fft_rfft, fp32) +# // KERNEL_XLA(fft_irfft, fp32) +# // KERNEL_XLA(fft_rfft2, fp32) +# // KERNEL_XLA(fft_irfft2, fp32) +# // KERNEL_XLA(fft_rfftn, fp32) +# // KERNEL_XLA(fft_irfftn, fp32) +# // KERNEL_XLA(fft_hfft, fp32) +# // KERNEL_XLA(fft_ihfft, fp32) +# // KERNEL_XLA(linalg_cond, fp32) +# // KERNEL_XLA2(linalg_cond, p_str, fp32) +# // KERNEL_XLA(linalg_matrix_rank, fp32) +# // KERNEL_XLA2(linalg_matrix_rank, tol_tensor, fp32) +# // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_tensor, fp32) +# // KERNEL_XLA2(linalg_matrix_rank, atol_rtol_float, fp32) +# // KERNEL_XLA(linalg_solve, fp32) +# // KERNEL_XLA(linalg_cholesky, fp32) +# // KERNEL_XLA(linalg_svdvals, fp32) +# // KERNEL_XLA(linalg_eigvals, fp32) +# // KERNEL_XLA(linalg_eigvalsh, fp32) +# // KERNEL_XLA(linalg_inv, fp32) +# // KERNEL_XLA(linalg_householder_product, fp32) +# // KERNEL_XLA(linalg_tensorinv, fp32) +# // KERNEL_XLA(linalg_tensorsolve, fp32) +# // KERNEL_XLA(fake_quantize_per_tensor_affine, fp32) +# // KERNEL_XLA(geqrf, fp32) +# // KERNEL_XLA(_lu_with_info, fp32) +# KERNEL_XLA(qr, fp32) +# KERNEL_XLA(svd, fp32) +# KERNEL_XLA(triangular_solve, fp32) +# KERNEL_XLA(multilabel_margin_loss_forward, fp32) +# // KERNEL_XLA(linalg_qr, fp32) +# // KERNEL_XLA(linalg_cholesky_ex, fp32) +# KERNEL_XLA(linalg_svd, fp32) +# // KERNEL_XLA(linalg_eig, fp32) +# // KERNEL_XLA(linalg_eigh, fp32) +# // KERNEL_XLA(linalg_lstsq, fp32) +# KERNEL_XLA(linalg_inv_ex, fp32) + +# // promote +# KERNEL_XLA(stack, promote) +# KERNEL_XLA(cat, promote) +# KERNEL_XLA(index_copy, promote) +# KERNEL_XLA2(index_copy, dimname, promote) \ No newline at end of file diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index d73b11ada85..986ebfe2243 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -5440,6 +5440,25 @@ def kthvalue(input, k, dim=None, keepdim=False, *, out=None): dimension, keepdim) return values, indices +# @op(torch.ops.aten.logit) +# def _aten_logit(self: jax.Array, eps: float | None = None) -> jax.Array: +# """ +# Computes the logit function of the input tensor. + +# logit(p) = log(p / (1 - p)) + +# Args: +# self: Input tensor. +# eps: A small value to clip the input tensor to avoid log(0) or division by zero. +# If None, no clipping is performed. + +# Returns: +# A tensor with the logit of each element of the input. +# """ +# if eps is not None: +# self = jnp.clip(self, eps, 1.0 - eps) +# return jnp.log(self / (1.0 - self)) + @op(torch.ops.aten.take) def _aten_take(self, index): diff --git a/torchax/torchax/tensor.py b/torchax/torchax/tensor.py index 6d58199aafe..038e2aaea98 100644 --- a/torchax/torchax/tensor.py +++ b/torchax/torchax/tensor.py @@ -59,7 +59,7 @@ def __new__(cls, elem, env): cls, shape, dtype=dtype, - device="meta", + device="privateuseone:0", requires_grad=False, ) @@ -134,9 +134,9 @@ def dtype(self): def dim(self): return self.ndim - @property - def device(self): - return torch.device("jax:0") + # @property + # def device(self): + # return torch.device("jax:0") @property def jax_device(self): @@ -347,7 +347,7 @@ def get_as_jax_device(self, device: Any): return None # fallback to torch def load_ops(self): - from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms + from torchax.ops import jaten, jtorch, jc10d, jtorchvision_nms, autocast_policy for k, v in itertools.chain(ops_registry.all_aten_ops.items(), ops_registry.all_torch_functions.items()):