diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index 3803da421..4f8d0bd5e 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -44,13 +44,11 @@ class TestApplyAC(unittest.TestCase): def test_flops(self): def get_bw_flops(model_fn): x = torch.randn(512, 512, requires_grad=True) - with torch.utils.checkpoint.set_checkpoint_early_stop(False): - out = model_fn(x) + out = model_fn(x) out.backward() x = torch.randn(512, 512, requires_grad=True) - with torch.utils.checkpoint.set_checkpoint_early_stop(False): - out = model_fn(x) + out = model_fn(x) with FlopCounterMode(display=False) as mode: out.backward() return mode.get_total_flops() / (512**3 * 2) @@ -66,6 +64,7 @@ def get_bw_flops(model_fn): mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list + early_stop=False, ) apply_ac(model_selective_ac, ac_config_no_force, False, False) flops_selective_ac = get_bw_flops(model_selective_ac) @@ -77,6 +76,7 @@ def get_bw_flops(model_fn): mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], + early_stop=False, ) apply_ac(model_with_force_first, ac_config_with_force_first, False, False) flops_with_force_first = get_bw_flops(model_with_force_first) @@ -87,6 +87,7 @@ def get_bw_flops(model_fn): mode="selective", selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], + early_stop=False, ) apply_ac(model_with_force_last, ac_config_with_force_last, False, False) flops_with_force_last = get_bw_flops(model_with_force_last) @@ -95,6 +96,7 @@ def get_bw_flops(model_fn): model_with_full_ac = ToyModule() ac_config_full_ac = ACConfig( mode="full", + early_stop=False, ) apply_ac(model_with_full_ac, ac_config_full_ac, False, False) flops_full_ac = get_bw_flops(model_with_full_ac) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 14ff6e7e8..6f027570f 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -546,6 +546,12 @@ class ActivationCheckpoint: ANY mm with shape matching (*, in) x (in, out) will be force recomputed. """ + early_stop: bool = False + """ + Whether to stop recomputing early when all activations have already been + rematerialized. + """ + @dataclass class Compile: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index b4c28a90b..401c07f83 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -42,7 +42,9 @@ def _apply_ac_to_transformer_block( ) if ac_config.mode == "full": - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + return ptd_checkpoint_wrapper( + module, preserve_rng_state=False, early_stop=ac_config.early_stop + ) assert ac_config.mode == "selective", f"{ac_config.mode}" use_op_sac = ac_config.selective_ac_option == "op" @@ -108,6 +110,7 @@ def selective_checkpointing_context_fn(): module, context_fn=selective_checkpointing_context_fn, preserve_rng_state=False, + early_stop=ac_config.early_stop, ) elif use_layer_sac: # Checkpoint every `ac_freq` of the modules passed to this function @@ -115,7 +118,9 @@ def selective_checkpointing_context_fn(): ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) ptd_checkpoint_wrapper._count += 1 if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + return ptd_checkpoint_wrapper( + module, preserve_rng_state=False, early_stop=ac_config.early_stop + ) else: return module diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 384d9e33f..72d663ebb 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist import torch.nn as nn +from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -23,36 +24,6 @@ from torch.distributed.tensor.placement_types import Placement -# from torch.distributed._functional_collectives import all_to_all_single_autograd -# TODO: there is memory leak issue with AC + all_to_all_single_autograd -# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467 -class _A2A(torch.autograd.Function): - @staticmethod - def forward(ctx, x, out_splits, in_splits, group): - T_out = int(sum(out_splits)) - y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits - dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group) - - ctx.in_splits = in_splits - ctx.out_splits = out_splits - ctx.group = group - return y - - @staticmethod - def backward(ctx, grad_y): - # grad wrt input has length sum(in_splits) - T_in = int(sum(ctx.in_splits)) - grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:])) - dist.all_to_all_single( - grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group - ) - return grad_x, None, None, None - - -def all_to_all_single_autograd(x, out_splits, in_splits, group): - return _A2A.apply(x, out_splits, in_splits, group) - - TOKEN_GROUP_ALIGN_SIZE_M = 8 ValidTokenGroupAlignmentSize = Literal[8, 16, 32]