Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions tests/unit_tests/test_activation_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.flop_counter import FlopCounterMode

from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.models.llama3.infra.parallelize import apply_ac
from torchtitan.distributed.activation_checkpoint import apply_ac


class ToyModule(nn.Module):
Expand Down Expand Up @@ -67,7 +67,7 @@ def get_bw_flops(model_fn):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(model_selective_ac, ac_config_no_force)
apply_ac(model_selective_ac, ac_config_no_force, False, False)
flops_selective_ac = get_bw_flops(model_selective_ac)

# 3. Per-op SAC with force recompute "moe.router.gate"
Expand All @@ -78,7 +78,7 @@ def get_bw_flops(model_fn):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(model_with_force_first, ac_config_with_force_first)
apply_ac(model_with_force_first, ac_config_with_force_first, False, False)
flops_with_force_first = get_bw_flops(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
Expand All @@ -88,15 +88,15 @@ def get_bw_flops(model_fn):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
)
apply_ac(model_with_force_last, ac_config_with_force_last)
apply_ac(model_with_force_last, ac_config_with_force_last, False, False)
flops_with_force_last = get_bw_flops(model_with_force_last)

# 5. Full AC
model_with_full_ac = ToyModule()
ac_config_full_ac = ACConfig(
mode="full",
)
apply_ac(model_with_full_ac, ac_config_full_ac)
apply_ac(model_with_full_ac, ac_config_full_ac, False, False)
flops_full_ac = get_bw_flops(model_with_full_ac)

self.assertEqual(flops_no_ac, 8.0)
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_act_mem(model_fn):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list
)
apply_ac(model_selective_ac, ac_config_no_force)
apply_ac(model_selective_ac, ac_config_no_force, False, False)
mem_selective_ac = get_act_mem(model_selective_ac)

# 3. Per-op SAC with force recompute "moe.router.gate"
Expand All @@ -144,7 +144,7 @@ def get_act_mem(model_fn):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
)
apply_ac(model_with_force_first, ac_config_with_force_first)
apply_ac(model_with_force_first, ac_config_with_force_first, False, False)
mem_with_force_first = get_act_mem(model_with_force_first)

# 4. Per-op SAC with force recompute "output"
Expand All @@ -154,15 +154,15 @@ def get_act_mem(model_fn):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
)
apply_ac(model_with_force_last, ac_config_with_force_last)
apply_ac(model_with_force_last, ac_config_with_force_last, False, False)
mem_with_force_last = get_act_mem(model_with_force_last)

# 5. Full AC
model_with_full_ac = ToyModule().cuda()
ac_config_full_ac = ACConfig(
mode="full",
)
apply_ac(model_with_full_ac, ac_config_full_ac)
apply_ac(model_with_full_ac, ac_config_full_ac, False, False)
mem_full_ac = get_act_mem(model_with_full_ac)

self.assertEqual(mem_no_ac, 2.0)
Expand All @@ -186,6 +186,8 @@ def test_correctness(self):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=[],
),
False,
False,
)
model_force_first = ToyModule()
model_force_first.load_state_dict(model_no_ac.state_dict())
Expand All @@ -196,6 +198,8 @@ def test_correctness(self):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"],
),
False,
False,
)

model_force_last = ToyModule()
Expand All @@ -207,6 +211,8 @@ def test_correctness(self):
selective_ac_option="op",
per_op_sac_force_recompute_mm_shapes_by_fqns=["output"],
),
False,
False,
)

def run_fwd_bwd(model, batch):
Expand Down
162 changes: 162 additions & 0 deletions torchtitan/distributed/activation_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This file provides the util functions to apply activation checkpointing to the model.
# Technically, this is not a part of distributed, but distributed module is the best place to put it.

from collections import defaultdict

import torch
import torch.nn as nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)

from torchtitan.config.job_config import ActivationCheckpoint as ACConfig
from torchtitan.tools.logging import logger

# for selective op activation checkpointing
_save_list = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's preferable if we can create customized list for each individual model if necessary, in addition to some default save_list.
E.g. MoE and dense models may need different save_list, and it sounds bad if we just mix everything.

This refactor can happen in a separate PR.

torch.ops.aten.mm.default,
torch.ops.aten._scaled_dot_product_efficient_attention.default,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops._c10d_functional.reduce_scatter_tensor.default,
# for low precision training, it's useful to always save
# the result of max, since the absolute maximum is
# used to compute the scaling factor for quantization.
torch.ops.aten.max.default,
torch._higher_order_ops.flex_attention,
}


def _apply_ac_to_transformer_block(
module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None
):
valid_ac_modes = ("full", "selective")
if ac_config.mode not in valid_ac_modes:
raise ValueError(
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}"
)

if ac_config.mode == "full":
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)

assert ac_config.mode == "selective", f"{ac_config.mode}"
use_op_sac = ac_config.selective_ac_option == "op"
use_layer_sac = ac_config.selective_ac_option.isdigit()
if not use_op_sac and not use_layer_sac:
raise ValueError(
f"Invalid selective AC option: {ac_config.selective_ac_option}. "
f"Valid options: 'op' or a positive int representing layer frequency"
)
if use_op_sac:
from torch.utils.checkpoint import (
CheckpointPolicy,
create_selective_checkpoint_contexts,
)

mm_recompute_shapes = set()
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0:
for module_fqn, submod in module.named_modules():
fqn = module_fqn
if base_fqn is not None:
fqn = f"{base_fqn}.{module_fqn}"
if not any(
filter_fqn in fqn
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns
):
continue
if not isinstance(submod, nn.Linear):
raise ValueError(
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match "
f"a nn.Linear, but got: {submod}"
)
out_f, in_f = submod.weight.shape
mm_recompute_shapes.add((in_f, out_f))
logger.debug(
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}"
)

def _get_custom_policy(meta):
def _custom_policy(ctx, func, *args, **kwargs):
mode = "recompute" if ctx.is_recompute else "forward"
mm_count_key = f"{mode}_mm_count"
if func == torch.ops.aten.mm.default:
if args[1].shape in mm_recompute_shapes:
return CheckpointPolicy.PREFER_RECOMPUTE
meta[mm_count_key] += 1
# Saves output of all compute ops, except every second mm
to_save = func in _save_list and not (
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
)
return (
CheckpointPolicy.MUST_SAVE
if to_save
else CheckpointPolicy.PREFER_RECOMPUTE
)

return _custom_policy

def selective_checkpointing_context_fn():
meta = defaultdict(int)
return create_selective_checkpoint_contexts(_get_custom_policy(meta))

return ptd_checkpoint_wrapper(
module,
context_fn=selective_checkpointing_context_fn,
preserve_rng_state=False,
)
elif use_layer_sac:
# Checkpoint every `ac_freq` of the modules passed to this function
ac_freq = int(ac_config.selective_ac_option)
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0)
ptd_checkpoint_wrapper._count += 1
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0:
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
else:
return module


def apply_ac(
model: nn.Module,
ac_config: ACConfig,
model_compile_enabled: bool,
use_flex_attn: bool,
):
"""Apply activation checkpointing to the model.

Note that SAC, Flex Attention and model compilation have some conflicts.
We explicitly ask the user to pass these configs to warn if there are conflicts.
"""

if use_flex_attn and not model_compile_enabled:
logger.warning(
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
"POTENTIAL PERFORMANCE ISSUE DETECTED:\n"
"Flex attention requires compilation for optimal performance and will be\n"
"compiled automatically regardless of config.compile settings. However,\n"
"Selective Activation Checkpointing (SAC) requires compilation to be applied\n"
"at the outermost level (e.g., compile(SAC(model))). Othewise the compilation\n"
"will be ignored."
"\n"
"Without enabling config.compile, the apply order will be:\n"
"SAC(compile(flex_attention)). The compilation of flex_attention will be\n"
"skipped, which results in poor performance.\n"
"\n"
"For best results, enable config.compile to ensure proper compilation order:\n"
"compile(SAC(compile(flex_attention)))\n"
"\n"
"The innermost torch.compile will be ignored, but the outermost will take\n"
"effect and provide optimal performance.\n"
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
)
for layer_id, transformer_block in model.layers.named_children():
transformer_block = _apply_ac_to_transformer_block(
transformer_block, ac_config, base_fqn=f"layers.{layer_id}"
)
model.layers.register_module(layer_id, transformer_block)

logger.info(f"Applied {ac_config.mode} activation checkpointing to the model")
22 changes: 12 additions & 10 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims

from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.expert_parallel import (
ExpertParallel,
ExpertTensorParallel,
Expand All @@ -29,8 +29,7 @@
TensorParallel,
)
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp

from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
from torchtitan.models.llama3.infra.parallelize import apply_ddp
from torchtitan.tools.logging import logger


Expand All @@ -57,10 +56,8 @@ def parallelize_llama(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

if (
job_config.parallelism.context_parallel_degree > 1
and model.model_args.use_flex_attn
):
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")

if parallel_dims.tp_enabled:
Expand Down Expand Up @@ -98,12 +95,17 @@ def parallelize_llama(
etp_enabled=parallel_dims.etp_enabled,
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)

model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
if job_config.activation_checkpoint.mode != "none":
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled,
use_flex_attn,
)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
Expand Down
15 changes: 9 additions & 6 deletions torchtitan/experiments/qwen3/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.expert_parallel import NoParallel
from torchtitan.models.llama3.infra.parallelize import (
apply_ac,
apply_compile,
apply_ddp,
apply_fsdp,
Expand All @@ -46,10 +46,8 @@ def parallelize_qwen3(
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
"""

if (
job_config.parallelism.context_parallel_degree > 1
and model.model_args.use_flex_attn
):
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
raise NotImplementedError("CP support for FlexAttention is still in progress.")

model_compile_enabled = (
Expand Down Expand Up @@ -82,7 +80,12 @@ def parallelize_qwen3(
)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled,
use_flex_attn,
)

# turn on per-TransformerBlock compile after AC wrapping and before FSDP
if model_compile_enabled:
Expand Down
14 changes: 12 additions & 2 deletions torchtitan/experiments/simple_fsdp/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.distributed import ParallelDims
from torchtitan.distributed.activation_checkpoint import apply_ac
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp
from torchtitan.models.llama3.infra.parallelize import apply_tp
from torchtitan.tools.logging import logger

from .simple_fsdp import data_parallel, MixedPrecisionPolicy
Expand Down Expand Up @@ -60,7 +61,16 @@ def parallelize_llama(
maybe_enable_async_tp(job_config, tp_mesh)

if job_config.activation_checkpoint.mode != "none":
apply_ac(model, job_config.activation_checkpoint)
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
model_compile_enabled = (
job_config.compile.enable and "model" in job_config.compile.components
)
apply_ac(
model,
job_config.activation_checkpoint,
model_compile_enabled,
use_flex_attn,
)

# apply data parallel
if (
Expand Down
Loading