-
Notifications
You must be signed in to change notification settings - Fork 293
[WIP] [SmolLM3] Add Backbone and CausalLM #2327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
f2dedc4
1d90715
e5a8f33
54191ca
1369733
2448d80
598fd74
b9e458d
6a53a7d
81eff73
b0080f2
d5767c1
186eaf8
6ab2e5c
6819fd1
e126938
26511b2
d81e831
e07e848
e25fcdd
5a49ed6
89391d9
7a9d99c
e3067a5
7622315
982a546
7319f48
3c3d7fb
8046d4b
2d4a3b5
53efb59
c136080
7b7ebbb
4148384
d9a0f7a
58e87f6
5a6fb27
e17dd99
4c4e1e0
c8b7423
8aebfd1
2c674dc
06472bb
f913179
a663a5c
630cc70
de79b8d
fc5974d
3575636
c71ea2e
bb905f3
edbb757
6cf8422
7b193e5
3ede850
327e2bf
a798d35
9b5cd11
308a682
6c65160
a56474c
973b0e5
b940997
98daba9
783f8d7
97aea00
c31c889
3aaa3e1
46eed1a
72fabf4
66170cb
ceab147
0ded2ae
93bc8e8
af85773
f66846b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.modeling.reversible_embedding import ( | ||
ReversibleEmbedding, | ||
) | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer | ||
from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.models.SmolLM3Backbone", | ||
"keras_hub.models.SmolLMBackbone", | ||
] | ||
) | ||
class SmolLM3Backbone(Backbone): | ||
""" | ||
The SmolLM Transformer core architecture with hyperparameters. | ||
|
||
This network implements a Transformer-based decoder network, | ||
SmolLM3, as described in the SmolLM3 model architecture. | ||
It includes the embedding lookups and transformer layers. | ||
|
||
The default constructor gives a fully customizable, randomly initialized | ||
SmolLM3 model with any number of layers, heads, and embedding | ||
dimensions. To load preset architectures and weights, use the `from_preset` | ||
constructor. | ||
|
||
Args: | ||
|
||
|
||
Examples: | ||
|
||
```python | ||
input_data = { | ||
"token_ids": np.ones(shape=(1, 12), dtype="int32"), | ||
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), | ||
} | ||
|
||
# Pretrained SmolLM decoder. | ||
model = keras_hub.models.SmolLM3Backbone.from_preset("...") | ||
model(input_data) | ||
|
||
# Randomly initialized SmolLM3 decoder with custom config. | ||
model = keras_hub.models.SmolLM3Backbone( | ||
... | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
Comment on lines
+20
to
+53
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class docstring is incomplete. Please fill out the Style Guide ReferencesFootnotes |
||
|
||
def __init__( | ||
self, | ||
vocabulary_size, | ||
hidden_dim, | ||
intermediate_dim, | ||
num_layers, | ||
num_attention_heads, | ||
num_key_value_heads, | ||
attention_bias, | ||
attention_dropout, | ||
rope_layer_enabled_list, | ||
layer_types, | ||
mlp_bias, | ||
layer_norm_epsilon, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually there's some of these terms (like the epsilon's and rope theta) that have a consistent value across all the presets we care about, and we give them defaults here. Not super important, just for people that wanted an easier time making a custom small version of the arch or something like that. |
||
max_position_embeddings, | ||
rope_theta, | ||
partial_rotary_factor, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.token_embedding = ReversibleEmbedding( | ||
input_dim=vocabulary_size, | ||
output_dim=hidden_dim, | ||
name="token_embedding", | ||
) | ||
self.transformer_layers = [] | ||
for i in range(num_layers): | ||
layer = SmolLM3DecoderLayer( | ||
hidden_size=hidden_dim, | ||
num_attention_heads=num_attention_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
attention_bias=attention_bias, | ||
attention_dropout=attention_dropout, | ||
rope_layer_enabled_list=rope_layer_enabled_list, | ||
layer_types=layer_types, | ||
layer_idx=i, | ||
intermediate_size=intermediate_dim, | ||
mlp_bias=mlp_bias, | ||
layer_norm_epsilon=layer_norm_epsilon, | ||
name=f"transformer_layer_{i}", | ||
) | ||
self.transformer_layers.append(layer) | ||
|
||
self.norm = keras.layers.RMSNormalization( | ||
epsilon=layer_norm_epsilon, | ||
name="sequence_output_layernorm", | ||
) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
# === Functional Model === | ||
token_id_input = keras.Input( | ||
shape=(None,), dtype="int32", name="token_ids" | ||
) | ||
|
||
padding_mask_input = keras.Input( | ||
shape=(None,), dtype="int32", name="padding_mask" | ||
) | ||
|
||
x = self.token_embedding(token_id_input) | ||
|
||
for decoder_layer in self.transformer_layers: | ||
x = decoder_layer( | ||
x, | ||
decoder_padding_mask=padding_mask_input, | ||
**kwargs, | ||
) | ||
|
||
sequence_output = self.norm(x) | ||
super().__init__( | ||
inputs={ | ||
"token_ids": token_id_input, | ||
"padding_mask": padding_mask_input, | ||
}, | ||
outputs=sequence_output, | ||
**kwargs, | ||
) | ||
|
||
# === Config === | ||
self.vocabulary_size = vocabulary_size | ||
self.hidden_dim = hidden_dim | ||
self.intermediate_dim = intermediate_dim | ||
self.num_layers = num_layers | ||
self.num_attention_heads = num_attention_heads | ||
self.num_key_value_heads = num_key_value_heads | ||
self.attention_bias = attention_bias | ||
self.attention_dropout = attention_dropout | ||
self.rope_layer_enabled_list = rope_layer_enabled_list | ||
self.layer_types = layer_types | ||
self.mlp_bias = mlp_bias | ||
self.layer_norm_epsilon = layer_norm_epsilon | ||
self.max_position_embeddings = max_position_embeddings | ||
self.rope_theta = rope_theta | ||
self.partial_rotary_factor = partial_rotary_factor | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"vocabulary_size": self.vocabulary_size, | ||
"hidden_dim": self.hidden_dim, | ||
"intermediate_dim": self.intermediate_dim, | ||
"num_layers": self.num_layers, | ||
"num_attention_heads": self.num_attention_heads, | ||
"num_key_value_heads": self.num_key_value_heads, | ||
"attention_bias": self.attention_bias, | ||
"attention_dropout": self.attention_dropout, | ||
"rope_layer_enabled_list": self.rope_layer_enabled_list, | ||
"layer_types": self.layer_types, | ||
"mlp_bias": self.mlp_bias, | ||
"layer_norm_epsilon": self.layer_norm_epsilon, | ||
"max_position_embeddings": self.max_position_embeddings, | ||
"rope_theta": self.rope_theta, | ||
"partial_rotary_factor": self.partial_rotary_factor, | ||
} | ||
) | ||
return config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the commented-out debug print statement.