Skip to content

[WIP] Add autocast in torchax #9361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft

[WIP] Add autocast in torchax #9361

wants to merge 3 commits into from

Conversation

qihqi
Copy link
Collaborator

@qihqi qihqi commented Jun 16, 2025

How torch.autocast works:

When you do

with torch.autocast(device):
  some math

torch will do the following:

  1. Check the device module (in our case it's in the device_module.py file) which is registered here: https://github.com/pytorch/xla/blob/master/torchax/torchax/__init__.py#L92; on what dtypes are supported for autocast.
  2. switch dispatch key for the underlying math. For example, CPU device -> AutocastCPU; CUDA -> AutocastCUDA. We use the PrivateUse1 dispatch key so this will dispatch with AutocastPrivateUse1. Usual case of CUDA, it will call an op registered to it: for example, for AutocastCPU, it is registered here: https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L322
    a. Above, L332 registers fallback meaning every op to AutocastCPU will automatically fallback to next (presumably to CPU), i.e. unless otherwise specified, it will be no-op. This line prevents errors. and then lines 329 and beyond defines those op that do have additional behavior; these are generated by a cpp macro + templates.

What we need to do to support autocast:

  1. add this function
def get_amp_supported_dtype():
  return [torch.float16, torch.bfloat16]

to tell torch that we do support autocast. At this point it will run but will not actually do autocast yet

  1. Change tensor base device to 'privateuseone' instead of meta in here: https://github.com/pytorch/xla/blob/master/torchax/torchax/tensor.py#L62 because if we use meta device we won't get dispatch from AutocastPrivateUse1. At this point we try running autocast we would see errors on op not registered for AutocastPrivateUse1.
  2. Now we want to register the ops, and we can do that in Python using the torch.library API. The exact incantation is in the autocat_policies.py file. In that file, we still need to reimplement the logic of autocast, mainly downcast the input before calling certain ops.
  3. Fix up errors from changing device from meta to privateuseone introduced in 2. There are many decompositions that calls into CPP and it will check Cpp registrations. Device Guard is one of those: (note: filed Ability to set device guard in Python pytorch#156052 to do that in python). This can be fixed with adding a Cpp file that calls the Guard registration (we don't really need to use the features provided by the device guard, just need to register it so Pytorch doesn't complain).
  4. There is one more tests fail (on autograd from j2t_autograd) likely more cpp registration is needed.
  5. Move the building of cpp file from runtime to compile time by fiddling with pyproject.toml / setup.py etc.

@qihqi qihqi changed the title Add autocast [WIP] Add autocast Jun 16, 2025
@qihqi qihqi marked this pull request as draft June 16, 2025 14:01
@qihqi qihqi changed the title [WIP] Add autocast [WIP] Add autocast in torchax Jun 16, 2025
qihqi added a commit that referenced this pull request Jun 16, 2025
This PR implements the 3 autocast policies that we use
and wires them in the Environment.

Wiring it through torch infrastructure so that torch.autocast
also work is WIP in #9361
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant