-
Notifications
You must be signed in to change notification settings - Fork 495
Add config to AC to toggle early-stop and revert A2A autograd.Function workaround #1580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9219f60
to
fa55348
Compare
tps is a lot higher with |
@xmfan |
Jeffrey probably knows better, my guess is that there's exposed cpu overhead with the tracking for early-stop, and they take longer than the avoided recompute |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Please also fix training.seed and verify two passes (enable early-stop vs. disable) give the same loss curves.
Updated log in PR description w/ to fix the seed. There should not be very much additional CPU overhead from tracking early-stop. I suspect the leak case is slower because the cuda caching allocator needs to allocate new memory every iteration rather than being able to settle into a stable state. I also checked with a different config that doesn't leak without early_stop and with early_stop=True seemed a little faster. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me. I think we can also switch back to the non-workaround pytorch a2a in this PR
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L26
and remove the workaround A2A.
4113139
to
8c02e95
Compare
8c02e95
to
51a77a3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix!
51a77a3
to
5fcc0f1
Compare
Depends on pytorch/pytorch#160781
This PR:
More context in #1467 (comment)
Leak reproable from @tianyu-l 's comment:
Without early-stop=False:
With early-stop=False:
Another workaround as suggested by @xmfan is to have a SAC policy to save the A2A. This PR intends to address the full AC case.