-
Notifications
You must be signed in to change notification settings - Fork 500
Activation Checkpoint improvment #1645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- shall we update all the
parallelize.py
files to depend on this file? - @xmfan mentioned that there could be silent numerical incorrectness when we compile MoE, which sounds concerning if we always recommend compiling to work with FlexAttention in this PR.
from torchtitan.tools.logging import logger | ||
|
||
# for selective op activation checkpointing | ||
_save_list = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's preferable if we can create customized list for each individual model if necessary, in addition to some default save_list.
E.g. MoE and dense models may need different save_list, and it sounds bad if we just mix everything.
This refactor can happen in a separate PR.
I did, but somehow the change was not uploaded. It should be good now.
ye, I also don't want to force users to always torch.compile. But even a hack to enable SAC + FlexAttention requires some discussion. As for now, I would keep this suggestion since we mainly focus on performance at this moment. We should make the hack work soon. cc., @drisspg @soulitzer @bdhirsh |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sgtm
uh, there are some conflicts, let me fix it |
This PR refactors the activation checkpoint by moving `apply_ac()` out of the llama3 `parallelize.py` module. Additionally, it introduces a warning about the configuration combinations involving SAC, `torch.compile`, and `flex_attention` to inform users of potential issues.
9cc3dcf
to
42a7cae
Compare
@tianyu-l Right now, I think this applies only to when you set |
@xmfan so just to confirm:
|
Yes on 1 and 2. For 3, there's other things broken with |
This PR refactors the activation checkpoint by moving `apply_ac()` out of the llama3 `parallelize.py` module. Additionally, it introduces a warning about the configuration combinations involving SAC, `torch.compile`, and `flex_attention` to inform users of potential issues. This PR depends on pytorch/pytorch#161541 cc., @drisspg @bdhirsh @soulitzer
This PR refactors the activation checkpoint by moving
apply_ac()
out of the llama3parallelize.py
module. Additionally, it introduces a warning about the configuration combinations involving SAC,torch.compile
, andflex_attention
to inform users of potential issues.This PR depends on pytorch/pytorch#161541
cc., @drisspg @bdhirsh @soulitzer