From ca53b90349ddfa2f66daaf0d77c6877cbffa2581 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 27 Aug 2025 11:15:46 -0700 Subject: [PATCH 1/5] Activation Checkpoint improvment This PR refactors the activation checkpoint by moving `apply_ac()` out of the llama3 `parallelize.py` module. Additionally, it introduces a warning about the configuration combinations involving SAC, `torch.compile`, and `flex_attention` to inform users of potential issues. --- .../distributed/activation_checkpoint.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 torchtitan/distributed/activation_checkpoint.py diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py new file mode 100644 index 000000000..cdc387e2a --- /dev/null +++ b/torchtitan/distributed/activation_checkpoint.py @@ -0,0 +1,160 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file provides the util functions to apply activation checkpointing to the model. +# Technically, this is not a part of distributed, but distributed module is the best place to put it. + +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) + +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.tools.logging import logger + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +def _apply_ac_to_transformer_block( + module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None +): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + mm_recompute_shapes = set() + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + logger.debug( + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac( + model: nn.Module, + ac_config: ACConfig, + model_compile_enabled: bool, + use_flex_attn: bool, +): + """Apply activation checkpointing to the model. + + Note that SAC, Flex Attention and model compilation have some conflicts. + We explicitly ask the user to pass these configs to warn if there are conflicts. + """ + + if use_flex_attn and not model_compile_enabled: + logger.warning( + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" + "POTENTIAL PERFORMANCE ISSUE DETECTED:\n" + "Flex attention requires compilation for optimal performance and will be\n" + "compiled automatically regardless of config.compile settings. However,\n" + "Selective Activation Checkpointing (SAC) requires compilation to be applied\n" + "at the outermost level (e.g., compile(SAC(model))).\n" + "\n" + "Without enabling config.compile, the apply order will be:\n" + "SAC(compile(flex_attention)), which results in poor performance.\n" + "\n" + "For best results, enable config.compile to ensure proper compilation order:\n" + "compile(SAC(compile(flex_attention)))\n" + "\n" + "The innermost torch.compile will be ignored, but the outermost will take\n" + "effect and provide optimal performance.\n" + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" + ) + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, ac_config, base_fqn=f"layers.{layer_id}" + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") From a6e2743eb6b7bd05756b46a6cd38ec658307f482 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 27 Aug 2025 14:45:28 -0700 Subject: [PATCH 2/5] misc --- .../distributed/activation_checkpoint.py | 6 +- .../experiments/llama4/infra/parallelize.py | 13 +- .../experiments/qwen3/infra/parallelize.py | 10 +- .../experiments/simple_fsdp/parallelize.py | 11 +- .../models/deepseek_v3/infra/parallelize.py | 19 ++- torchtitan/models/llama3/infra/parallelize.py | 134 ++---------------- 6 files changed, 53 insertions(+), 140 deletions(-) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index cdc387e2a..b4c28a90b 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -139,10 +139,12 @@ def apply_ac( "Flex attention requires compilation for optimal performance and will be\n" "compiled automatically regardless of config.compile settings. However,\n" "Selective Activation Checkpointing (SAC) requires compilation to be applied\n" - "at the outermost level (e.g., compile(SAC(model))).\n" + "at the outermost level (e.g., compile(SAC(model))). Othewise the compilation\n" + "will be ignored." "\n" "Without enabling config.compile, the apply order will be:\n" - "SAC(compile(flex_attention)), which results in poor performance.\n" + "SAC(compile(flex_attention)). The compilation of flex_attention will be\n" + "skipped, which results in poor performance.\n" "\n" "For best results, enable config.compile to ensure proper compilation order:\n" "compile(SAC(compile(flex_attention)))\n" diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 3e4dd43f7..45e998201 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -21,6 +21,8 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac + from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, @@ -29,8 +31,7 @@ TensorParallel, ) from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp - -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger @@ -99,7 +100,13 @@ def parallelize_llama( ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 8f7cb06ef..ad60912d8 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -22,9 +22,9 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.expert_parallel import NoParallel from torchtitan.models.llama3.infra.parallelize import ( - apply_ac, apply_compile, apply_ddp, apply_fsdp, @@ -82,7 +82,13 @@ def parallelize_qwen3( ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index 5feffdabb..04a477634 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -10,7 +10,8 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger from .simple_fsdp import data_parallel, MixedPrecisionPolicy @@ -60,7 +61,13 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) # apply data parallel if ( diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c77250d0f..6fc2c0112 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -18,6 +18,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.expert_parallel import NoParallel from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.llama4.infra.parallelize import ( @@ -25,7 +26,7 @@ apply_fsdp, apply_moe_ep_tp, ) -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger @@ -46,9 +47,9 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn + job_config.parallelism.context_parallel_degree > 1 and use_flex_attn ): raise NotImplementedError("CP support for FlexAttention is still in progress.") @@ -89,12 +90,18 @@ def parallelize_deepseekv3( etp_enabled=parallel_dims.etp_enabled, ) - if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) - model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 7d0b5de92..1307302f4 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -7,15 +7,9 @@ # This file applies the PT-D parallelisms (except pipeline parallelism) and various # training techniques (e.g. activation checkpointing and compile) to the Llama model. -from collections import defaultdict -from typing import Optional - import torch import torch.nn as nn from torch.distributed._composable.replicate import replicate -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, -) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy @@ -29,8 +23,8 @@ ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -58,10 +52,8 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: @@ -85,7 +77,12 @@ def parallelize_llama( maybe_enable_async_tp(job_config, world_mesh["tp"]) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -221,119 +218,6 @@ def apply_tp( ) -# for selective op activation checkpointing -_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, -} - - -def _apply_ac_to_transformer_block( - module: nn.Module, ac_config: ACConfig, *, base_fqn: Optional[str] = None -): - valid_ac_modes = ("full", "selective") - if ac_config.mode not in valid_ac_modes: - raise ValueError( - f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" - ) - - if ac_config.mode == "full": - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - - assert ac_config.mode == "selective", f"{ac_config.mode}" - use_op_sac = ac_config.selective_ac_option == "op" - use_layer_sac = ac_config.selective_ac_option.isdigit() - if not use_op_sac and not use_layer_sac: - raise ValueError( - f"Invalid selective AC option: {ac_config.selective_ac_option}. " - f"Valid options: 'op' or a positive int representing layer frequency" - ) - if use_op_sac: - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - mm_recompute_shapes = set() - if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: - for module_fqn, submod in module.named_modules(): - fqn = module_fqn - if base_fqn is not None: - fqn = f"{base_fqn}.{module_fqn}" - if not any( - filter_fqn in fqn - for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns - ): - continue - if not isinstance(submod, nn.Linear): - raise ValueError( - "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " - f"a nn.Linear, but got: {submod}" - ) - out_f, in_f = submod.weight.shape - mm_recompute_shapes.add((in_f, out_f)) - logger.debug( - f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" - ) - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in _save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - - return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, - ) - elif use_layer_sac: - # Checkpoint every `ac_freq` of the modules passed to this function - ac_freq = int(ac_config.selective_ac_option) - ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) - ptd_checkpoint_wrapper._count += 1 - if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - else: - return module - - -def apply_ac(model: nn.Module, ac_config: ACConfig): - """Apply activation checkpointing to the model.""" - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = _apply_ac_to_transformer_block( - transformer_block, ac_config, base_fqn=f"layers.{layer_id}" - ) - model.layers.register_module(layer_id, transformer_block) - - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") - - def apply_compile(model: nn.Module): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to From 42a7cae9e0f9999c93e971aece3e2dcf860739d3 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 27 Aug 2025 14:49:00 -0700 Subject: [PATCH 3/5] misc --- torchtitan/experiments/llama4/infra/parallelize.py | 7 ++----- torchtitan/experiments/qwen3/infra/parallelize.py | 7 ++----- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 45e998201..d3a73704b 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -58,10 +58,8 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: @@ -100,7 +98,6 @@ def parallelize_llama( ) if job_config.activation_checkpoint.mode != "none": - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_ac( model, job_config.activation_checkpoint, diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index ad60912d8..05278a2ec 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -46,10 +46,8 @@ def parallelize_qwen3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") model_compile_enabled = ( @@ -82,7 +80,6 @@ def parallelize_qwen3( ) if job_config.activation_checkpoint.mode != "none": - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_ac( model, job_config.activation_checkpoint, From 259ade47c468bc58de3b842fa72f5a8ffb862521 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 27 Aug 2025 14:55:08 -0700 Subject: [PATCH 4/5] misc --- torchtitan/experiments/llama4/infra/parallelize.py | 8 +++----- torchtitan/experiments/simple_fsdp/parallelize.py | 5 ++++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index d3a73704b..4ea2798b2 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -20,9 +20,7 @@ ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims - from torchtitan.distributed.activation_checkpoint import apply_ac - from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, @@ -97,6 +95,9 @@ def parallelize_llama( etp_enabled=parallel_dims.etp_enabled, ) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -105,9 +106,6 @@ def parallelize_llama( use_flex_attn, ) - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index 04a477634..206a4d7bd 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -9,8 +9,8 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims -from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger @@ -62,6 +62,9 @@ def parallelize_llama( if job_config.activation_checkpoint.mode != "none": use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) apply_ac( model, job_config.activation_checkpoint, From 752d09034a4f35d504460fe82e272c508f275f32 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 27 Aug 2025 18:18:47 -0700 Subject: [PATCH 5/5] Fix tests --- .../unit_tests/test_activation_checkpoint.py | 24 ++++++++++++------- .../models/deepseek_v3/infra/parallelize.py | 4 +--- torchtitan/models/llama3/infra/parallelize.py | 7 +++--- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index a4dbc21a5..3803da421 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -11,7 +11,7 @@ from torch.utils.flop_counter import FlopCounterMode from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.models.llama3.infra.parallelize import apply_ac +from torchtitan.distributed.activation_checkpoint import apply_ac class ToyModule(nn.Module): @@ -67,7 +67,7 @@ def get_bw_flops(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list ) - apply_ac(model_selective_ac, ac_config_no_force) + apply_ac(model_selective_ac, ac_config_no_force, False, False) flops_selective_ac = get_bw_flops(model_selective_ac) # 3. Per-op SAC with force recompute "moe.router.gate" @@ -78,7 +78,7 @@ def get_bw_flops(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ) - apply_ac(model_with_force_first, ac_config_with_force_first) + apply_ac(model_with_force_first, ac_config_with_force_first, False, False) flops_with_force_first = get_bw_flops(model_with_force_first) # 4. Per-op SAC with force recompute "output" @@ -88,7 +88,7 @@ def get_bw_flops(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ) - apply_ac(model_with_force_last, ac_config_with_force_last) + apply_ac(model_with_force_last, ac_config_with_force_last, False, False) flops_with_force_last = get_bw_flops(model_with_force_last) # 5. Full AC @@ -96,7 +96,7 @@ def get_bw_flops(model_fn): ac_config_full_ac = ACConfig( mode="full", ) - apply_ac(model_with_full_ac, ac_config_full_ac) + apply_ac(model_with_full_ac, ac_config_full_ac, False, False) flops_full_ac = get_bw_flops(model_with_full_ac) self.assertEqual(flops_no_ac, 8.0) @@ -133,7 +133,7 @@ def get_act_mem(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list ) - apply_ac(model_selective_ac, ac_config_no_force) + apply_ac(model_selective_ac, ac_config_no_force, False, False) mem_selective_ac = get_act_mem(model_selective_ac) # 3. Per-op SAC with force recompute "moe.router.gate" @@ -144,7 +144,7 @@ def get_act_mem(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ) - apply_ac(model_with_force_first, ac_config_with_force_first) + apply_ac(model_with_force_first, ac_config_with_force_first, False, False) mem_with_force_first = get_act_mem(model_with_force_first) # 4. Per-op SAC with force recompute "output" @@ -154,7 +154,7 @@ def get_act_mem(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ) - apply_ac(model_with_force_last, ac_config_with_force_last) + apply_ac(model_with_force_last, ac_config_with_force_last, False, False) mem_with_force_last = get_act_mem(model_with_force_last) # 5. Full AC @@ -162,7 +162,7 @@ def get_act_mem(model_fn): ac_config_full_ac = ACConfig( mode="full", ) - apply_ac(model_with_full_ac, ac_config_full_ac) + apply_ac(model_with_full_ac, ac_config_full_ac, False, False) mem_full_ac = get_act_mem(model_with_full_ac) self.assertEqual(mem_no_ac, 2.0) @@ -186,6 +186,8 @@ def test_correctness(self): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], ), + False, + False, ) model_force_first = ToyModule() model_force_first.load_state_dict(model_no_ac.state_dict()) @@ -196,6 +198,8 @@ def test_correctness(self): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ), + False, + False, ) model_force_last = ToyModule() @@ -207,6 +211,8 @@ def test_correctness(self): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ), + False, + False, ) def run_fwd_bwd(model, batch): diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 6fc2c0112..e3407dac7 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -48,9 +48,7 @@ def parallelize_deepseekv3( """ use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if ( - job_config.parallelism.context_parallel_degree > 1 and use_flex_attn - ): + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 1307302f4..f8e1295a8 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -76,6 +76,10 @@ def parallelize_llama( ) maybe_enable_async_tp(job_config, world_mesh["tp"]) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + if job_config.activation_checkpoint.mode != "none": apply_ac( model, @@ -84,9 +88,6 @@ def parallelize_llama( use_flex_attn, ) - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: apply_compile(model)