Skip to content

Add paged attention into DeepSeek v3 #1843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ pagedattn_num_pages: 64 # total number of pages to allocate
pagedattn_tokens_per_page: 32 # number of tokens each page can hold
pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels
pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
# TPUs, the head_dim is padded to the nearest multiple of 128.
pagedattn_head_dim_alignment: 128


# Chunked Prefill Parameters
Expand Down
17 changes: 17 additions & 0 deletions MaxText/inference/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ def init_or_get_kv_pages(self, model_mode: str):
value_pages_var.value = nn.with_logical_constraint(value_pages_var.value, self.kv_pages_axis_names)
return key_pages_var, value_pages_var

def pad_qkv(self, *qkv):
"""Pad input to kv_head_dim_size"""
def pad_to_kv_head_dim_size(x):
if x.shape[-1] != self.kv_head_dim_size:
return jnp.pad(
x,
((0, 0), (0, 0), (0, 0), (0, self.kv_head_dim_size - x.shape[-1])),
mode="constant",
constant_values=0.0,
)
else:
return x

# Align Q, K, V to the same head dim. This is required by the kernel.
return tuple(pad_to_kv_head_dim_size(x) for x in qkv)

def paged_dot_product_attention_with_max_and_sum(self, query, key, value):
"""paged dot product attention with max & sum"""
b, t, n, d = query.shape
Expand Down Expand Up @@ -237,6 +253,7 @@ def __call__(
are None for autoregressive mode (handled by paged_attention kernel)
"""
key_pages_var, value_pages_var = self.init_or_get_kv_pages(model_mode)
query, key, value = self.pad_qkv(query, key, value)

# update kv pages and call page attention kernel
if model_mode == MODEL_MODE_PREFILL:
Expand Down
35 changes: 33 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,30 @@ def setup(self):
mscale = 0.1 * self.mscale * jnp.log(self.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale

# Setup paged attention op
if self.config.attention == "paged":
# Set head_dim to the max of qk_head_dim and v_head_dim. The current paged
# attention kernel requires the head_dim to be the same for q, k, v.
head_dim = max(self.qk_head_dim, self.v_head_dim)
# Align head_dim to the pagedattn_head_dim_alignment if specified.
if self.config.pagedattn_head_dim_alignment > 0:
alignment = self.config.pagedattn_head_dim_alignment
head_dim = (head_dim + alignment - 1) // alignment * alignment
self.ds_paged_attention_op = paged_attention.PagedAttentionOp(
mesh=self.mesh,
num_pages=self.config.pagedattn_num_pages,
tokens_per_page=self.config.pagedattn_tokens_per_page,
max_pages_per_slot=(self.config.max_target_length + self.config.pagedattn_tokens_per_page - 1)
// self.config.pagedattn_tokens_per_page,
max_pages_per_prefill=(self.config.max_prefill_predict_length + self.config.pagedattn_tokens_per_page - 1)
// self.config.pagedattn_tokens_per_page,
pages_per_compute_block=self.config.pagedattn_pages_per_compute_block,
num_kv_heads=self.num_kv_heads,
kv_head_dim_size=head_dim,
dtype=self.dtype,
attn_logits_soft_cap=self.attn_logits_soft_cap,
)

def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array:
"""Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
if self.q_lora_rank == 0:
Expand Down Expand Up @@ -1907,7 +1931,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm

key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
cached_values = [None, None]
if model_mode != MODEL_MODE_TRAIN:
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
if self.config.mla_naive_kvcache:
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
else:
Expand Down Expand Up @@ -1956,7 +1980,14 @@ def __call__(
key = checkpoint_name(key, "key_proj")
value = checkpoint_name(value, "value_proj")

out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN:
unnormalized_out, _, exp_sum = self.ds_paged_attention_op(
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
)
unnormalized_out = unnormalized_out[..., :self.v_head_dim]
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
else:
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
out = nn.with_logical_constraint(out, self.out_axis_names)
out = self.out_projection(inputs_q.shape[-1], out)
return out
Expand Down
53 changes: 44 additions & 9 deletions MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,25 @@
from MaxText.layers import moe
from MaxText.layers import quantizations
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.inference import page_manager

# -----------------------------------------
# The Decoder Layer for DeepSeek v3
# -----------------------------------------


def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode):
def self_attention_with_norm(
inputs,
cfg,
mesh,
quant,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk = None,
page_state: Optional[page_manager.PageState] = None,
slot: Optional[int] = None):
"""self-attention with normalization"""
# Normalization
lnx_rms = models.RMSNorm(
Expand Down Expand Up @@ -86,6 +98,9 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
decoder_segment_ids=decoder_segment_ids,
deterministic=deterministic,
model_mode=model_mode,
previous_chunk=previous_chunk,
page_state=page_state,
slot=slot,
)

attention_lnx = nn.with_logical_constraint(
Expand Down Expand Up @@ -139,16 +154,26 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
slot=None,
previous_chunk = None,
page_state: Optional[page_manager.PageState] = None,
slot: Optional[int] = None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")

hidden_states, intermediate_inputs = self_attention_with_norm(
inputs, cfg, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode
inputs,
cfg,
self.mesh,
self.quant,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk,
page_state,
slot,
)
mlp_lnx = linears.MlpBlock(
intermediate_dim=cfg.mlp_dim,
Expand Down Expand Up @@ -189,16 +214,26 @@ def __call__(
decoder_positions,
deterministic,
model_mode,
previous_chunk=None,
page_state=None,
slot=None,
previous_chunk = None,
page_state: Optional[page_manager.PageState] = None,
slot: Optional[int] = None,
):
cfg = self.config
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
inputs = checkpoint_name(inputs, "decoder_layer_input")

hidden_states, intermediate_inputs = self_attention_with_norm(
inputs, self.config, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode
inputs,
self.config,
self.mesh,
self.quant,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
previous_chunk,
page_state,
slot,
)

# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.
Expand Down
9 changes: 8 additions & 1 deletion MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,15 +552,22 @@ def __call__(
if cfg.scan_layers:
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
layer_call_kwargs = {
"page_state": page_state,
"previous_chunk": previous_chunk,
"slot": slot,
}
dense_layer = RemattedBlockLayers[0]
moe_layer = RemattedBlockLayers[1]
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)(
y,
decoder_segment_ids,
decoder_positions,
deterministic,
model_mode,
)
moe_layer = RemattedBlockLayers[1]
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)(
y,
Expand Down
Loading