Skip to content

torch.compile Not Applied to FlexAttention with SAC selective_ac_option=op #1631

@fegin

Description

@fegin

Bug description

Description

When using the SAC with the option selective_ac_option=op, torch.compile is not applied to FlexAttention. This occurs despite TorchTitan explicitly adding calls to torch.compile(flex_attention).

As a result, the following warning is emitted and results in a very low performance:

[rank0]:
[rank0]:SOLUTION: Use torch.compile(flex_attention)(...)
[rank0]:
[rank0]:If you want to debug your score_mod/mask_mod, you can set:
[rank0]:torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
[rank0]:
[rank0]:This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results.

Steps to Reproduce

  1. Use the default DeepSeek 16B configuration, or
  2. Use the debug model with the flag --model.flavor=debugmodel_flex_attn.

Additional Information

  • This issue does not occur when:
    • No Activation Checkpointing (AC) is used,
    • Full Activation Checkpointing is used,
    • SAC is used with selective_ac_option=2.

Versions

Nightly

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions