diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index b01a87454..7b58dfb7e 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -692,6 +692,9 @@ pagedattn_num_pages: 64 # total number of pages to allocate pagedattn_tokens_per_page: 32 # number of tokens each page can hold pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length +# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On +# TPUs, the head_dim is padded to the nearest multiple of 128. +pagedattn_head_dim_alignment: 128 # Chunked Prefill Parameters diff --git a/MaxText/inference/paged_attention.py b/MaxText/inference/paged_attention.py index a65da6995..70439ea75 100644 --- a/MaxText/inference/paged_attention.py +++ b/MaxText/inference/paged_attention.py @@ -86,6 +86,22 @@ def init_or_get_kv_pages(self, model_mode: str): value_pages_var.value = nn.with_logical_constraint(value_pages_var.value, self.kv_pages_axis_names) return key_pages_var, value_pages_var + def pad_qkv(self, *qkv): + """Pad input to kv_head_dim_size""" + def pad_to_kv_head_dim_size(x): + if x.shape[-1] != self.kv_head_dim_size: + return jnp.pad( + x, + ((0, 0), (0, 0), (0, 0), (0, self.kv_head_dim_size - x.shape[-1])), + mode="constant", + constant_values=0.0, + ) + else: + return x + + # Align Q, K, V to the same head dim. This is required by the kernel. + return tuple(pad_to_kv_head_dim_size(x) for x in qkv) + def paged_dot_product_attention_with_max_and_sum(self, query, key, value): """paged dot product attention with max & sum""" b, t, n, d = query.shape @@ -237,6 +253,7 @@ def __call__( are None for autoregressive mode (handled by paged_attention kernel) """ key_pages_var, value_pages_var = self.init_or_get_kv_pages(model_mode) + query, key, value = self.pad_qkv(query, key, value) # update kv pages and call page attention kernel if model_mode == MODEL_MODE_PREFILL: diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index ce3d6a516..fbf538ef8 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -1821,6 +1821,30 @@ def setup(self): mscale = 0.1 * self.mscale * jnp.log(self.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale + # Setup paged attention op + if self.config.attention == "paged": + # Set head_dim to the max of qk_head_dim and v_head_dim. The current paged + # attention kernel requires the head_dim to be the same for q, k, v. + head_dim = max(self.qk_head_dim, self.v_head_dim) + # Align head_dim to the pagedattn_head_dim_alignment if specified. + if self.config.pagedattn_head_dim_alignment > 0: + alignment = self.config.pagedattn_head_dim_alignment + head_dim = (head_dim + alignment - 1) // alignment * alignment + self.ds_paged_attention_op = paged_attention.PagedAttentionOp( + mesh=self.mesh, + num_pages=self.config.pagedattn_num_pages, + tokens_per_page=self.config.pagedattn_tokens_per_page, + max_pages_per_slot=(self.config.max_target_length + self.config.pagedattn_tokens_per_page - 1) + // self.config.pagedattn_tokens_per_page, + max_pages_per_prefill=(self.config.max_prefill_predict_length + self.config.pagedattn_tokens_per_page - 1) + // self.config.pagedattn_tokens_per_page, + pages_per_compute_block=self.config.pagedattn_pages_per_compute_block, + num_kv_heads=self.num_kv_heads, + kv_head_dim_size=head_dim, + dtype=self.dtype, + attn_logits_soft_cap=self.attn_logits_soft_cap, + ) + def mla_query_projection(self, inputs_q: Array, inputs_positions: Array, model_mode) -> Array: """Query projection for MLA, e.g. includes LoRA if q_lora_rank > 0.""" if self.q_lora_rank == 0: @@ -1907,7 +1931,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode) cached_values = [None, None] - if model_mode != MODEL_MODE_TRAIN: + if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN: if self.config.mla_naive_kvcache: cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) else: @@ -1956,7 +1980,14 @@ def __call__( key = checkpoint_name(key, "key_proj") value = checkpoint_name(value, "value_proj") - out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values) + if self.config.attention == "paged" and model_mode != MODEL_MODE_TRAIN: + unnormalized_out, _, exp_sum = self.ds_paged_attention_op( + query, key, value, decoder_segment_ids, model_mode, previous_chunk, slot=slot, page_state=page_state + ) + unnormalized_out = unnormalized_out[..., :self.v_head_dim] + out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out + else: + out = self.attention_op(query, key, value, decoder_segment_ids, model_mode, cached_values) out = nn.with_logical_constraint(out, self.out_axis_names) out = self.out_projection(inputs_q.shape[-1], out) return out diff --git a/MaxText/layers/deepseek.py b/MaxText/layers/deepseek.py index 61f409dcc..37623b1ce 100644 --- a/MaxText/layers/deepseek.py +++ b/MaxText/layers/deepseek.py @@ -34,13 +34,25 @@ from MaxText.layers import moe from MaxText.layers import quantizations from MaxText.layers.quantizations import AqtQuantization as Quant +from MaxText.inference import page_manager # ----------------------------------------- # The Decoder Layer for DeepSeek v3 # ----------------------------------------- -def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode): +def self_attention_with_norm( + inputs, + cfg, + mesh, + quant, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk = None, + page_state: Optional[page_manager.PageState] = None, + slot: Optional[int] = None): """self-attention with normalization""" # Normalization lnx_rms = models.RMSNorm( @@ -86,6 +98,9 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco decoder_segment_ids=decoder_segment_ids, deterministic=deterministic, model_mode=model_mode, + previous_chunk=previous_chunk, + page_state=page_state, + slot=slot, ) attention_lnx = nn.with_logical_constraint( @@ -139,16 +154,26 @@ def __call__( decoder_positions, deterministic, model_mode, - previous_chunk=None, - page_state=None, - slot=None, + previous_chunk = None, + page_state: Optional[page_manager.PageState] = None, + slot: Optional[int] = None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") hidden_states, intermediate_inputs = self_attention_with_norm( - inputs, cfg, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode + inputs, + cfg, + self.mesh, + self.quant, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, ) mlp_lnx = linears.MlpBlock( intermediate_dim=cfg.mlp_dim, @@ -189,16 +214,26 @@ def __call__( decoder_positions, deterministic, model_mode, - previous_chunk=None, - page_state=None, - slot=None, + previous_chunk = None, + page_state: Optional[page_manager.PageState] = None, + slot: Optional[int] = None, ): cfg = self.config inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_norm_length", "activation_embed")) inputs = checkpoint_name(inputs, "decoder_layer_input") hidden_states, intermediate_inputs = self_attention_with_norm( - inputs, self.config, self.mesh, self.quant, decoder_segment_ids, decoder_positions, deterministic, model_mode + inputs, + self.config, + self.mesh, + self.quant, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk, + page_state, + slot, ) # NOTE: the naming mismatch here is to ensure reverse compatibility with existing checkpoints. diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 8f813b8a1..e2f320948 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -552,8 +552,13 @@ def __call__( if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." + layer_call_kwargs = { + "page_state": page_state, + "previous_chunk": previous_chunk, + "slot": slot, + } dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] + dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) y, _ = self.scan_decoder_layers(cfg, dense_layer, cfg.first_num_dense_layers, "dense_layers", mesh)( y, decoder_segment_ids, @@ -561,6 +566,8 @@ def __call__( deterministic, model_mode, ) + moe_layer = RemattedBlockLayers[1] + moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers y, _ = self.scan_decoder_layers(cfg, moe_layer, num_moe_layers, "moe_layers", mesh)( y,