-
Notifications
You must be signed in to change notification settings - Fork 500
Activation Checkpoint improvment #1645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# This file provides the util functions to apply activation checkpointing to the model. | ||
# Technically, this is not a part of distributed, but distributed module is the best place to put it. | ||
|
||
from collections import defaultdict | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||
checkpoint_wrapper as ptd_checkpoint_wrapper, | ||
) | ||
|
||
from torchtitan.config.job_config import ActivationCheckpoint as ACConfig | ||
from torchtitan.tools.logging import logger | ||
|
||
# for selective op activation checkpointing | ||
_save_list = { | ||
torch.ops.aten.mm.default, | ||
torch.ops.aten._scaled_dot_product_efficient_attention.default, | ||
torch.ops.aten._scaled_dot_product_flash_attention.default, | ||
torch.ops._c10d_functional.reduce_scatter_tensor.default, | ||
# for low precision training, it's useful to always save | ||
# the result of max, since the absolute maximum is | ||
# used to compute the scaling factor for quantization. | ||
torch.ops.aten.max.default, | ||
torch._higher_order_ops.flex_attention, | ||
} | ||
|
||
|
||
def _apply_ac_to_transformer_block( | ||
module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None | ||
): | ||
valid_ac_modes = ("full", "selective") | ||
if ac_config.mode not in valid_ac_modes: | ||
raise ValueError( | ||
f"Invalid AC mode: {ac_config.mode}. Valid modes: {valid_ac_modes}" | ||
) | ||
|
||
if ac_config.mode == "full": | ||
return ptd_checkpoint_wrapper(module, preserve_rng_state=False) | ||
|
||
assert ac_config.mode == "selective", f"{ac_config.mode}" | ||
use_op_sac = ac_config.selective_ac_option == "op" | ||
use_layer_sac = ac_config.selective_ac_option.isdigit() | ||
if not use_op_sac and not use_layer_sac: | ||
raise ValueError( | ||
f"Invalid selective AC option: {ac_config.selective_ac_option}. " | ||
f"Valid options: 'op' or a positive int representing layer frequency" | ||
) | ||
if use_op_sac: | ||
from torch.utils.checkpoint import ( | ||
CheckpointPolicy, | ||
create_selective_checkpoint_contexts, | ||
) | ||
|
||
mm_recompute_shapes = set() | ||
if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: | ||
for module_fqn, submod in module.named_modules(): | ||
fqn = module_fqn | ||
if base_fqn is not None: | ||
fqn = f"{base_fqn}.{module_fqn}" | ||
if not any( | ||
filter_fqn in fqn | ||
for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns | ||
): | ||
continue | ||
if not isinstance(submod, nn.Linear): | ||
raise ValueError( | ||
"per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " | ||
f"a nn.Linear, but got: {submod}" | ||
) | ||
out_f, in_f = submod.weight.shape | ||
mm_recompute_shapes.add((in_f, out_f)) | ||
logger.debug( | ||
f"Selective op AC force recomputing mms with rhs shapes {mm_recompute_shapes}" | ||
) | ||
|
||
def _get_custom_policy(meta): | ||
def _custom_policy(ctx, func, *args, **kwargs): | ||
mode = "recompute" if ctx.is_recompute else "forward" | ||
mm_count_key = f"{mode}_mm_count" | ||
if func == torch.ops.aten.mm.default: | ||
if args[1].shape in mm_recompute_shapes: | ||
return CheckpointPolicy.PREFER_RECOMPUTE | ||
meta[mm_count_key] += 1 | ||
# Saves output of all compute ops, except every second mm | ||
to_save = func in _save_list and not ( | ||
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 | ||
) | ||
return ( | ||
CheckpointPolicy.MUST_SAVE | ||
if to_save | ||
else CheckpointPolicy.PREFER_RECOMPUTE | ||
) | ||
|
||
return _custom_policy | ||
|
||
def selective_checkpointing_context_fn(): | ||
meta = defaultdict(int) | ||
return create_selective_checkpoint_contexts(_get_custom_policy(meta)) | ||
|
||
return ptd_checkpoint_wrapper( | ||
module, | ||
context_fn=selective_checkpointing_context_fn, | ||
preserve_rng_state=False, | ||
) | ||
elif use_layer_sac: | ||
# Checkpoint every `ac_freq` of the modules passed to this function | ||
ac_freq = int(ac_config.selective_ac_option) | ||
ptd_checkpoint_wrapper.__dict__.setdefault("_count", 0) | ||
ptd_checkpoint_wrapper._count += 1 | ||
if not ac_freq or ptd_checkpoint_wrapper._count % ac_freq == 0: | ||
return ptd_checkpoint_wrapper(module, preserve_rng_state=False) | ||
else: | ||
return module | ||
|
||
|
||
def apply_ac( | ||
model: nn.Module, | ||
ac_config: ACConfig, | ||
model_compile_enabled: bool, | ||
use_flex_attn: bool, | ||
): | ||
"""Apply activation checkpointing to the model. | ||
|
||
Note that SAC, Flex Attention and model compilation have some conflicts. | ||
We explicitly ask the user to pass these configs to warn if there are conflicts. | ||
""" | ||
|
||
if use_flex_attn and not model_compile_enabled: | ||
logger.warning( | ||
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" | ||
"POTENTIAL PERFORMANCE ISSUE DETECTED:\n" | ||
"Flex attention requires compilation for optimal performance and will be\n" | ||
"compiled automatically regardless of config.compile settings. However,\n" | ||
fegin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Selective Activation Checkpointing (SAC) requires compilation to be applied\n" | ||
"at the outermost level (e.g., compile(SAC(model))). Othewise the compilation\n" | ||
"will be ignored." | ||
"\n" | ||
"Without enabling config.compile, the apply order will be:\n" | ||
"SAC(compile(flex_attention)). The compilation of flex_attention will be\n" | ||
"skipped, which results in poor performance.\n" | ||
"\n" | ||
"For best results, enable config.compile to ensure proper compilation order:\n" | ||
"compile(SAC(compile(flex_attention)))\n" | ||
"\n" | ||
"The innermost torch.compile will be ignored, but the outermost will take\n" | ||
"effect and provide optimal performance.\n" | ||
"!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n" | ||
) | ||
for layer_id, transformer_block in model.layers.named_children(): | ||
transformer_block = _apply_ac_to_transformer_block( | ||
transformer_block, ac_config, base_fqn=f"layers.{layer_id}" | ||
) | ||
model.layers.register_module(layer_id, transformer_block) | ||
|
||
logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's preferable if we can create customized list for each individual model if necessary, in addition to some default save_list.
E.g. MoE and dense models may need different save_list, and it sounds bad if we just mix everything.
This refactor can happen in a separate PR.