-
Notifications
You must be signed in to change notification settings - Fork 287
Safetensors conversion #2290
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Safetensors conversion #2290
Changes from all commits
903733b
9f99030
c896fdb
219bf37
b5cf25c
6eaa954
2cbedc4
bbb2042
df2951a
aa5f7e0
ab27a73
cda19d3
f31ad26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import json | ||
import os | ||
import shutil | ||
import warnings | ||
|
||
import torch | ||
from safetensors.torch import save_file | ||
|
||
|
||
def convert_to_hf_config(keras_config): | ||
hf_config = { | ||
"vocab_size": keras_config.vocabulary_size, | ||
"num_hidden_layers": keras_config.num_layers, | ||
"num_attention_heads": keras_config.num_query_heads, | ||
"num_key_value_heads": keras_config.num_key_value_heads, | ||
"hidden_size": keras_config.hidden_dim, | ||
"intermediate_size": keras_config.intermediate_dim // 2, | ||
"head_dim": keras_config.head_dim, | ||
"max_position_embeddings": 8192, | ||
} | ||
return hf_config | ||
|
||
|
||
def export_to_hf(keras_model, path): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should add the API export decorator here, similar to this: https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/bloom/bloom_backbone.py#L15-L16 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, do you think we should refactor some of the common code across models to a separate file? We can then expose that as the API. So, this is how the directory
Pinging @mattdangerw to confirm if we should do this now or at a later point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could land and do the API bit a later point. Though agree it's an important concern. I'm not sure if we want a method like |
||
"""This function converts a Keras Gemma model to Hugging Face format by: | ||
- Extracting and mapping weights from the Keras backbone to safetensors. | ||
- Saving the configuration as 'config.json'. | ||
- Saving weights in 'model.safetensors'. | ||
- Saving tokenizer assets. | ||
Args: | ||
keras_model: The Keras Gemma model (e.g., GemmaCausalLM) to convert. | ||
path: str. Path of the directory to which the safetensors file, | ||
config and tokenizer will be saved. | ||
""" | ||
backbone = keras_model.backbone | ||
hf_config = convert_to_hf_config(backbone) | ||
|
||
weights_dict = {} | ||
|
||
# Map token embedding | ||
token_embedding = backbone.get_layer("token_embedding").get_weights()[0] | ||
weights_dict["model.embed_tokens.weight"] = torch.from_numpy( | ||
token_embedding | ||
) | ||
|
||
for i in range(backbone.num_layers): | ||
decoder_layer = backbone.get_layer(f"decoder_block_{i}") | ||
|
||
# Pre-attention normalization | ||
pre_attn_norm = decoder_layer.pre_attention_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = ( | ||
torch.from_numpy(pre_attn_norm) | ||
) | ||
|
||
# Attention query projection | ||
query_kernel = decoder_layer.attention.query_dense.get_weights()[0] | ||
query_kernel = ( | ||
torch.from_numpy(query_kernel) | ||
.permute(1, 0, 2) | ||
.reshape(-1, backbone.hidden_dim) | ||
.T | ||
) | ||
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel | ||
|
||
# Attention key projection | ||
key_kernel = decoder_layer.attention.key_dense.get_weights()[0][0] | ||
key_kernel = torch.from_numpy(key_kernel).T | ||
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = key_kernel | ||
|
||
# Attention value projection | ||
value_kernel = decoder_layer.attention.value_dense.get_weights()[0][0] | ||
value_kernel = torch.from_numpy(value_kernel).T | ||
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = value_kernel | ||
|
||
# Attention output projection | ||
out_kernel = decoder_layer.attention.output_dense.get_weights()[0] | ||
out_kernel = ( | ||
torch.from_numpy(out_kernel) | ||
.permute(2, 0, 1) | ||
.reshape(backbone.hidden_dim, -1) | ||
) | ||
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel | ||
|
||
# Post-attention normalization | ||
post_attn_norm = decoder_layer.pre_ffw_norm.get_weights()[0] | ||
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( | ||
torch.from_numpy(post_attn_norm) | ||
) | ||
|
||
# MLP gate projection | ||
gate_kernel = decoder_layer.gating_ffw.get_weights()[0] | ||
gate_kernel = torch.from_numpy(gate_kernel).T | ||
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel | ||
|
||
# MLP up projection | ||
up_kernel = decoder_layer.gating_ffw_2.get_weights()[0] | ||
up_kernel = torch.from_numpy(up_kernel).T | ||
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel | ||
|
||
# MLP down projection | ||
down_kernel = decoder_layer.ffw_linear.get_weights()[0] | ||
down_kernel = torch.from_numpy(down_kernel).T | ||
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel | ||
|
||
# Map final normalization | ||
final_norm = backbone.get_layer("final_normalization").get_weights()[0] | ||
weights_dict["model.norm.weight"] = torch.from_numpy(final_norm) | ||
|
||
# Tie lm_head.weight to embedding weights | ||
weights_dict["lm_head.weight"] = weights_dict[ | ||
"model.embed_tokens.weight" | ||
].clone() | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Save config | ||
os.makedirs(path, exist_ok=True) | ||
config_path = os.path.join(path, "config.json") | ||
with open(config_path, "w") as f: | ||
json.dump(hf_config, f) | ||
|
||
# Make tensors contiguous before saving | ||
weights_dict_contiguous = { | ||
k: v.contiguous() for k, v in weights_dict.items() | ||
} | ||
Bond099 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Save weights | ||
weights_path = os.path.join(path, "model.safetensors") | ||
save_file(weights_dict_contiguous, weights_path) | ||
|
||
# Save tokenizer assets | ||
keras_model.preprocessor.tokenizer.save_assets(path) | ||
|
||
# Rename vocabulary file | ||
vocab_spm_path = os.path.join(path, "vocabulary.spm") | ||
tokenizer_model_path = os.path.join(path, "tokenizer.model") | ||
if os.path.exists(vocab_spm_path): | ||
shutil.move(vocab_spm_path, tokenizer_model_path) | ||
else: | ||
warnings.warn( | ||
f"{vocab_spm_path} not found. Tokenizer may not load correctly." | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import os | ||
|
||
import pytest | ||
import torch | ||
from transformers import GemmaForCausalLM | ||
from transformers import GemmaTokenizer | ||
|
||
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM | ||
from keras_hub.src.tests.test_case import TestCase | ||
from keras_hub.src.utils.transformers.export_gemma_to_safetensor import ( | ||
export_to_hf, | ||
) | ||
|
||
|
||
class TestGemmaExport(TestCase): | ||
@pytest.mark.large | ||
def test_export_to_hf(self): | ||
# Load Keras model | ||
keras_model = GemmaCausalLM.from_preset("gemma_2b_en") | ||
input_text = "All hail RCB" | ||
max_length = 25 | ||
|
||
# Export to Hugging Face format using self.tmp_path | ||
export_path = os.path.join(self.get_temp_dir(), "export_to_hf") | ||
export_to_hf(keras_model, export_path) | ||
|
||
# Load Hugging Face model and tokenizer | ||
hf_model = GemmaForCausalLM.from_pretrained(export_path) | ||
hf_tokenizer = GemmaTokenizer.from_pretrained(export_path) | ||
|
||
# Generate text with Keras model | ||
keras_output = keras_model.generate(input_text, max_length=max_length) | ||
|
||
# Generate text with Hugging Face model | ||
hf_inputs = hf_tokenizer(input_text, return_tensors="pt") | ||
with torch.no_grad(): | ||
hf_outputs = hf_model.generate( | ||
**hf_inputs, max_length=max_length, do_sample=False | ||
) | ||
hf_output_text = hf_tokenizer.decode( | ||
hf_outputs[0], skip_special_tokens=True | ||
) | ||
|
||
self.assertEqual(keras_output, hf_output_text) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ sentencepiece | |
tensorflow-datasets | ||
safetensors | ||
pillow | ||
transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work on all backends? or do we need to flip between versions depending on the backend? worth testing out