Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,11 @@ class TestApplyAC(unittest.TestCase):
def test_flops(self):
def get_bw_flops(model_fn):
x = torch.randn(512, 512, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = model_fn(x)
out = model_fn(x)
out.backward()

x = torch.randn(512, 512, requires_grad=True)
with torch.utils.checkpoint.set_checkpoint_early_stop(False):
out = model_fn(x)
out = model_fn(x)
with FlopCounterMode(display=False) as mode:
out.backward()
return mode.get_total_flops() / (512**3 * 2)
Expand All @@ -66,6 +64,7 @@ def get_bw_flops(model_fn):
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
early_stop=False,
)
apply_ac(model_selective_ac, ac_config_no_force, False, False)
flops_selective_ac = get_bw_flops(model_selective_ac)
Expand All @@ -77,6 +76,7 @@ def get_bw_flops(model_fn):
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
early_stop=False,
)
apply_ac(model_with_force_first, ac_config_with_force_first, False, False)
flops_with_force_first = get_bw_flops(model_with_force_first)
Expand All @@ -87,6 +87,7 @@ def get_bw_flops(model_fn):
mode="selective",
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
early_stop=False,
)
apply_ac(model_with_force_last, ac_config_with_force_last, False, False)
flops_with_force_last = get_bw_flops(model_with_force_last)
Expand All @@ -95,6 +96,7 @@ def get_bw_flops(model_fn):
model_with_full_ac = ToyModule()
ac_config_full_ac = ACConfig(
mode="full",
early_stop=False,
)
apply_ac(model_with_full_ac, ac_config_full_ac, False, False)
flops_full_ac = get_bw_flops(model_with_full_ac)
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,12 @@ class ActivationCheckpoint:
ANY mm with shape matching (*, in) x (in, out) will be force recomputed.
"""

early_stop: bool = False
"""
Whether to stop recomputing early when all activations have already been
rematerialized.
"""


@dataclass
class Compile:
Expand Down
9 changes: 7 additions & 2 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def _apply_ac_to_transformer_block(
)

if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
)

assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
Expand Down Expand Up @@ -108,14 +110,17 @@ def selective_checkpointing_context_fn():
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
early_stop=ac_config.early_stop,
)
elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
ptd_checkpoint_wrapper._count += 1
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
return ptd_checkpoint_wrapper(
module, preserve_rng_state=False, early_stop=ac_config.early_stop
)
else:
return module

Expand Down
31 changes: 1 addition & 30 deletions torchtitan/distributed/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._functional_collectives import all_to_all_single_autograd
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
Expand All @@ -23,36 +24,6 @@
from torch.distributed.tensor.placement_types import Placement


# from torch.distributed._functional_collectives import all_to_all_single_autograd
# TODO: there is memory leak issue with AC + all_to_all_single_autograd
# This is a temporary fix by @rakkit https://github.com/pytorch/torchtitan/issues/1467
class _A2A(torch.autograd.Function):
@staticmethod
def forward(ctx, x, out_splits, in_splits, group):
T_out = int(sum(out_splits))
y = x.new_empty((T_out,) + tuple(x.shape[1:])) # allocate by output splits
dist.all_to_all_single(y, x.contiguous(), out_splits, in_splits, group=group)

ctx.in_splits = in_splits
ctx.out_splits = out_splits
ctx.group = group
return y

@staticmethod
def backward(ctx, grad_y):
# grad wrt input has length sum(in_splits)
T_in = int(sum(ctx.in_splits))
grad_x = grad_y.new_empty((T_in,) + tuple(grad_y.shape[1:]))
dist.all_to_all_single(
grad_x, grad_y.contiguous(), ctx.in_splits, ctx.out_splits, group=ctx.group
)
return grad_x, None, None, None


def all_to_all_single_autograd(x, out_splits, in_splits, group):
return _A2A.apply(x, out_splits, in_splits, group)


TOKEN_GROUP_ALIGN_SIZE_M = 8
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]

Expand Down
Loading