-
Notifications
You must be signed in to change notification settings - Fork 287
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Safetensors conversion #2290
Conversation
Thanks for the PR, will take a look in a bit :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Just left some initial comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a unit test that calls this util and tries loading the result with transformers and seeing if it works. OK to add transformers to our ci environment here https://github.com/keras-team/keras-hub/blob/master/requirements-common.txt
import os | ||
|
||
import torch | ||
from safetensors.torch import save_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out
… into safetensors_conversion merge updated branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Please address the changes from the earlier PR as well
keras_hub/src/utils/transformers/export_gemma_to_safetensors_test.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, nice work!
return hf_config | ||
|
||
|
||
def export_to_hf(keras_model, path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add the API export decorator here, similar to this: https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/bloom/bloom_backbone.py#L15-L16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API.
So, this is how the directory keras_hub/src/utils/transformers/convert_to_safetensor/
will look like:
export.py
: this will have the common code. We will expose this as the API. This will also check if we support safetensor conversion for a given passed model yet.gemma.py
: this will just have a way to create the weight dictionary for Gemma. Insideexport.py
, we will call the the weight conversion function specific to a specified model.
Pinging @mattdangerw to confirm if we should do this now or at a later point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like model.save_to_preset()
or a function like some_export(model)
. Any thoughts?
Description of the change
Reference
Colab Notebook
https://colab.research.google.com/drive/1naqf0sO2J40skndWbVMeQismjL7MuEjd?usp=sharingChecklist