-
Notifications
You must be signed in to change notification settings - Fork 362
Enable Checkpoint Conversion from Huggingface to Maxtext #1839
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, mt_hf_mutual_conversion_check
checks MaxText vs HF, while hf_checkpoint_conversion_check
checks converted HF vs original HF? Would you suggest to deprecate the latter one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the latter one is not necessary any more.
# Tokenize for HF | ||
inputs = tokenizer(input_text, return_tensors="pt", padding=True, max_length=config.max_target_length, truncation=True) | ||
actual_seq_len = inputs["input_ids"].shape[1] | ||
# actual_seq_len = 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: shall we remove these commented codes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
mt_decoder_positions = mt_decoder_positions_full[:, :actual_seq_len] | ||
# max_logging.log(f"MaxText input shapes: ids={mt_ids.shape}, " | ||
# f"decoder_positions={mt_decoder_positions.shape}, " | ||
# f"decoder_segment_ids={mt_decoder_segment_ids.shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: similar to above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is nice, thanks for adding such a feature!
# Get parameter mappings and hooks | ||
model_key = config.model_name | ||
param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config.scan_layers) | ||
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config.scan_layers, saving_to_hf=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hook_fn is the type of function that you customize to convert the parameter array, such as reshape/transpose right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, mt_to_hf and hf_to_mt use the same param_mapping and hook_fn mapping.
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) | ||
|
||
# Load Hugging Face model, config, and state_dict | ||
max_logging.log(f"Loading Hugging Face model: {model_id}...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have any mechanism to check the supported model_id? If a user provides an unsupported model, we could raise an error like: "Model {model_id} is not currently supported in MaxText. Supported models are: {list of supported models from ckpt_conversion}."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added. Also revised the to_huggingface.py to make it consistent
Hi @gagika ! I've heard this might be interesting to you for loading/saving HF checkpoints. Would you like to take a look when you got a chance? Thanks a lot for your time! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Yixuan! Added a few comments
from MaxText.train import save_checkpoint | ||
from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING | ||
from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add comments explaining what this script does, which parameters are supported, which os environment variables are needed for it to run, and example(s) of invoking it (if you'd like you could point to convert_gemma2_to_mt.sh or copy relevant command here as well)
model_id = HF_IDS[config.model_name] | ||
max_utils.print_system_information() | ||
if not config.base_output_directory: | ||
output_directory = os.path.expanduser("~/.mt_output/") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about output_directory = os.path.join(os.getcwd() , "mt_output")
so that the output is under the directory where the script is invoked
max_logging.log("Starting weight transformation...") | ||
final_mt_weights_numpy_list = [] | ||
|
||
for path_tuple, abstract_leaf_value in abstract_params_flat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add more comments here, perhaps an example of a parameter and how its mapped.
It loads the HF checkpoint and a maxtext checkpoint, and: | ||
1. runs a foward pass of a MaxText model and a HF model | ||
2. compares their output logits for a given input | ||
3. compares the predicted token sequences |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give an example how to invoke it, what parameters to pass etc
tokenizer_path="${TOKENIZER_PATH}" \ | ||
load_parameters_path="${OUTPUT_BASE_DIR}/0/items" \ | ||
per_device_batch_size="${PER_DEVICE_BATCH_SIZE}" \ | ||
run_name="mt_gemma2_check" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you use MODEL_NAME here instead of hardcoding gemma2, also below
|
||
echo "--- Starting Comparing Logits and Predicted Tokens ---" | ||
python3 -m "MaxText.tests.mt_hf_mutual_conversion_check" \ | ||
hf_model_id="google/gemma-2-2b" \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you make this a parameter, also below instead of hardcoding, so all model-specific parameters are at the top of this script
Description
Enable checkpoint conversion from Huggingface to Maxtext.
Tests
The converted checkpoint is tested with mt_hf_mutual_conversion_check.py. It compared:
Tested on Gemma2-2b Model. A successful conversion example.
Checklist
Before submitting this PR, please make sure (put X in square brackets):