Skip to content

[Ring Attention] Add more detailed references #6294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,18 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
class RingAttention(torch.autograd.Function):
"""Implements the Ring Attention from `Ring Attention with Blockwise Transformers for Near-Infinite Context`
(https://arxiv.org/abs/2310.01889).
For load-balancing we adopted the "zigzag" attention scheme from https://github.com/zhuzilin/ring-flash-attention/tree/main
For portable integration with more models, we don't follow the spirit of "block-wise FNN" in the original paper,
which requires fusing FFN with the Flash Attention kernel/function (see https://arxiv.org/pdf/2305.19370;
implemented in Jax and not optimized).
We adopt the double ring topology from LoongTrain (https://arxiv.org/pdf/2406.18485) to fully utilize available
For load-balancing, we adopted the "zigzag" dataloading scheme from ring-flash-attention.
We also adopt the double ring topology from LoongTrain to fully utilize available
NICs on each node, by computing attention within a inner ring first and then sending all KVs to the next
ring at once.
Our implementation references code from
- ring-flash-attention: https://github.com/zhuzilin/ring-flash-attention/tree/main
- Megatron Context Parallel: https://github.com/NVIDIA/TransformerEngine/pull/726
References:
- Ring Attention with Blockwise Transformers for Near-Infinite Context
https://arxiv.org/abs/2310.01889
- LoongTrain: Efficient Training of Long-Sequence LLMs with Head-Context Parallelism
https://arxiv.org/abs/2406.18485
"""

# Globle cache to avoid recomputation for same-lengthed sequences
Expand Down
Loading