-
Notifications
You must be signed in to change notification settings - Fork 500
Open
Labels
Description
Bug description
Description
When using the SAC with the option selective_ac_option=op
, torch.compile
is not applied to FlexAttention
. This occurs despite TorchTitan explicitly adding calls to torch.compile(flex_attention)
.
As a result, the following warning is emitted and results in a very low performance:
[rank0]:
[rank0]:SOLUTION: Use torch.compile(flex_attention)(...)
[rank0]:
[rank0]:If you want to debug your score_mod/mask_mod, you can set:
[rank0]:torch.nn.attention.flex_attention._FLEX_ATTENTION_DISABLE_COMPILE_DEBUG = True
[rank0]:
[rank0]:This will allow you to use print statements or breakpoints. Note: This doesn't work with the backwards pass and may produce incorrect results.
Steps to Reproduce
- Use the default DeepSeek 16B configuration, or
- Use the debug model with the flag
--model.flavor=debugmodel_flex_attn
.
Additional Information
- This issue does not occur when:
- No Activation Checkpointing (AC) is used,
- Full Activation Checkpointing is used,
- SAC is used with
selective_ac_option=2
.
Versions
Nightly