Skip to content

Add T5Gemma to KerasHub #2339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
071d0df
init: Add initial project structure and files
harshaljanjani Jul 19, 2025
1c9ebbc
nit: Fix code format test; and cool AI-generated reviews
harshaljanjani Jul 19, 2025
1c7dc13
refactor: Cleanup and replace incorrect T5LayerNorm with RMSNormaliza…
harshaljanjani Jul 21, 2025
41910d3
fix: Numerics @ atol=1e-4
harshaljanjani Jul 22, 2025
a8eb53c
refactor: Refactor T5Gemma decoder cache handling
harshaljanjani Jul 23, 2025
95f563b
feat: Add checkpoint conversion script
harshaljanjani Jul 23, 2025
afb9845
nit: Precise compute_output_shape methods; document head_dim
harshaljanjani Jul 24, 2025
5be6438
nit: Propagate dtypes
harshaljanjani Jul 24, 2025
3dbc0b7
bug fix + minor cleanup: Fix head_dim default → head_dim from config
harshaljanjani Jul 24, 2025
291d8f1
perf(jax/tpu): Fused kernel optim for TPU backend + get_config() args
harshaljanjani Jul 25, 2025
524aa37
cleanup: Slight refactor
harshaljanjani Jul 25, 2025
c1af495
Merge branch 'keras-team:master' into t5gemma
harshaljanjani Jul 26, 2025
889e23b
fix: Enable mixed precision and quantization tests
harshaljanjani Jul 30, 2025
32a6912
feat: Add support for asymmetrical presets (only invariants included)
harshaljanjani Jul 30, 2025
050910b
refactor: Address reviews - presets will be handled post D-FINE
harshaljanjani Aug 6, 2025
6b320fa
feat: Support direct loading of Hugging Face checkpoints
harshaljanjani Aug 17, 2025
26db4d1
✅ Yayy: Generate outputs identical, hidden states match within 1e-3
harshaljanjani Aug 21, 2025
87a221d
preset test: Register and test a preset (to be replaced later by the …
harshaljanjani Aug 22, 2025
9c79058
nit: Sharded weights don’t include `model.weights.h5`
harshaljanjani Aug 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,18 @@
T5Preprocessor as T5Preprocessor,
)
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
from keras_hub.src.models.t5gemma.t5gemma_backbone import (
T5GemmaBackbone as T5GemmaBackbone,
)
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm import (
T5GemmaSeq2SeqLM as T5GemmaSeq2SeqLM,
)
from keras_hub.src.models.t5gemma.t5gemma_seq_2_seq_lm_preprocessor import (
T5GemmaSeq2SeqLMPreprocessor as T5GemmaSeq2SeqLMPreprocessor,
)
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
T5GemmaTokenizer as T5GemmaTokenizer,
)
from keras_hub.src.models.task import Task as Task
from keras_hub.src.models.text_classifier import TextClassifier as Classifier
from keras_hub.src.models.text_classifier import (
Expand Down
3 changes: 3 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
SigLIPTokenizer as SigLIPTokenizer,
)
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
from keras_hub.src.models.t5gemma.t5gemma_tokenizer import (
T5GemmaTokenizer as T5GemmaTokenizer,
)
from keras_hub.src.models.whisper.whisper_tokenizer import (
WhisperTokenizer as WhisperTokenizer,
)
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/t5gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.t5gemma.t5gemma_backbone import T5GemmaBackbone
from keras_hub.src.models.t5gemma.t5gemma_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, T5GemmaBackbone)
Loading
Loading