diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index a4dbc21a5..3803da421 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -11,7 +11,7 @@ from torch.utils.flop_counter import FlopCounterMode from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.models.llama3.infra.parallelize import apply_ac +from torchtitan.distributed.activation_checkpoint import apply_ac class ToyModule(nn.Module): @@ -67,7 +67,7 @@ def get_bw_flops(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list ) - apply_ac(model_selective_ac, ac_config_no_force) + apply_ac(model_selective_ac, ac_config_no_force, False, False) flops_selective_ac = get_bw_flops(model_selective_ac) # 3. Per-op SAC with force recompute "moe.router.gate" @@ -78,7 +78,7 @@ def get_bw_flops(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ) - apply_ac(model_with_force_first, ac_config_with_force_first) + apply_ac(model_with_force_first, ac_config_with_force_first, False, False) flops_with_force_first = get_bw_flops(model_with_force_first) # 4. Per-op SAC with force recompute "output" @@ -88,7 +88,7 @@ def get_bw_flops(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ) - apply_ac(model_with_force_last, ac_config_with_force_last) + apply_ac(model_with_force_last, ac_config_with_force_last, False, False) flops_with_force_last = get_bw_flops(model_with_force_last) # 5. Full AC @@ -96,7 +96,7 @@ def get_bw_flops(model_fn): ac_config_full_ac = ACConfig( mode="full", ) - apply_ac(model_with_full_ac, ac_config_full_ac) + apply_ac(model_with_full_ac, ac_config_full_ac, False, False) flops_full_ac = get_bw_flops(model_with_full_ac) self.assertEqual(flops_no_ac, 8.0) @@ -133,7 +133,7 @@ def get_act_mem(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], # Empty list ) - apply_ac(model_selective_ac, ac_config_no_force) + apply_ac(model_selective_ac, ac_config_no_force, False, False) mem_selective_ac = get_act_mem(model_selective_ac) # 3. Per-op SAC with force recompute "moe.router.gate" @@ -144,7 +144,7 @@ def get_act_mem(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ) - apply_ac(model_with_force_first, ac_config_with_force_first) + apply_ac(model_with_force_first, ac_config_with_force_first, False, False) mem_with_force_first = get_act_mem(model_with_force_first) # 4. Per-op SAC with force recompute "output" @@ -154,7 +154,7 @@ def get_act_mem(model_fn): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ) - apply_ac(model_with_force_last, ac_config_with_force_last) + apply_ac(model_with_force_last, ac_config_with_force_last, False, False) mem_with_force_last = get_act_mem(model_with_force_last) # 5. Full AC @@ -162,7 +162,7 @@ def get_act_mem(model_fn): ac_config_full_ac = ACConfig( mode="full", ) - apply_ac(model_with_full_ac, ac_config_full_ac) + apply_ac(model_with_full_ac, ac_config_full_ac, False, False) mem_full_ac = get_act_mem(model_with_full_ac) self.assertEqual(mem_no_ac, 2.0) @@ -186,6 +186,8 @@ def test_correctness(self): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=[], ), + False, + False, ) model_force_first = ToyModule() model_force_first.load_state_dict(model_no_ac.state_dict()) @@ -196,6 +198,8 @@ def test_correctness(self): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ), + False, + False, ) model_force_last = ToyModule() @@ -207,6 +211,8 @@ def test_correctness(self): selective_ac_option="op", per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ), + False, + False, ) def run_fwd_bwd(model, batch): diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py new file mode 100644 index 000000000..b4c28a90b --- /dev/null +++ b/torchtitan/distributed/activation_checkpoint.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file provides the util functions to apply activation checkpointing to the model. +# Technically, this is not a part of distributed, but distributed module is the best place to put it. + +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) + +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.tools.logging import logger + +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +def _apply_ac_to_transformer_block( + module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None +): + valid_ac_modes = ("full", "selective") + if ac_config.mode not in valid_ac_modes: + raise ValueError( + f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" + ) + + if ac_config.mode == "full": + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + + assert ac_config.mode == "selective", f"{ac_config.mode}" + use_op_sac = ac_config.selective_ac_option == "op" + use_layer_sac = ac_config.selective_ac_option.isdigit() + if not use_op_sac and not use_layer_sac: + raise ValueError( + f"Invalid selective AC option: {ac_config.selective_ac_option}. " + f"Valid options: 'op' or a positive int representing layer frequency" + ) + if use_op_sac: + from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, + ) + + mm_recompute_shapes = set() + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + logger.debug( + f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" + ) + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE + meta[mm_count_key] += 1 + # Saves output of all compute ops, except every second mm + to_save = func in _save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_context_fn(): + meta = defaultdict(int) + return create_selective_checkpoint_contexts(_get_custom_policy(meta)) + + return ptd_checkpoint_wrapper( + module, + context_fn=selective_checkpointing_context_fn, + preserve_rng_state=False, + ) + elif use_layer_sac: + # Checkpoint every `ac_freq` of the modules passed to this function + ac_freq = int(ac_config.selective_ac_option) + ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) + ptd_checkpoint_wrapper._count += 1 + if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: + return ptd_checkpoint_wrapper(module, preserve_rng_state=False) + else: + return module + + +def apply_ac( + model: nn.Module, + ac_config: ACConfig, + model_compile_enabled: bool, + use_flex_attn: bool, +): + """Apply activation checkpointing to the model. + + Note that SAC, Flex Attention and model compilation have some conflicts. + We explicitly ask the user to pass these configs to warn if there are conflicts. + """ + + if use_flex_attn and not model_compile_enabled: + logger.warning( + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" + "POTENTIAL PERFORMANCE ISSUE DETECTED:\n" + "Flex attention requires compilation for optimal performance and will be\n" + "compiled automatically regardless of config.compile settings. However,\n" + "Selective Activation Checkpointing (SAC) requires compilation to be applied\n" + "at the outermost level (e.g., compile(SAC(model))). Othewise the compilation\n" + "will be ignored." + "\n" + "Without enabling config.compile, the apply order will be:\n" + "SAC(compile(flex_attention)). The compilation of flex_attention will be\n" + "skipped, which results in poor performance.\n" + "\n" + "For best results, enable config.compile to ensure proper compilation order:\n" + "compile(SAC(compile(flex_attention)))\n" + "\n" + "The innermost torch.compile will be ignored, but the outermost will take\n" + "effect and provide optimal performance.\n" + "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" + ) + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = _apply_ac_to_transformer_block( + transformer_block, ac_config, base_fqn=f"layers.{layer_id}" + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 3e4dd43f7..4ea2798b2 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -20,7 +20,7 @@ ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims - +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.expert_parallel import ( ExpertParallel, ExpertTensorParallel, @@ -29,8 +29,7 @@ TensorParallel, ) from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp - -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger @@ -57,10 +56,8 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: @@ -98,12 +95,17 @@ def parallelize_llama( etp_enabled=parallel_dims.etp_enabled, ) - if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) - model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) + # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE diff --git a/torchtitan/experiments/qwen3/infra/parallelize.py b/torchtitan/experiments/qwen3/infra/parallelize.py index 8f7cb06ef..05278a2ec 100644 --- a/torchtitan/experiments/qwen3/infra/parallelize.py +++ b/torchtitan/experiments/qwen3/infra/parallelize.py @@ -22,9 +22,9 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.expert_parallel import NoParallel from torchtitan.models.llama3.infra.parallelize import ( - apply_ac, apply_compile, apply_ddp, apply_fsdp, @@ -46,10 +46,8 @@ def parallelize_qwen3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") model_compile_enabled = ( @@ -82,7 +80,12 @@ def parallelize_qwen3( ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: diff --git a/torchtitan/experiments/simple_fsdp/parallelize.py b/torchtitan/experiments/simple_fsdp/parallelize.py index 5feffdabb..206a4d7bd 100644 --- a/torchtitan/experiments/simple_fsdp/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/parallelize.py @@ -9,8 +9,9 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp +from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger from .simple_fsdp import data_parallel, MixedPrecisionPolicy @@ -60,7 +61,16 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) # apply data parallel if ( diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index c77250d0f..e3407dac7 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -18,6 +18,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.expert_parallel import NoParallel from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.experiments.llama4.infra.parallelize import ( @@ -25,7 +26,7 @@ apply_fsdp, apply_moe_ep_tp, ) -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger @@ -46,10 +47,8 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: @@ -89,12 +88,18 @@ def parallelize_deepseekv3( etp_enabled=parallel_dims.etp_enabled, ) - if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) - model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) + if model_compile_enabled: # NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE torch._dynamo.config.capture_scalar_outputs = True diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 7d0b5de92..f8e1295a8 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -7,15 +7,9 @@ # This file applies the PT-D parallelisms (except pipeline parallelism) and various # training techniques (e.g. activation checkpointing and compile) to the Llama model. -from collections import defaultdict -from typing import Optional - import torch import torch.nn as nn from torch.distributed._composable.replicate import replicate -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, -) from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy @@ -29,8 +23,8 @@ ) from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.config.job_config import ActivationCheckpoint as ACConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -58,10 +52,8 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - if ( - job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn - ): + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: @@ -84,12 +76,18 @@ def parallelize_llama( ) maybe_enable_async_tp(job_config, world_mesh["tp"]) - if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) - model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled, + use_flex_attn, + ) + # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: apply_compile(model) @@ -221,119 +219,6 @@ def apply_tp( ) -# for selective op activation checkpointing -_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten._scaled_dot_product_efficient_attention.default, - torch.ops.aten._scaled_dot_product_flash_attention.default, - torch.ops._c10d_functional.reduce_scatter_tensor.default, - # for low precision training, it's useful to always save - # the result of max, since the absolute maximum is - # used to compute the scaling factor for quantization. - torch.ops.aten.max.default, - torch._higher_order_ops.flex_attention, -} - - -def _apply_ac_to_transformer_block( - module: nn.Module, ac_config: ACConfig, *, base_fqn: Optional[str] = None -): - valid_ac_modes = ("full", "selective") - if ac_config.mode not in valid_ac_modes: - raise ValueError( - f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" - ) - - if ac_config.mode == "full": - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - - assert ac_config.mode == "selective", f"{ac_config.mode}" - use_op_sac = ac_config.selective_ac_option == "op" - use_layer_sac = ac_config.selective_ac_option.isdigit() - if not use_op_sac and not use_layer_sac: - raise ValueError( - f"Invalid selective AC option: {ac_config.selective_ac_option}. " - f"Valid options: 'op' or a positive int representing layer frequency" - ) - if use_op_sac: - from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, - ) - - mm_recompute_shapes = set() - if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: - for module_fqn, submod in module.named_modules(): - fqn = module_fqn - if base_fqn is not None: - fqn = f"{base_fqn}.{module_fqn}" - if not any( - filter_fqn in fqn - for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns - ): - continue - if not isinstance(submod, nn.Linear): - raise ValueError( - "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " - f"a nn.Linear, but got: {submod}" - ) - out_f, in_f = submod.weight.shape - mm_recompute_shapes.add((in_f, out_f)) - logger.debug( - f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" - ) - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE - meta[mm_count_key] += 1 - # Saves output of all compute ops, except every second mm - to_save = func in _save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - - def selective_checkpointing_context_fn(): - meta = defaultdict(int) - return create_selective_checkpoint_contexts(_get_custom_policy(meta)) - - return ptd_checkpoint_wrapper( - module, - context_fn=selective_checkpointing_context_fn, - preserve_rng_state=False, - ) - elif use_layer_sac: - # Checkpoint every `ac_freq` of the modules passed to this function - ac_freq = int(ac_config.selective_ac_option) - ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) - ptd_checkpoint_wrapper._count += 1 - if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: - return ptd_checkpoint_wrapper(module, preserve_rng_state=False) - else: - return module - - -def apply_ac(model: nn.Module, ac_config: ACConfig): - """Apply activation checkpointing to the model.""" - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = _apply_ac_to_transformer_block( - transformer_block, ac_config, base_fqn=f"layers.{layer_id}" - ) - model.layers.register_module(layer_id, transformer_block) - - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") - - def apply_compile(model: nn.Module): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to