Skip to content

Enable Checkpoint Conversion from Huggingface to Maxtext #1839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

YixuanWang-99
Copy link
Collaborator

@YixuanWang-99 YixuanWang-99 commented Jun 16, 2025

Description

Enable checkpoint conversion from Huggingface to Maxtext.

  • Add to_maxtext.py to perform the checkpoint conversion from HF to MaxText.
  • Add convert_gemma2_to_mt.sh to automate the conversion and verification.
  • Add mt_hf_mutual_conversion_check.py to compare the Huggingface and MaxText checkpoints.
  • Official Gemma2 models are supported.

Tests

The converted checkpoint is tested with mt_hf_mutual_conversion_check.py. It compared:

  1. For given prompts, top-k predicted tokens and scores for the next token;
  2. KL divergence of the full logit distributions

Tested on Gemma2-2b Model. A successful conversion example.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@YixuanWang-99 YixuanWang-99 changed the title Enable conversion from Huggingface to Maxtext Enable Checkpoint Conversion from Huggingface to Maxtext Jun 16, 2025
Copy link
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent work!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, mt_hf_mutual_conversion_check checks MaxText vs HF, while hf_checkpoint_conversion_check checks converted HF vs original HF? Would you suggest to deprecate the latter one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the latter one is not necessary any more.

# Tokenize for HF
inputs = tokenizer(input_text, return_tensors="pt", padding=True, max_length=config.max_target_length, truncation=True)
actual_seq_len = inputs["input_ids"].shape[1]
# actual_seq_len = 4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: shall we remove these commented codes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

mt_decoder_positions = mt_decoder_positions_full[:, :actual_seq_len]
# max_logging.log(f"MaxText input shapes: ids={mt_ids.shape}, "
# f"decoder_positions={mt_decoder_positions.shape}, "
# f"decoder_segment_ids={mt_decoder_segment_ids.shape}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: similar to above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice, thanks for adding such a feature!

# Get parameter mappings and hooks
model_key = config.model_name
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config.scan_layers)
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config.scan_layers, saving_to_hf=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hook_fn is the type of function that you customize to convert the parameter array, such as reshape/transpose right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, mt_to_hf and hf_to_mt use the same param_mapping and hook_fn mapping.

mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)

# Load Hugging Face model, config, and state_dict
max_logging.log(f"Loading Hugging Face model: {model_id}...")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any mechanism to check the supported model_id? If a user provides an unsupported model, we could raise an error like: "Model {model_id} is not currently supported in MaxText. Supported models are: {list of supported models from ckpt_conversion}."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Also revised the to_huggingface.py to make it consistent

@hengtaoguo
Copy link
Collaborator

Hi @gagika ! I've heard this might be interesting to you for loading/saving HF checkpoints. Would you like to take a look when you got a chance? Thanks a lot for your time!

Copy link
Collaborator

@shralex shralex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Yixuan! Added a few comments

from MaxText.train import save_checkpoint
from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add comments explaining what this script does, which parameters are supported, which os environment variables are needed for it to run, and example(s) of invoking it (if you'd like you could point to convert_gemma2_to_mt.sh or copy relevant command here as well)

model_id = HF_IDS[config.model_name]
max_utils.print_system_information()
if not config.base_output_directory:
output_directory = os.path.expanduser("~/.mt_output/")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about output_directory = os.path.join(os.getcwd() , "mt_output")
so that the output is under the directory where the script is invoked

max_logging.log("Starting weight transformation...")
final_mt_weights_numpy_list = []

for path_tuple, abstract_leaf_value in abstract_params_flat:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add more comments here, perhaps an example of a parameter and how its mapped.

It loads the HF checkpoint and a maxtext checkpoint, and:
1. runs a foward pass of a MaxText model and a HF model
2. compares their output logits for a given input
3. compares the predicted token sequences
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give an example how to invoke it, what parameters to pass etc

tokenizer_path="${TOKENIZER_PATH}" \
load_parameters_path="${OUTPUT_BASE_DIR}/0/items" \
per_device_batch_size="${PER_DEVICE_BATCH_SIZE}" \
run_name="mt_gemma2_check" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you use MODEL_NAME here instead of hardcoding gemma2, also below


echo "--- Starting Comparing Logits and Predicted Tokens ---"
python3 -m "MaxText.tests.mt_hf_mutual_conversion_check" \
hf_model_id="google/gemma-2-2b" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you make this a parameter, also below instead of hardcoding, so all model-specific parameters are at the top of this script

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants