-
Notifications
You must be signed in to change notification settings - Fork 361
Integrate Multi-Token Prediction (MTP) Training objective #1837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
parambole
wants to merge
5
commits into
parambole/mtp_refactor
Choose a base branch
from
parambole/maxtext_mtp_training_obective
base: parambole/mtp_refactor
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9a846bd
Integrate Multi-Token Prediction (MTP) Training objective
parambole 600339b
Revert Outputhead logic
parambole e048cc6
Refactoring the code so that MTP uses shared Embedding and OuputHead
parambole c32325e
Adding the missing new line
parambole cd41461
fixing lint
parambole File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,10 +25,13 @@ | |
|
||
from MaxText.common_types import DecoderBlockType, Config, MODEL_MODE_TRAIN, MODEL_MODE_AUTOREGRESSIVE, DECODING_ACTIVE_SEQUENCE_INDICATOR | ||
from MaxText.inference import page_manager | ||
from MaxText import maxtext_utils | ||
from MaxText import multimodal_utils | ||
from MaxText.layers.blocks import Decoder, VisionEncoder | ||
from MaxText.layers.embeddings import Embed | ||
from MaxText.layers.quantizations import AqtQuantization as Quant | ||
from MaxText.layers.multi_token_prediction import MultiTokenPredictionBlock | ||
|
||
|
||
# ------------------------------------------------------------------------------ | ||
# The network: Transformer Definitions | ||
|
@@ -59,14 +62,25 @@ def setup(self): | |
name="token_embedder", | ||
config=cfg, | ||
) | ||
|
||
self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh) if cfg.use_multimodal else None | ||
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) | ||
# If MTP is enabled via config, set up the MTP block. | ||
if self.config.mtp_num_layers > 0: | ||
# Get the list of layer blueprints for the current model. | ||
layer_types = maxtext_utils.get_decoder_layers(self.config) | ||
# For MTP, we use the primary (usually dense) transformer block blueprint | ||
# to ensure architectural consistency. By convention, this is the first in the list. | ||
mtp_layer = layer_types[0] | ||
self.mtp_block = MultiTokenPredictionBlock( | ||
config=self.config, mesh=self.mesh, name="mtp_block", transformer_layer_module=mtp_layer, decoder=self.decoder | ||
) | ||
|
||
def __call__( | ||
self, | ||
decoder_input_tokens: jnp.ndarray, | ||
decoder_positions: jnp.ndarray, | ||
decoder_target_tokens: Optional[jnp.ndarray] = None, | ||
decoder_target_mask: Optional[jnp.ndarray] = None, | ||
decoder_segment_ids=None, | ||
encoder_images: Optional[jnp.ndarray] = None, | ||
enable_dropout=True, | ||
|
@@ -99,7 +113,7 @@ def __call__( | |
if self.config.decoder_block == DecoderBlockType.GEMMA3: | ||
bidirectional_mask = decoder_input_tokens == multimodal_utils.GEMMA_TOKEN_PLACEHOLDER | ||
|
||
logits, _ = self.decoder( | ||
logits, hidden_state = self.decoder( | ||
decoder_input_tokens=decoder_input_tokens, | ||
decoder_positions=decoder_positions, | ||
decoder_segment_ids=decoder_segment_ids, | ||
|
@@ -111,4 +125,35 @@ def __call__( | |
bidirectional_mask=bidirectional_mask, | ||
image_embeddings=image_embeddings, | ||
) | ||
|
||
# If we are initializing the model AND MTP is enabled, we must create | ||
# dummy target tensors. This allows Flax to trace the MTPBlock and create | ||
# all its necessary parameters, without requiring the main training pipeline | ||
# to be aware of this initialization detail. | ||
if self.is_initializing() and self.config.mtp_num_layers > 0: | ||
if decoder_target_tokens is None: | ||
dummy_shape = decoder_input_tokens.shape | ||
decoder_target_tokens = jnp.ones(dummy_shape, dtype=jnp.int32) | ||
decoder_target_mask = jnp.ones(dummy_shape, dtype=jnp.int32) | ||
|
||
# The Multi-Token Prediction (MTP) block functions as a "side-car" to the main | ||
# model, active only during training. It computes an auxiliary loss based on | ||
# predicting multiple future tokens, as described in the DeepSeek-V3 paper. | ||
# To ensure architectural consistency, it uses two key components from the parent Transformer: | ||
# 1. The same `DecoderLayer` blueprint for its internal transformer blocks. | ||
# 2. The `shared_embedding` for both embedding future tokens and for its final | ||
# logit projection. | ||
# Its only effect is to "sow" these losses; it does not alter the primary logits output. | ||
if self.config.mtp_num_layers > 0 and model_mode == MODEL_MODE_TRAIN: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we add assertion in pyconfig.py that mtp_num_layers>1 for inference/serving is not supported? |
||
self.mtp_block( | ||
main_hidden_state=hidden_state, | ||
input_ids=decoder_input_tokens, | ||
target_ids=decoder_target_tokens, | ||
target_mask=decoder_target_mask, | ||
position_ids=decoder_positions, | ||
decoder_segment_ids=decoder_segment_ids, | ||
deterministic=not enable_dropout, | ||
model_mode=model_mode, | ||
) | ||
|
||
return logits |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For DeepSeek, it's mixed layers (
[deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer]
. Could you confirm if this is dense or moe layer? It is a moe layer if I recall correctly.