Skip to content

Commit 7009be2

Browse files
author
maxtext authors
committed
Add paged attention into DeepSeek v3
This change integrates paged attention into DeepSeek v3 model by: 1. Add paged attention op into Attention layer. 2. Update Deepseek v3 model such that the paged attention can be used in inference mode. 3. Make sure page_state can be passed into paged_attention kernel in scan scenario. 4. Allow paged_attn_head_dim_alignment to align the QKV's head dim to the nearest multiple of 128, which is required for running on TPUs. PiperOrigin-RevId: 771582588
1 parent 1300eeb commit 7009be2

File tree

5 files changed

+107
-14
lines changed

5 files changed

+107
-14
lines changed

MaxText/configs/base.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,9 @@ pagedattn_num_pages: 64 # total number of pages to allocate
692692
pagedattn_tokens_per_page: 32 # number of tokens each page can hold
693693
pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels
694694
pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length
695+
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
696+
# TPUs, the head_dim is padded to the nearest multiple of 128.
697+
pagedattn_head_dim_alignment: 128
695698

696699

697700
# Chunked Prefill Parameters

MaxText/inference/paged_attention.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,22 @@ def init_or_get_kv_pages(self, model_mode: str):
8686
value_pages_var.value = nn.with_logical_constraint(value_pages_var.value, self.kv_pages_axis_names)
8787
return key_pages_var, value_pages_var
8888

89+
def pad_qkv(self, *qkv):
90+
"""Pad input to kv_head_dim_size"""
91+
def pad_to_kv_head_dim_size(x):
92+
if x.shape[-1] != self.kv_head_dim_size:
93+
return jnp.pad(
94+
x,
95+
((0, 0), (0, 0), (0, 0), (0, self.kv_head_dim_size - x.shape[-1])),
96+
mode="constant",
97+
constant_values=0.0,
98+
)
99+
else:
100+
return x
101+
102+
# Align Q, K, V to the same head dim. This is required by the kernel.
103+
return tuple(pad_to_kv_head_dim_size(x) for x in qkv)
104+
89105
def paged_dot_product_attention_with_max_and_sum(self, query, key, value):
90106
"""paged dot product attention with max & sum"""
91107
b, t, n, d = query.shape
@@ -180,8 +196,8 @@ def paged_attention_v1_decode(
180196
page_state: page_manager.PageState,
181197
) -> Array:
182198
"""Apply Paged Attention v1 in decode only."""
183-
kv_pages_pspec = nn.logical_to_mesh_axes(("paged_kv_heads", None, None, None))
184-
q_pspec = nn.logical_to_mesh_axes((None, None, "paged_kv_heads", None))
199+
kv_pages_pspec = nn.logical_to_mesh_axes((None, None, None, None))
200+
q_pspec = nn.logical_to_mesh_axes((None, None, None, None))
185201

186202
@functools.partial(
187203
shard_map,
@@ -237,6 +253,7 @@ def __call__(
237253
are None for autoregressive mode (handled by paged_attention kernel)
238254
"""
239255
key_pages_var, value_pages_var = self.init_or_get_kv_pages(model_mode)
256+
query, key, value = self.pad_qkv(query, key, value)
240257

241258
# update kv pages and call page attention kernel
242259
if model_mode == MODEL_MODE_PREFILL:

MaxText/layers/attentions.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,30 @@ def setup(self):
18211821
mscale = 0.1 * self.mscale * jnp.log(self.rope_factor) + 1.0
18221822
self.softmax_scale = self.softmax_scale * mscale * mscale
18231823

1824+
# Setup paged attention op
1825+
if self.config.attention == "paged":
1826+
# Set head_dim to the max of qk_head_dim and v_head_dim. The current paged
1827+
# attention kernel requires the head_dim to be the same for q, k, v.
1828+
head_dim = max(self.qk_head_dim, self.v_head_dim)
1829+
# Align head_dim to the pagedattn_head_dim_alignment if specified.
1830+
if self.config.pagedattn_head_dim_alignment > 0:
1831+
alignment = self.config.pagedattn_head_dim_alignment
1832+
head_dim = (head_dim + alignment - 1) // alignment * alignment
1833+
self.ds_paged_attention_op = paged_attention.PagedAttentionOp(
1834+
mesh=self.mesh,
1835+
num_pages=self.config.pagedattn_num_pages,
1836+
tokens_per_page=self.config.pagedattn_tokens_per_page,
1837+
max_pages_per_slot=(self.config.max_target_length + self.config.pagedattn_tokens_per_page - 1)
1838+
// self.config.pagedattn_tokens_per_page,
1839+
max_pages_per_prefill=(self.config.max_prefill_predict_length + self.config.pagedattn_tokens_per_page - 1)
1840+
// self.config.pagedattn_tokens_per_page,
1841+
pages_per_compute_block=self.config.pagedattn_pages_per_compute_block,
1842+
num_kv_heads=self.num_kv_heads,
1843+
kv_head_dim_size=head_dim,
1844+
dtype=self.dtype,
1845+
attn_logits_soft_cap=self.attn_logits_soft_cap,
1846+
)
1847+
18241848
def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array:
18251849
"""Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0."""
18261850
if self.q_lora_rank == 0:
@@ -1907,7 +1931,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
19071931

19081932
key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode)
19091933
cached_values = [None, None]
1910-
if model_mode != MODEL_MODE_TRAIN:
1934+
if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN:
19111935
if self.config.mla_naive_kvcache:
19121936
cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk)
19131937
else:
@@ -1956,7 +1980,14 @@ def __call__(
19561980
key = checkpoint_name(key, "key_proj")
19571981
value = checkpoint_name(value, "value_proj")
19581982

1959-
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
1983+
if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN:
1984+
unnormalized_out, _, exp_sum = self.ds_paged_attention_op(
1985+
query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state
1986+
)
1987+
unnormalized_out = unnormalized_out[..., :self.v_head_dim]
1988+
out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out
1989+
else:
1990+
out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values)
19601991
out = nn.with_logical_constraint(out, self.out_axis_names)
19611992
out = self.out_projection(inputs_q.shape[-1], out)
19621993
return out

MaxText/layers/deepseek.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,25 @@
3434
from MaxText.layers import moe
3535
from MaxText.layers import quantizations
3636
from MaxText.layers.quantizations import AqtQuantization as Quant
37+
from MaxText.inference import page_manager
3738

3839
# -----------------------------------------
3940
# The Decoder Layer for DeepSeek v3
4041
# -----------------------------------------
4142

4243

43-
def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode):
44+
def self_attention_with_norm(
45+
inputs,
46+
cfg,
47+
mesh,
48+
quant,
49+
decoder_segment_ids,
50+
decoder_positions,
51+
deterministic,
52+
model_mode,
53+
previous_chunk = None,
54+
page_state: Optional[page_manager.PageState] = None,
55+
slot: Optional[int] = None):
4456
"""self-attention with normalization"""
4557
# Normalization
4658
lnx_rms = models.RMSNorm(
@@ -86,6 +98,9 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
8698
decoder_segment_ids=decoder_segment_ids,
8799
deterministic=deterministic,
88100
model_mode=model_mode,
101+
previous_chunk=previous_chunk,
102+
page_state=page_state,
103+
slot=slot,
89104
)
90105

91106
attention_lnx = nn.with_logical_constraint(
@@ -139,16 +154,26 @@ def __call__(
139154
decoder_positions,
140155
deterministic,
141156
model_mode,
142-
previous_chunk=None,
143-
page_state=None,
144-
slot=None,
157+
previous_chunk = None,
158+
page_state: Optional[page_manager.PageState] = None,
159+
slot: Optional[int] = None,
145160
):
146161
cfg = self.config
147162
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
148163
inputs = checkpoint_name(inputs, "decoder_layer_input")
149164

150165
hidden_states, intermediate_inputs = self_attention_with_norm(
151-
inputs, cfg, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode
166+
inputs,
167+
cfg,
168+
self.mesh,
169+
self.quant,
170+
decoder_segment_ids,
171+
decoder_positions,
172+
deterministic,
173+
model_mode,
174+
previous_chunk,
175+
page_state,
176+
slot,
152177
)
153178
mlp_lnx = linears.MlpBlock(
154179
intermediate_dim=cfg.mlp_dim,
@@ -189,16 +214,26 @@ def __call__(
189214
decoder_positions,
190215
deterministic,
191216
model_mode,
192-
previous_chunk=None,
193-
page_state=None,
194-
slot=None,
217+
previous_chunk = None,
218+
page_state: Optional[page_manager.PageState] = None,
219+
slot: Optional[int] = None,
195220
):
196221
cfg = self.config
197222
inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed"))
198223
inputs = checkpoint_name(inputs, "decoder_layer_input")
199224

200225
hidden_states, intermediate_inputs = self_attention_with_norm(
201-
inputs, self.config, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode
226+
inputs,
227+
self.config,
228+
self.mesh,
229+
self.quant,
230+
decoder_segment_ids,
231+
decoder_positions,
232+
deterministic,
233+
model_mode,
234+
previous_chunk,
235+
page_state,
236+
slot,
202237
)
203238

204239
# NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints.

MaxText/layers/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,15 +552,22 @@ def __call__(
552552
if cfg.scan_layers:
553553
if cfg.decoder_block == DecoderBlockType.DEEPSEEK:
554554
assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek."
555+
layer_call_kwargs = {
556+
"page_state": page_state,
557+
"previous_chunk": previous_chunk,
558+
"slot": slot,
559+
}
555560
dense_layer = RemattedBlockLayers[0]
556-
moe_layer = RemattedBlockLayers[1]
561+
dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs)
557562
y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)(
558563
y,
559564
decoder_segment_ids,
560565
decoder_positions,
561566
deterministic,
562567
model_mode,
563568
)
569+
moe_layer = RemattedBlockLayers[1]
570+
moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs)
564571
num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers
565572
y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)(
566573
y,

0 commit comments

Comments
 (0)