Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Aug 27, 2025

This PR refactors the activation checkpoint by moving apply_ac() out of the llama3 parallelize.py module. Additionally, it introduces a warning about the configuration combinations involving SAC, torch.compile, and flex_attention to inform users of potential issues.

This PR depends on pytorch/pytorch#161541

cc., @drisspg @bdhirsh @soulitzer

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 27, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. shall we update all the parallelize.py files to depend on this file?
  2. @xmfan mentioned that there could be silent numerical incorrectness when we compile MoE, which sounds concerning if we always recommend compiling to work with FlexAttention in this PR.

from torchtitan.tools.logging import logger

# for selective op activation checkpointing
_save_list = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's preferable if we can create customized list for each individual model if necessary, in addition to some default save_list.
E.g. MoE and dense models may need different save_list, and it sounds bad if we just mix everything.

This refactor can happen in a separate PR.

@fegin
Copy link
Contributor Author

fegin commented Aug 27, 2025

shall we update all the parallelize.py files to depend on this file?

I did, but somehow the change was not uploaded. It should be good now.

@xmfan mentioned that there could be silent numerical incorrectness when we compile MoE, which sounds concerning if we always recommend compiling to work with FlexAttention in this PR.

ye, I also don't want to force users to always torch.compile. But even a hack to enable SAC + FlexAttention requires some discussion. As for now, I would keep this suggestion since we mainly focus on performance at this moment. We should make the hack work soon.

cc., @drisspg @soulitzer @bdhirsh

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

@fegin
Copy link
Contributor Author

fegin commented Aug 27, 2025

uh, there are some conflicts, let me fix it

fegin added 3 commits August 27, 2025 14:42
This PR refactors the activation checkpoint by moving `apply_ac()` out
of the llama3 `parallelize.py` module. Additionally, it introduces a
warning about the configuration combinations involving SAC,
`torch.compile`, and `flex_attention` to inform users of potential
issues.
@fegin fegin force-pushed the chienchin/flex_sac branch from 9cc3dcf to 42a7cae Compare August 27, 2025 21:49
@fegin fegin merged commit 8a6c9fe into main Aug 29, 2025
8 checks passed
@fegin fegin deleted the chienchin/flex_sac branch August 29, 2025 04:43
@xmfan
Copy link
Member

xmfan commented Aug 29, 2025

@xmfan mentioned that there could be silent numerical incorrectness when we compile MoE

@tianyu-l Right now, I think this applies only to when you set torch._dynamo.config.capture_scalar_outputs=True. Without it, we have small graphs and I don't think we have any inplace ops + autograd functions in the same graph. For context, the issue is: pytorch/pytorch#161275

@bdhirsh
Copy link

bdhirsh commented Aug 29, 2025

@xmfan so just to confirm:

  • capture_scalar_outputs=True gave us bigger graphs (fewer graph breaks)

  • those bigger graphs caused us to capture an inplace op + autograd.Function in the same dynamo region, causing the correctness issue linked above

  • Now that @soulitzer has removed the autograd.Function for all2all, though (see Add config to AC to toggle early-stop and revert A2A autograd.Function workaround #1580), I would imagine that it's safe to add back capture_scalar_outputs=True in titan if it gives meaningful perf wins, no?

@xmfan
Copy link
Member

xmfan commented Aug 29, 2025

Yes on 1 and 2. For 3, there's other things broken with capture_scalar_outputs=True, it's still broken in main rn: #1649.

wwwjn pushed a commit that referenced this pull request Sep 3, 2025
This PR refactors the activation checkpoint by moving `apply_ac()` out
of the llama3 `parallelize.py` module. Additionally, it introduces a
warning about the configuration combinations involving SAC,
`torch.compile`, and `flex_attention` to inform users of potential
issues.

This PR depends on pytorch/pytorch#161541

cc., @drisspg @bdhirsh @soulitzer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants