Skip to content

Safetensors conversion #2290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

Bond099
Copy link

@Bond099 Bond099 commented Jun 6, 2025

Description of the change

Reference

Colab Notebook

https://colab.research.google.com/drive/1naqf0sO2J40skndWbVMeQismjL7MuEjd?usp=sharing

Checklist

  • I have added all the necessary unit tests for my change.
  • I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • I have followed the Keras Hub Model contribution guidelines in making these changes.
  • I have followed the Keras Hub API design guidelines in making these changes.
  • I have signed the Contributor License Agreement.

@abheesht17
Copy link
Collaborator

Thanks for the PR, will take a look in a bit :)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just left some initial comments.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a unit test that calls this util and tries loading the result with transformers and seeing if it works. OK to add transformers to our ci environment here https://github.com/keras-team/keras-hub/blob/master/requirements-common.txt

import os

import torch
from safetensors.torch import save_file
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Jun 19, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jun 19, 2025
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Please address the changes from the earlier PR as well

Copy link
Collaborator

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, nice work!

return hf_config


def export_to_hf(keras_model, path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@abheesht17 abheesht17 Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.

So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/ will look like:

  • export.py: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.
  • gemma.py: this will just have a way to create the weight dictionary for Gemma. Inside export.py, we will call the the weight conversion function specific to a specified model.

Pinging @mattdangerw to confirm if we should do this now or at a later point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset() or a function like some_export(model). Any thoughts?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants