Skip to content

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Aug 15, 2025

Depends on pytorch/pytorch#160781

This PR:

  • Add config to AC to toggle early-stop with a default of False
  • Reverts A2A autograd.Function workaround

More context in #1467 (comment)

Leak reproable from @tianyu-l 's comment:

CONFIG_FILE=./torchtitan/experiments/llama4/train_configs/debug_model.toml ./run_train.sh --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode=full

Without early-stop=False:

[rank0]:[titan] 2025-08-18 06:07:18,537 - root - INFO - step:  1  loss:  8.0588  grad_norm:  1.4694  memory:  0.82GiB(1.03%)  tps: 4,900  tflops: 0.37  mfu: 0.12%
[rank0]:[titan] 2025-08-18 06:07:18,538 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-18 06:07:18,718 - root - INFO - step:  2  loss:  6.6677  grad_norm:  2.1904  memory:  1.35GiB(1.71%)  tps: 91,012  tflops: 6.83  mfu: 2.19%
[rank0]:[titan] 2025-08-18 06:07:18,910 - root - INFO - step:  3  loss:  4.3877  grad_norm:  2.3935  memory:  1.73GiB(2.19%)  tps: 85,603  tflops: 6.43  mfu: 2.06%
[rank0]:[titan] 2025-08-18 06:07:19,076 - root - INFO - step:  4  loss:  4.1728  grad_norm:  3.0683  memory:  2.25GiB(2.84%)  tps: 98,808  tflops: 7.42  mfu: 2.38%
[rank0]:[titan] 2025-08-18 06:07:19,247 - root - INFO - step:  5  loss:  3.5788  grad_norm:  3.1740  memory:  2.64GiB(3.34%)  tps: 96,348  tflops: 7.23  mfu: 2.32%
[rank0]:[titan] 2025-08-18 06:07:19,419 - root - INFO - step:  6  loss:  3.2758  grad_norm:  1.7157  memory:  3.04GiB(3.84%)  tps: 95,197  tflops: 7.15  mfu: 2.29%
[rank0]:[titan] 2025-08-18 06:07:19,581 - root - INFO - step:  7  loss:  3.1672  grad_norm:  1.7939  memory:  3.46GiB(4.37%)  tps: 101,865  tflops: 7.65  mfu: 2.45%
[rank0]:[titan] 2025-08-18 06:07:19,768 - root - INFO - step:  8  loss:  3.0511  grad_norm:  1.2923  memory:  3.88GiB(4.90%)  tps: 87,393  tflops: 6.56  mfu: 2.10%
[rank0]:[titan] 2025-08-18 06:07:19,933 - root - INFO - step:  9  loss:  3.1330  grad_norm:  0.7737  memory:  4.41GiB(5.58%)  tps: 99,667  tflops: 7.48  mfu: 2.40%
[rank0]:[titan] 2025-08-18 06:07:20,126 - root - INFO - step: 10  loss:  2.9731  grad_norm:  0.5428  memory:  4.87GiB(6.15%)  tps: 85,447  tflops: 6.42  mfu: 2.06%

With early-stop=False:

[rank0]:[titan] 2025-08-18 06:09:32,465 - root - INFO - step:  1  loss:  8.0588  grad_norm:  1.4694  memory:  0.61GiB(0.77%)  tps: 4,709  tflops: 0.35  mfu: 0.11%
[rank0]:[titan] 2025-08-18 06:09:32,465 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-08-18 06:09:32,643 - root - INFO - step:  2  loss:  6.6677  grad_norm:  2.1904  memory:  0.70GiB(0.89%)  tps: 92,171  tflops: 6.92  mfu: 2.22%
[rank0]:[titan] 2025-08-18 06:09:32,799 - root - INFO - step:  3  loss:  4.3877  grad_norm:  2.3935  memory:  0.70GiB(0.89%)  tps: 105,162  tflops: 7.90  mfu: 2.53%
[rank0]:[titan] 2025-08-18 06:09:32,960 - root - INFO - step:  4  loss:  4.1729  grad_norm:  3.0684  memory:  0.70GiB(0.89%)  tps: 102,732  tflops: 7.71  mfu: 2.47%
[rank0]:[titan] 2025-08-18 06:09:33,117 - root - INFO - step:  5  loss:  3.5789  grad_norm:  3.1744  memory:  0.70GiB(0.89%)  tps: 104,545  tflops: 7.85  mfu: 2.52%
[rank0]:[titan] 2025-08-18 06:09:33,284 - root - INFO - step:  6  loss:  3.2758  grad_norm:  1.7160  memory:  0.70GiB(0.89%)  tps: 98,415  tflops: 7.39  mfu: 2.37%
[rank0]:[titan] 2025-08-18 06:09:33,437 - root - INFO - step:  7  loss:  3.1672  grad_norm:  1.7937  memory:  0.70GiB(0.89%)  tps: 107,744  tflops: 8.09  mfu: 2.59%
[rank0]:[titan] 2025-08-18 06:09:33,594 - root - INFO - step:  8  loss:  3.0512  grad_norm:  1.2927  memory:  0.70GiB(0.89%)  tps: 104,634  tflops: 7.86  mfu: 2.52%
[rank0]:[titan] 2025-08-18 06:09:33,754 - root - INFO - step:  9  loss:  3.1330  grad_norm:  0.7741  memory:  0.70GiB(0.89%)  tps: 102,704  tflops: 7.71  mfu: 2.47%
[rank0]:[titan] 2025-08-18 06:09:33,926 - root - INFO - step: 10  loss:  2.9730  grad_norm:  0.5423  memory:  0.70GiB(0.89%)  tps: 96,116  tflops: 7.22  mfu: 2.31%

Another workaround as suggested by @xmfan is to have a SAC policy to save the A2A. This PR intends to address the full AC case.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 15, 2025
@soulitzer soulitzer force-pushed the disable-early-stop-full-ac branch from 9219f60 to fa55348 Compare August 15, 2025 21:02
@soulitzer soulitzer marked this pull request as draft August 15, 2025 21:15
@soulitzer soulitzer marked this pull request as ready for review August 15, 2025 21:16
@xmfan
Copy link
Member

xmfan commented Aug 15, 2025

tps is a lot higher with early-stop=False 🤔 , if this is expected, then should we default this on for torchtitan?

@tianyu-l
Copy link
Contributor

@xmfan
wait I thought early-stop is an optimization for throughput, how come disabling it makes things faster

@xmfan
Copy link
Member

xmfan commented Aug 15, 2025

Jeffrey probably knows better, my guess is that there's exposed cpu overhead with the tracking for early-stop, and they take longer than the avoided recompute

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Please also fix training.seed and verify two passes (enable early-stop vs. disable) give the same loss curves.

@soulitzer
Copy link
Contributor Author

Updated log in PR description w/ to fix the seed. There should not be very much additional CPU overhead from tracking early-stop. I suspect the leak case is slower because the cuda caching allocator needs to allocate new memory every iteration rather than being able to settle into a stable state. I also checked with a different config that doesn't leak without early_stop and with early_stop=True seemed a little faster.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me. I think we can also switch back to the non-workaround pytorch a2a in this PR
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/expert_parallel.py#L26
and remove the workaround A2A.

@soulitzer soulitzer force-pushed the disable-early-stop-full-ac branch 3 times, most recently from 4113139 to 8c02e95 Compare August 27, 2025 22:10
@soulitzer soulitzer changed the title Add config to AC to toggle early-stop Add config to AC to toggle early-stop and revert A2A autograd.Function workaround Aug 27, 2025
@pytorch-bot pytorch-bot bot added the ci-no-td label Aug 27, 2025
@soulitzer soulitzer force-pushed the disable-early-stop-full-ac branch from 8c02e95 to 51a77a3 Compare August 27, 2025 22:31
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix!

@tianyu-l tianyu-l linked an issue Aug 27, 2025 that may be closed by this pull request
@soulitzer soulitzer force-pushed the disable-early-stop-full-ac branch from 51a77a3 to 5fcc0f1 Compare August 29, 2025 12:51
@soulitzer soulitzer merged commit 25413d2 into main Aug 29, 2025
6 of 7 checks passed
@tianyu-l tianyu-l deleted the disable-early-stop-full-ac branch August 29, 2025 19:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

possible memory leaking of DP2EP with recompute
3 participants