Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
f2dedc4
add first few utils
DavidLandup0 Jul 14, 2025
1d90715
add eager attention forward
DavidLandup0 Jul 14, 2025
e5a8f33
Add SmolLM3Attention
DavidLandup0 Jul 14, 2025
54191ca
Add SmolLM3MLP
DavidLandup0 Jul 14, 2025
1369733
Add SmolLM3DecoderLayer
DavidLandup0 Jul 14, 2025
2448d80
remove unnecessary comments
DavidLandup0 Jul 14, 2025
598fd74
Add SmolLM3RotaryEmbedding
DavidLandup0 Jul 14, 2025
b9e458d
add most of smollm3backbone
DavidLandup0 Jul 14, 2025
6a53a7d
Fix calls within causal model
DavidLandup0 Jul 16, 2025
81eff73
Move causal mask computation to forward call
DavidLandup0 Jul 16, 2025
b0080f2
Add convert_smollm3.py and update preset loader
DavidLandup0 Jul 16, 2025
d5767c1
Fix causal mask call
DavidLandup0 Jul 16, 2025
186eaf8
Fix conversion weight names
DavidLandup0 Jul 16, 2025
6ab2e5c
remove unnecessary arg
DavidLandup0 Jul 16, 2025
6819fd1
Build all layers
DavidLandup0 Jul 16, 2025
e126938
Remove k and q norms
DavidLandup0 Jul 16, 2025
26511b2
add causal attn mask, a few fixes
DavidLandup0 Jul 16, 2025
d81e831
add softmax op
DavidLandup0 Jul 26, 2025
e07e848
fix build cache shape?
DavidLandup0 Jul 26, 2025
e25fcdd
fix shape positioning in cache update
DavidLandup0 Jul 26, 2025
5a49ed6
Remove position ids as input
DavidLandup0 Jul 26, 2025
89391d9
use sampler's max length
DavidLandup0 Jul 26, 2025
7a9d99c
format
DavidLandup0 Jul 26, 2025
e3067a5
add logs
DavidLandup0 Jul 26, 2025
7622315
switch order or value heads and max length
DavidLandup0 Jul 26, 2025
982a546
oh god please
DavidLandup0 Jul 26, 2025
7319f48
oh god please
DavidLandup0 Jul 26, 2025
3c3d7fb
oh god please
DavidLandup0 Jul 26, 2025
8046d4b
oh god please
DavidLandup0 Jul 26, 2025
2d4a3b5
oh god please
DavidLandup0 Jul 26, 2025
53efb59
god has answered my prayers
DavidLandup0 Jul 26, 2025
c136080
Simplify position ids
DavidLandup0 Aug 5, 2025
7b7ebbb
Simplify position ids
DavidLandup0 Aug 5, 2025
4148384
Use existing rotary embeddings
DavidLandup0 Aug 5, 2025
d9a0f7a
Use existing rotary embeddings
DavidLandup0 Aug 5, 2025
58e87f6
Use existing rotary embeddings
DavidLandup0 Aug 5, 2025
5a6fb27
pass dtype policy
DavidLandup0 Aug 5, 2025
e17dd99
pass dtype policy
DavidLandup0 Aug 5, 2025
4c4e1e0
pass dtype policy
DavidLandup0 Aug 5, 2025
c8b7423
pass dtype policy
DavidLandup0 Aug 5, 2025
8aebfd1
refactor rotary embeddings
DavidLandup0 Aug 5, 2025
2c674dc
refactor rotary embeddings
DavidLandup0 Aug 5, 2025
06472bb
refactor rotary embeddings
DavidLandup0 Aug 5, 2025
f913179
refactor rotary embeddings
DavidLandup0 Aug 5, 2025
a663a5c
refactor rotary embeddings
DavidLandup0 Aug 5, 2025
630cc70
log cache_update_index
DavidLandup0 Aug 5, 2025
de79b8d
rotary embed in loop
DavidLandup0 Aug 5, 2025
fc5974d
log cache_update_index
DavidLandup0 Aug 5, 2025
3575636
rotary embed in loop
DavidLandup0 Aug 5, 2025
c71ea2e
small refactor
DavidLandup0 Aug 16, 2025
bb905f3
add logging
DavidLandup0 Aug 16, 2025
edbb757
more logging
DavidLandup0 Aug 16, 2025
6cf8422
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
7b193e5
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
3ede850
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
327e2bf
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
a798d35
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
9b5cd11
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
308a682
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
6c65160
don't reshape unnecessarily in compute_kv
DavidLandup0 Aug 16, 2025
a56474c
logging
DavidLandup0 Aug 17, 2025
973b0e5
adjust how rope is applied
DavidLandup0 Aug 17, 2025
b940997
adjust how rope is applied
DavidLandup0 Aug 17, 2025
98daba9
switch to kerashub rotaryembedding
DavidLandup0 Aug 17, 2025
783f8d7
switch to kerashub rotaryembedding
DavidLandup0 Aug 17, 2025
97aea00
remove reshape
DavidLandup0 Aug 17, 2025
c31c889
fix reshape
DavidLandup0 Aug 17, 2025
3aaa3e1
new attention computation
DavidLandup0 Aug 17, 2025
46eed1a
new attention computation
DavidLandup0 Aug 17, 2025
72fabf4
new attention computation
DavidLandup0 Aug 17, 2025
66170cb
new attention computation
DavidLandup0 Aug 17, 2025
ceab147
new attention computation
DavidLandup0 Aug 17, 2025
0ded2ae
new attention computation
DavidLandup0 Aug 17, 2025
93bc8e8
new attention computation
DavidLandup0 Aug 17, 2025
af85773
slight cleanup
DavidLandup0 Aug 17, 2025
f66846b
slight cleanup
DavidLandup0 Aug 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,30 @@
from keras_hub.src.models.siglip.siglip_vision_encoder import (
SigLIPVisionEncoder as SigLIPVisionEncoder,
)
from keras_hub.src.models.smollm3.smollm3_backbone import (
SmolLM3Backbone as SmolLM3Backbone,
)
from keras_hub.src.models.smollm3.smollm3_backbone import (
SmolLM3Backbone as SmolLMBackbone,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
SmolLM3CausalLM as SmolLM3CausalLM,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm import (
SmolLM3CausalLM as SmolLMCausalLM,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
SmolLM3CausalLMPreprocessor as SmolLM3CausalLMPreprocessor,
)
from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import (
SmolLM3CausalLMPreprocessor as SmolLMCausalLMPreprocessor,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLM3Tokenizer,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLMTokenizer,
)
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (
StableDiffusion3Backbone as StableDiffusion3Backbone,
)
Expand Down
6 changes: 6 additions & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@
from keras_hub.src.models.siglip.siglip_tokenizer import (
SigLIPTokenizer as SigLIPTokenizer,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLM3Tokenizer,
)
from keras_hub.src.models.smollm3.smollm3_tokenizer import (
SmolLM3Tokenizer as SmolLMTokenizer,
)
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer
from keras_hub.src.models.whisper.whisper_tokenizer import (
WhisperTokenizer as WhisperTokenizer,
Expand Down
1 change: 1 addition & 0 deletions keras_hub/src/models/qwen3/qwen3_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def call_with_cache(
self_attention_cache=current_cache,
self_attention_cache_update_index=cache_update_index,
)
#print(next_cache.shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Please remove the commented-out debug print statement.

updated_cache.append(next_cache)
cache = ops.stack(updated_cache, axis=1)
hidden_states = x = self.backbone.layer_norm(x)
Expand Down
169 changes: 169 additions & 0 deletions keras_hub/src/models/smollm3/smollm3_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The SmolLM3RotaryEmbedding class is imported but never used in this file. Please remove the unused import to keep the code clean.



@keras_hub_export(
[
"keras_hub.models.SmolLM3Backbone",
"keras_hub.models.SmolLMBackbone",
]
)
class SmolLM3Backbone(Backbone):
"""
The SmolLM Transformer core architecture with hyperparameters.

This network implements a Transformer-based decoder network,
SmolLM3, as described in the SmolLM3 model architecture.
It includes the embedding lookups and transformer layers.

The default constructor gives a fully customizable, randomly initialized
SmolLM3 model with any number of layers, heads, and embedding
dimensions. To load preset architectures and weights, use the `from_preset`
constructor.

Args:


Examples:

```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}

# Pretrained SmolLM decoder.
model = keras_hub.models.SmolLM3Backbone.from_preset("...")
model(input_data)

# Randomly initialized SmolLM3 decoder with custom config.
model = keras_hub.models.SmolLM3Backbone(
...
)
model(input_data)
```
"""
Comment on lines +20 to +53

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The class docstring is incomplete. Please fill out the Args section with descriptions for all __init__ parameters. The examples for from_preset and the custom config also use placeholders like "...". Providing concrete examples is crucial for users to understand how to use the model. 1

Style Guide References

Footnotes

  1. Docstrings should be comprehensive, including documenting all parameters and providing complete usage examples. (link)


def __init__(
self,
vocabulary_size,
hidden_dim,
intermediate_dim,
num_layers,
num_attention_heads,
num_key_value_heads,
attention_bias,
attention_dropout,
rope_layer_enabled_list,
layer_types,
mlp_bias,
layer_norm_epsilon,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually there's some of these terms (like the epsilon's and rope theta) that have a consistent value across all the presets we care about, and we give them defaults here. Not super important, just for people that wanted an easier time making a custom small version of the arch or something like that.

max_position_embeddings,
rope_theta,
partial_rotary_factor,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
name="token_embedding",
)
self.transformer_layers = []
for i in range(num_layers):
layer = SmolLM3DecoderLayer(
hidden_size=hidden_dim,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
attention_bias=attention_bias,
attention_dropout=attention_dropout,
rope_layer_enabled_list=rope_layer_enabled_list,
layer_types=layer_types,
layer_idx=i,
intermediate_size=intermediate_dim,
mlp_bias=mlp_bias,
layer_norm_epsilon=layer_norm_epsilon,
name=f"transformer_layer_{i}",
)
self.transformer_layers.append(layer)

self.norm = keras.layers.RMSNormalization(
epsilon=layer_norm_epsilon,
name="sequence_output_layernorm",
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line contains only whitespace and should be removed.

# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)

padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

x = self.token_embedding(token_id_input)

for decoder_layer in self.transformer_layers:
x = decoder_layer(
x,
decoder_padding_mask=padding_mask_input,
**kwargs,
)

sequence_output = self.norm(x)
super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
**kwargs,
)

# === Config ===
self.vocabulary_size = vocabulary_size
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rope_layer_enabled_list = rope_layer_enabled_list
self.layer_types = layer_types
self.mlp_bias = mlp_bias
self.layer_norm_epsilon = layer_norm_epsilon
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.partial_rotary_factor = partial_rotary_factor

def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_layers": self.num_layers,
"num_attention_heads": self.num_attention_heads,
"num_key_value_heads": self.num_key_value_heads,
"attention_bias": self.attention_bias,
"attention_dropout": self.attention_dropout,
"rope_layer_enabled_list": self.rope_layer_enabled_list,
"layer_types": self.layer_types,
"mlp_bias": self.mlp_bias,
"layer_norm_epsilon": self.layer_norm_epsilon,
"max_position_embeddings": self.max_position_embeddings,
"rope_theta": self.rope_theta,
"partial_rotary_factor": self.partial_rotary_factor,
}
)
return config
Loading
Loading