Skip to content

Integrate Multi-Token Prediction (MTP) Training objective #1837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: parambole/mtp_refactor
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher preci
float32_qk_product: False # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: False # in dot_product attention, whether to cast to fp32 the inputs to softmax

# Multi-Token Prediction Configs
# The number of auxiliary prediction layers to use for MTP.
# Set to 0 to disable the feature.
mtp_num_layers: 0
# The scaling factor (lambda) for the MTP auxiliary loss. The final loss is:
# main_loss + mtp_loss_scaling_factor * avg_mtp_loss
mtp_loss_scaling_factor: 0.1

# mixture of experts (moe)
num_experts: 1
num_experts_per_tok: 1
Expand Down
49 changes: 47 additions & 2 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@

from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR
from MaxText.inference import page_manager
from MaxText import maxtext_utils
from MaxText import multimodal_utils
from MaxText.layers.blocks import Decoder, VisionEncoder
from MaxText.layers.embeddings import Embed
from MaxText.layers.quantizations import AqtQuantization as Quant
from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock


# ------------------------------------------------------------------------------
# The network: Transformer Definitions
Expand Down Expand Up @@ -59,14 +62,25 @@ def setup(self):
name="token_embedder",
config=cfg,
)

self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh) if cfg.use_multimodal else None
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant)
# If MTP is enabled via config, set up the MTP block.
if self.config.mtp_num_layers > 0:
# Get the list of layer blueprints for the current model.
layer_types = maxtext_utils.get_decoder_layers(self.config)
# For MTP, we use the primary (usually dense) transformer block blueprint
# to ensure architectural consistency. By convention, this is the first in the list.
mtp_layer = layer_types[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For DeepSeek, it's mixed layers ([deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]. Could you confirm if this is dense or moe layer? It is a moe layer if I recall correctly.

self.mtp_block = MultiTokenPredictionBlock(
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder
)

def __call__(
self,
decoder_input_tokens: jnp.ndarray,
decoder_positions: jnp.ndarray,
decoder_target_tokens: Optional[jnp.ndarray] = None,
decoder_target_mask: Optional[jnp.ndarray] = None,
decoder_segment_ids=None,
encoder_images: Optional[jnp.ndarray] = None,
enable_dropout=True,
Expand Down Expand Up @@ -99,7 +113,7 @@ def __call__(
if self.config.decoder_block == DecoderBlockType.GEMMA3:
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER

logits, _ = self.decoder(
logits, hidden_state = self.decoder(
decoder_input_tokens=decoder_input_tokens,
decoder_positions=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
Expand All @@ -111,4 +125,35 @@ def __call__(
bidirectional_mask=bidirectional_mask,
image_embeddings=image_embeddings,
)

# If we are initializing the model AND MTP is enabled, we must create
# dummy target tensors. This allows Flax to trace the MTPBlock and create
# all its necessary parameters, without requiring the main training pipeline
# to be aware of this initialization detail.
if self.is_initializing() and self.config.mtp_num_layers > 0:
if decoder_target_tokens is None:
dummy_shape = decoder_input_tokens.shape
decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32)
decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32)

# The Multi-Token Prediction (MTP) block functions as a "side-car" to the main
# model, active only during training. It computes an auxiliary loss based on
# predicting multiple future tokens, as described in the DeepSeek-V3 paper.
# To ensure architectural consistency, it uses two key components from the parent Transformer:
# 1. The same `DecoderLayer` blueprint for its internal transformer blocks.
# 2. The `shared_embedding` for both embedding future tokens and for its final
# logit projection.
# Its only effect is to "sow" these losses; it does not alter the primary logits output.
if self.config.mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add assertion in pyconfig.py that mtp_num_layers>1 for inference/serving is not supported?

self.mtp_block(
main_hidden_state=hidden_state,
input_ids=decoder_input_tokens,
target_ids=decoder_target_tokens,
target_mask=decoder_target_mask,
position_ids=decoder_positions,
decoder_segment_ids=decoder_segment_ids,
deterministic=not enable_dropout,
model_mode=model_mode,
)

return logits
78 changes: 77 additions & 1 deletion MaxText/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@

from typing import Optional, Type

import jax
import jax.numpy as jnp
from jax.sharding import Mesh

from flax import linen as nn

from MaxText.common_types import Config, MODEL_MODE_TRAIN
from MaxText.layers.attentions import dense_general
from MaxText.layers.blocks import DecoderLayer
from MaxText.layers.blocks import DecoderLayer, Decoder
from MaxText.layers.normalizations import RMSNorm
from MaxText import max_utils
from MaxText import maxtext_utils


class MultiTokenPredictionLayer(nn.Module):
Expand Down Expand Up @@ -136,3 +139,76 @@ def __call__(
# Shape: [B, S, H]
# --- Return Processed Hidden State ---
return next_hidden_state


class MultiTokenPredictionBlock(nn.Module):
"""Orchestrates the MTP process by running a sequence of MTP layers."""

config: Config
mesh: Mesh
transformer_layer_module: Type[DecoderLayer]
decoder: Type[Decoder]

@nn.compact
def __call__(
self,
main_hidden_state,
input_ids,
target_ids,
target_mask,
position_ids,
decoder_segment_ids,
deterministic,
model_mode: str = MODEL_MODE_TRAIN,
):
cfg = self.config
# The initial hidden state for the MTP chain is the raw output from the main model.
mtp_hidden_state = main_hidden_state

# These variables are updated sequentially in each loop iteration,
# moving the prediction window one token to the right each time.
rolled_input_ids = input_ids
rolled_target_ids = target_ids
rolled_target_mask = target_mask
rolled_position_id = position_ids

# Range chosen to align with the naming convention of the paper
for k in range(1, cfg.mtp_num_layers + 1):
# Sequentially roll all tensors to prepare data for predicting the k-th future token.
rolled_input_ids = maxtext_utils.roll_and_mask(rolled_input_ids)
rolled_target_ids = maxtext_utils.roll_and_mask(rolled_target_ids)
rolled_target_mask = maxtext_utils.roll_and_mask(rolled_target_mask)
rolled_position_id = maxtext_utils.roll_and_mask(rolled_position_id)

# Embed the k-th future input tokens using the shared embedding module
target_token_embedding = self.decoder._apply_embedding(rolled_input_ids, rolled_position_id, deterministic)

# Instantiate and apply the MTP layer for this step
mtp_layer = MultiTokenPredictionLayer(
config=cfg,
mesh=self.mesh,
layer_number=k,
name=f"mtp_layer_{k}",
transformer_layer_module=self.transformer_layer_module,
)

next_mtp_hidden_state = mtp_layer(
mtp_hidden_state, target_token_embedding, position_ids, decoder_segment_ids, deterministic, model_mode
)

# Project to logits using the shared embedding transpose
mtp_logits = self.decoder._apply_output_head(next_mtp_hidden_state, deterministic, model_mode)

# Calculate cross-entropy loss for this specific layer's prediction
mtp_xent, _ = max_utils.cross_entropy_with_logits(mtp_logits, jax.nn.one_hot(rolled_target_ids, cfg.vocab_size), 0.0)
mtp_xent_masked = mtp_xent * rolled_target_mask

# This condition ensures loss is only computed during training runs (`.apply`),
# and not during model initialization (`.init()`).
if not self.is_initializing():
# "Sow" the loss values into the 'mtp_losses' collection for the
self.sow("mtp_losses", "losses", jnp.sum(mtp_xent_masked))
self.sow("mtp_losses", "weights", jnp.sum(rolled_target_mask))

# The output of this layer is the input for the next, maintaining the causal chain.
mtp_hidden_state = next_mtp_hidden_state
21 changes: 21 additions & 0 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,27 @@ def schedule(step):
return optax.join_schedules(pieces, boundaries)


def roll_and_mask(x: jnp.ndarray, shift: int = -1) -> jnp.ndarray:
"""
Performs a leftward roll on the sequence axis (axis=1) and masks the
newly created invalid positions at the end of the sequence.
Assumes input `x` has a batch dimension at axis 0 and sequence at axis 1.

Args:
x: The input array of shape [batch, seq_len, ...].
shift: The number of positions to shift left.

Returns:
The rolled array of the same shape as x.
"""
# If shift is 0, it's a no-op. Return the original array.
if shift == 0:
return x

# to set the last `abs(shift)` elements of the sequence to zero.
return jnp.roll(x, shift, axis=1).at[:, shift:, ...].set(0)


def get_formatted_sharding_annotations(params, mesh=None):
"""
Generates a readable string report of sharding annotations for all parameters.
Expand Down
59 changes: 59 additions & 0 deletions MaxText/tests/maxtext_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,65 @@ def multiple_rules(self):
self.assertEqual(transformed_rules, expected_transform)


class TestRollAndMask(unittest.TestCase):
"""Test class for utility functions supporting Roll and Mask."""

def test_mtp_roll_and_mask_shapes(self):
"""
Validates that roll_and_mask works correctly on the specific tensor shapes
that will be passed during training. The primary use case involves tensors
with a [batch, sequence_length] shape.
"""
batch_size = 4
seq_len = 8
# Create a dummy input tensor that mimics `input_ids` or `target_ids`.
# The values are sequential for easy validation.
# Shape: [4, 8]
input_tensor = jnp.arange(batch_size * seq_len, dtype=jnp.int32).reshape((batch_size, seq_len))

# print(input_tensor)

# --- Test Case 1: Default left shift by 1 ---
# This is the most common operation inside the MTP loop.
rolled_by_1 = maxtext_utils.roll_and_mask(input_tensor, shift=-1)

# Manually construct the expected output using jnp
expected_1 = jnp.array(
[
[1, 2, 3, 4, 5, 6, 7, 0], # First row rolled left, last element masked
[9, 10, 11, 12, 13, 14, 15, 0], # Second row rolled left
[17, 18, 19, 20, 21, 22, 23, 0],
[25, 26, 27, 28, 29, 30, 31, 0],
],
dtype=jnp.int32,
)

self.assertEqual(rolled_by_1.shape, (batch_size, seq_len), "Shape should be preserved after rolling.")
self.assertTrue(jnp.array_equal(rolled_by_1, expected_1), "Array content is incorrect after shift by -1.")

# --- Test Case 2: Larger left shift by 3 ---
# This simulates a later step in a hypothetical MTP loop.
rolled_by_3 = maxtext_utils.roll_and_mask(input_tensor, shift=-3)

# Manually construct the expected output using jnp
expected_3 = jnp.array(
[
[3, 4, 5, 6, 7, 0, 0, 0], # First row rolled left by 3, last 3 masked
[11, 12, 13, 14, 15, 0, 0, 0],
[19, 20, 21, 22, 23, 0, 0, 0],
[27, 28, 29, 30, 31, 0, 0, 0],
],
dtype=jnp.int32,
)
self.assertEqual(rolled_by_3.shape, (batch_size, seq_len), "Shape should be preserved after rolling.")
self.assertTrue(jnp.array_equal(rolled_by_3, expected_3), "Array content is incorrect after shift by -3.")

# --- Test Case 3: Shift of 0 (edge case) ---
# This should result in no change to the tensor.
rolled_by_0 = maxtext_utils.roll_and_mask(input_tensor, shift=0)
self.assertTrue(jnp.array_equal(rolled_by_0, input_tensor), "A shift of 0 should be a no-op.")


class TestAssertParamsSufficientlySharded(unittest.TestCase):
"""
Test suite for the sharding assertion utility function 'assert_params_sufficiently_sharded'.
Expand Down
Loading
Loading