diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index 9ad3aeb204..4ddcbc4521 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -576,6 +576,30 @@ from keras_hub.src.models.siglip.siglip_vision_encoder import ( SigLIPVisionEncoder as SigLIPVisionEncoder, ) +from keras_hub.src.models.smollm3.smollm3_backbone import ( + SmolLM3Backbone as SmolLM3Backbone, +) +from keras_hub.src.models.smollm3.smollm3_backbone import ( + SmolLM3Backbone as SmolLMBackbone, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm import ( + SmolLM3CausalLM as SmolLM3CausalLM, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm import ( + SmolLM3CausalLM as SmolLMCausalLM, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( + SmolLM3CausalLMPreprocessor as SmolLM3CausalLMPreprocessor, +) +from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( + SmolLM3CausalLMPreprocessor as SmolLMCausalLMPreprocessor, +) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLM3Tokenizer, +) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLMTokenizer, +) from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import ( StableDiffusion3Backbone as StableDiffusion3Backbone, ) diff --git a/keras_hub/api/tokenizers/__init__.py b/keras_hub/api/tokenizers/__init__.py index 082078184f..49b4eeab99 100644 --- a/keras_hub/api/tokenizers/__init__.py +++ b/keras_hub/api/tokenizers/__init__.py @@ -86,6 +86,12 @@ from keras_hub.src.models.siglip.siglip_tokenizer import ( SigLIPTokenizer as SigLIPTokenizer, ) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLM3Tokenizer, +) +from keras_hub.src.models.smollm3.smollm3_tokenizer import ( + SmolLM3Tokenizer as SmolLMTokenizer, +) from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer as T5Tokenizer from keras_hub.src.models.whisper.whisper_tokenizer import ( WhisperTokenizer as WhisperTokenizer, diff --git a/keras_hub/src/models/qwen3/qwen3_causal_lm.py b/keras_hub/src/models/qwen3/qwen3_causal_lm.py index f2d7b10b16..5d0cb60a58 100644 --- a/keras_hub/src/models/qwen3/qwen3_causal_lm.py +++ b/keras_hub/src/models/qwen3/qwen3_causal_lm.py @@ -193,6 +193,7 @@ def call_with_cache( self_attention_cache=current_cache, self_attention_cache_update_index=cache_update_index, ) + #print(next_cache.shape) updated_cache.append(next_cache) cache = ops.stack(updated_cache, axis=1) hidden_states = x = self.backbone.layer_norm(x) diff --git a/keras_hub/src/models/smollm3/smollm3_backbone.py b/keras_hub/src/models/smollm3/smollm3_backbone.py new file mode 100644 index 0000000000..34de272091 --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_backbone.py @@ -0,0 +1,169 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.modeling.reversible_embedding import ( + ReversibleEmbedding, +) +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3DecoderLayer +from keras_hub.src.models.smollm3.smollm3_layers import SmolLM3RotaryEmbedding + + +@keras_hub_export( + [ + "keras_hub.models.SmolLM3Backbone", + "keras_hub.models.SmolLMBackbone", + ] +) +class SmolLM3Backbone(Backbone): + """ + The SmolLM Transformer core architecture with hyperparameters. + + This network implements a Transformer-based decoder network, + SmolLM3, as described in the SmolLM3 model architecture. + It includes the embedding lookups and transformer layers. + + The default constructor gives a fully customizable, randomly initialized + SmolLM3 model with any number of layers, heads, and embedding + dimensions. To load preset architectures and weights, use the `from_preset` + constructor. + + Args: + + + Examples: + + ```python + input_data = { + "token_ids": np.ones(shape=(1, 12), dtype="int32"), + "padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]), + } + + # Pretrained SmolLM decoder. + model = keras_hub.models.SmolLM3Backbone.from_preset("...") + model(input_data) + + # Randomly initialized SmolLM3 decoder with custom config. + model = keras_hub.models.SmolLM3Backbone( + ... + ) + model(input_data) + ``` + """ + + def __init__( + self, + vocabulary_size, + hidden_dim, + intermediate_dim, + num_layers, + num_attention_heads, + num_key_value_heads, + attention_bias, + attention_dropout, + rope_layer_enabled_list, + layer_types, + mlp_bias, + layer_norm_epsilon, + max_position_embeddings, + rope_theta, + partial_rotary_factor, + **kwargs, + ): + # === Layers === + self.token_embedding = ReversibleEmbedding( + input_dim=vocabulary_size, + output_dim=hidden_dim, + name="token_embedding", + ) + self.transformer_layers = [] + for i in range(num_layers): + layer = SmolLM3DecoderLayer( + hidden_size=hidden_dim, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rope_layer_enabled_list=rope_layer_enabled_list, + layer_types=layer_types, + layer_idx=i, + intermediate_size=intermediate_dim, + mlp_bias=mlp_bias, + layer_norm_epsilon=layer_norm_epsilon, + name=f"transformer_layer_{i}", + ) + self.transformer_layers.append(layer) + + self.norm = keras.layers.RMSNormalization( + epsilon=layer_norm_epsilon, + name="sequence_output_layernorm", + ) + + # === Functional Model === + token_id_input = keras.Input( + shape=(None,), dtype="int32", name="token_ids" + ) + + padding_mask_input = keras.Input( + shape=(None,), dtype="int32", name="padding_mask" + ) + + x = self.token_embedding(token_id_input) + + for decoder_layer in self.transformer_layers: + x = decoder_layer( + x, + decoder_padding_mask=padding_mask_input, + **kwargs, + ) + + sequence_output = self.norm(x) + super().__init__( + inputs={ + "token_ids": token_id_input, + "padding_mask": padding_mask_input, + }, + outputs=sequence_output, + **kwargs, + ) + + # === Config === + self.vocabulary_size = vocabulary_size + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_layer_enabled_list = rope_layer_enabled_list + self.layer_types = layer_types + self.mlp_bias = mlp_bias + self.layer_norm_epsilon = layer_norm_epsilon + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + + def get_config(self): + config = super().get_config() + config.update( + { + "vocabulary_size": self.vocabulary_size, + "hidden_dim": self.hidden_dim, + "intermediate_dim": self.intermediate_dim, + "num_layers": self.num_layers, + "num_attention_heads": self.num_attention_heads, + "num_key_value_heads": self.num_key_value_heads, + "attention_bias": self.attention_bias, + "attention_dropout": self.attention_dropout, + "rope_layer_enabled_list": self.rope_layer_enabled_list, + "layer_types": self.layer_types, + "mlp_bias": self.mlp_bias, + "layer_norm_epsilon": self.layer_norm_epsilon, + "max_position_embeddings": self.max_position_embeddings, + "rope_theta": self.rope_theta, + "partial_rotary_factor": self.partial_rotary_factor, + } + ) + return config diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm.py b/keras_hub/src/models/smollm3/smollm3_causal_lm.py new file mode 100644 index 0000000000..4fbeec477f --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm.py @@ -0,0 +1,315 @@ +import keras +from keras import ops + +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm import CausalLM +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.models.smollm3.smollm3_causal_lm_preprocessor import ( + SmolLM3CausalLMPreprocessor, +) +from keras_hub.src.utils.tensor_utils import any_equal + + +@keras_hub_export( + [ + "keras_hub.models.SmolLM3CausalLM", + "keras_hub.models.SmolLMCausalLM", + ] +) +class SmolLM3CausalLM(CausalLM): + backbone_cls = SmolLM3Backbone + preprocessor_cls = SmolLM3CausalLMPreprocessor + + def __init__(self, backbone, preprocessor=None, **kwargs): + # === Layers === + self.backbone = backbone + self.preprocessor = preprocessor + + # === Functional Model === + # This must be "backbone.input" i.e. the full input structure, + # rather than "backbone.inputs" which is the flattened list of inputs. + inputs = backbone.input + hidden_states = backbone(inputs) + outputs = backbone.token_embedding(hidden_states, reverse=True) + super().__init__( + inputs=inputs, + outputs=outputs, + **kwargs, + ) + + def call_with_cache( + self, + token_ids, + cache, + cache_update_index, + ): + """Forward pass of `SmolLM3CausalLM` with cache. + + `call_with_cache` adds an additional forward pass for the model for + autoregressive inference. Unlike calling the model directly, this method + allows caching previous key/value Tensors in multi-head attention layer, + and avoids recomputing the outputs of seen tokens. + + Args: + token_ids: a dense int Tensor with shape `(batch_size, seq_len)`. + For prefill, `seq_len` is the prompt length. For generation, + `seq_len` is typically 1. + cache: a dense float Tensor, the cache of key and value. + Shape: (batch_size, num_layers, 2, max_seq_len, num_key_value_heads, head_dim) + cache_update_index: int, or int Tensor. The index of current inputs + in the whole sequence. + training: Boolean, whether the call is during training or inference. + attention_mask: Optional attention mask. + + Returns: + A (logits, hidden_states, cache) tuple. Where `logits` is the + language model logits for the input token_ids, `hidden_states` is + the final hidden representation of the input tokens, and `cache` is + the decoding cache. + """ + x = self.backbone.token_embedding(token_ids) + + # Each decoder layer has a cache; we update them separately. + updated_cache = [] + + for i in range(self.backbone.num_layers): + current_cache = cache[:, i, ...] + x, next_cache = self.backbone.transformer_layers[i]( + x, + self_attention_cache=current_cache, + self_attention_cache_update_index=cache_update_index, + ) + updated_cache.append(next_cache) + cache = ops.stack(updated_cache, axis=1) + hidden_states = x = self.backbone.norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + return logits, hidden_states, cache + + def _build_cache(self, token_ids): + """Build an empty cache for use with `call_with_cache()`.""" + batch_size = ops.shape(token_ids)[0] + max_length = ops.shape(token_ids)[1] + num_layers = self.backbone.num_layers + num_key_value_heads = self.backbone.num_key_value_heads + head_dim = self.backbone.hidden_dim // self.backbone.num_attention_heads + shape = [ + batch_size, + num_layers, + 2, + max_length, + num_key_value_heads, + head_dim, + ] + cache = ops.zeros(shape, dtype=self.compute_dtype) + index = ops.convert_to_tensor(0, dtype="int32") + # Seed the cache. + _, hidden_states, cache = self.call_with_cache(token_ids, cache, index) + return hidden_states, cache + + def generate_step( + self, + inputs, + stop_token_ids=None, + ): + """A compilable generation function for a single batch of inputs. + + This function represents the inner, XLA-compilable, generation function + for a single batch of inputs. Inputs should have the same structure as + model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`. + + Args: + inputs: A dictionary with two keys `"token_ids"` and + `"padding_mask"` and batched tensor values. + stop_token_ids: Tuple of id's of the end token to stop on. If all + sequences have produced a new stop token, generation + will stop. + """ + token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"] + + hidden_states, cache = self._build_cache(token_ids) + # Compute the lengths of all user inputted tokens ids. + row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1) + # Start at the first index that has no user inputted id. + index = ops.min(row_lengths) + + def next(prompt, cache, index): + # The cache index is the index of our previous token. + cache_update_index = index - 1 + batch_size = ops.shape(prompt)[0] + prompt = ops.slice(prompt, [0, cache_update_index], [batch_size, 1]) + logits, hidden_states, cache = self.call_with_cache( + prompt, + cache, + cache_update_index, + ) + return ( + ops.squeeze(logits, axis=1), + ops.squeeze(hidden_states, axis=1), + cache, + ) + + token_ids = self.sampler( + next=next, + prompt=token_ids, + cache=cache, + index=index, + mask=padding_mask, + stop_token_ids=stop_token_ids, + hidden_states=hidden_states, + model=self, + ) + + # Compute an output padding mask with the token ids we updated. + if stop_token_ids is not None: + # Build a mask of stop token locations not in the original + # prompt (not in locations where `padding_mask` is True). + end_locations = any_equal( + token_ids, stop_token_ids, ops.logical_not(padding_mask) + ) + end_locations = ops.cast(end_locations, "int32") + # Use cumsum to get ones in all locations after end_locations. + cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32") + overflow = cumsum - end_locations + # Our padding mask is the inverse of these overflow locations. + padding_mask = ops.logical_not(ops.cast(overflow, "bool")) + else: + # Without early stopping, all locations will have been updated. + padding_mask = ops.ones_like(token_ids, dtype="bool") + return { + "token_ids": token_ids, + "padding_mask": padding_mask, + } + + def score( + self, + token_ids, + padding_mask=None, + scoring_mode="logits", + layer_intercept_fn=None, + target_ids=None, + ): + """Score a generation represented by the provided token ids. + + Args: + token_ids: A [batch_size, num_tokens] tensor containing tokens + to score. Typically, this tensor captures the output from a call + to `SmolLM3CausalLM.generate()`, i.e., tokens for both the input + text and the model-generated text. + padding_mask: A [batch_size, num_tokens] tensor indicating the + tokens that should be preserved during generation. This is an + artifact required by the `SmolLM3Backbone` and isn't influential + on the computation of this function. If omitted, this function + uses `keras.ops.ones()` to create a tensor of the appropriate + shape. + scoring_mode: The type of scores to return, either "logits" or + "loss", both will be per input token. + layer_intercept_fn: An optional function for augmenting activations + with additional computation, for example, as part of + interpretability research. This function will be passed the + activations as its first parameter and a numeric index + associated with that backbone layer. _This index _is not_ an + index into `self.backbone.layers`_. The index -1 accompanies the + embeddings returned by calling `self.backbone.token_embedding()` + on `token_ids` in the forward direction. All subsequent indexes + will be 0-based indices for the activations returned by each of + the Transformers layers in the backbone. This function must + return a [batch_size, num_tokens, hidden_dims] tensor + that can be passed as an input to the next layer in the model. + target_ids: An [batch_size, num_tokens] tensor containing the + predicted tokens against which the loss should be computed. If a + span of tokens is provided (sequential truthy values along + axis=1 in the tensor), the loss will be computed as the + aggregate across those tokens. + + Raises: + ValueError: If an unsupported scoring_mode is provided, or if the + target_ids are not provided when using ScoringMode.LOSS. + + Returns: + The per-token scores as a tensor of size + [batch_size, num_tokens, vocab_size] in "logits" mode, or + [batch_size, num_tokens] in "loss" mode. + + Example: + + Compute gradients between embeddings and loss scores with TensorFlow: + ```python + smol_lm = keras_hub.models.SmolLM3CausalLM.from_preset("...") + generations = smol_lm.generate( + ["This is a", "Where are you"], + max_length=30 + ) + preprocessed = smol_lm.preprocessor.generate_preprocess(generations) + generation_ids = preprocessed["token_ids"] + padding_mask = preprocessed["padding_mask"] + target_ids = keras.ops.roll(generation_ids, shift=-1, axis=1) + + embeddings = None + with tf.GradientTape(watch_accessed_variables=True) as tape: + def layer_intercept_fn(x, i): + if i == -1: + nonlocal embeddings, tape + embeddings = x + tape.watch(embeddings) + return x + + losses = smol_lm.score( + token_ids=generation_ids, + padding_mask=padding_mask, + scoring_mode="loss", + layer_intercept_fn=layer_intercept_fn, + target_ids=target_ids, + ) + + grads = tape.gradient(losses, embeddings) + ``` + """ + if scoring_mode not in ("logits", "loss"): + raise ValueError( + "Unsupported scoring_mode. Must be one of 'logits' or 'loss'." + ) + + if scoring_mode == "loss" and target_ids is None: + raise ValueError( + "Cannot compute loss without targets. Please provide target " + "token ids via the target_ids parameter." + ) + + batch_shape = ops.shape(token_ids)[:2] + assert len(batch_shape) == 2 + + if padding_mask is None: + padding_mask = ops.ones(shape=batch_shape) + + if layer_intercept_fn is None: + + def default_layer_intercept_fn(x, unused_i): + return x + + layer_intercept_fn = default_layer_intercept_fn + + token_embeddings = self.backbone.token_embedding(token_ids) + x = layer_intercept_fn(token_embeddings, -1) + + # Get position embeddings for the full sequence + position_embeddings = self.backbone.rotary_embedding(x) + + for i, transformer_layer in enumerate(self.backbone.transformer_layers): + x = transformer_layer( + hidden_states=x, + position_embeddings=position_embeddings, + attention_mask=padding_mask, + ) + x = layer_intercept_fn(x, i) + + x = self.backbone.norm(x) + logits = self.backbone.token_embedding(x, reverse=True) + + if scoring_mode == "logits": + return logits + + per_token_loss_fn = keras.losses.SparseCategoricalCrossentropy( + from_logits=True, reduction="none" + ) + per_token_loss = per_token_loss_fn(target_ids, logits) + return per_token_loss diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py b/keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py new file mode 100644 index 0000000000..432519829f --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_causal_lm_preprocessor.py @@ -0,0 +1,84 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.causal_lm_preprocessor import CausalLMPreprocessor +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.models.smollm3.smollm3_tokenizer import SmolLM3Tokenizer + + +@keras_hub_export( + [ + "keras_hub.models.SmolLMCausalLMPreprocessor", + "keras_hub.models.SmolLM3CausalLMPreprocessor", + ] +) +class SmolLM3CausalLMPreprocessor(CausalLMPreprocessor): + """SmolLM3 Causal LM preprocessor. + + This preprocessing layer is meant for use with + `keras_hub.models.SmolLM3CausalLM`. By default, it will take in batches of + strings, and return outputs in a `(x, y, sample_weight)` format, where the + `y` label is the next token id in the `x` sequence. + + For use with generation, the layer also exposes two methods + `generate_preprocess()` and `generate_postprocess()`. When this preprocessor + is attached to a `keras_hub.models.SmolLM3CausalLM` instance, these methods + will be called implicitly in `generate()`. They can also be called + standalone (e.g. to precompute preprocessing inputs for generation in a + separate process). + + Args: + tokenizer: A `keras_hub.models.SmolLM3Tokenizer` instance. + sequence_length: The length of the packed inputs. + add_start_token: If `True`, the preprocessor will prepend the tokenizer + start token to each input sequence. Default is `True`. + add_end_token: If `True`, the preprocessor will append the tokenizer + end token to each input sequence. Default is `False`. + + Call arguments: + x: A string, `tf.Tensor` or list of python strings. + y: Label data. Should always be `None` as the layer generates labels. + sample_weight: Label weights. Should always be `None` as the layer + generates label weights. + sequence_length: Pass to override the configured `sequence_length` of + the layer. + + Examples: + ```python + # Load the preprocessor from a preset. + preprocessor = keras_hub.models.SmolLM3CausalLMPreprocessor.from_preset( + "..." + ) + + # Tokenize and pack a single sentence. + sentence = tf.constant("...") + preprocessor(sentence) + # Same output. + preprocessor("...") + + # Tokenize a batch of sentences. + sentences = tf.constant(["...", "..."]) + preprocessor(sentences) + # Same output. + preprocessor(["...", "..."]) + + # Map a dataset to preprocess a single sentence. + features = tf.constant( + [ + "...", + "...", + ] + ) + labels = tf.constant([1, 0]) + ds = tf.data.Dataset.from_tensor_slices((features, labels)) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + + # Map a dataset to preprocess unlabled sentences. + ds = tf.data.Dataset.from_tensor_slices(features) + ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) + ``` + """ + + backbone_cls = SmolLM3Backbone + tokenizer_cls = SmolLM3Tokenizer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/keras_hub/src/models/smollm3/smollm3_layers.py b/keras_hub/src/models/smollm3/smollm3_layers.py new file mode 100644 index 0000000000..811aaa9dac --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_layers.py @@ -0,0 +1,699 @@ +from keras import activations +from keras import initializers +from keras import layers +from keras import ops + +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + compute_causal_mask, +) +from keras_hub.src.layers.modeling.transformer_layer_utils import ( + merge_padding_and_attention_mask, +) +from keras_hub.src.models.smollm3.smollm3_utils import rope_init +from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding +import math + + +class SmolLM3Attention(layers.Layer): + """ + Multi-head attention layer for SmolLM3 model. + + Args: + hidden_size: The hidden size of the attention layer. + num_attention_heads: The number of attention heads. + num_key_value_heads: The number of key-value heads. + attention_bias: Whether to use bias in attention projections. + attention_dropout: Dropout rate for attention weights. + rope_layer_enabled_list: List indicating if RoPE is enabled for each layer. + layer_types: List of layer types. + layer_idx: Index of the current layer. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + attention_bias: bool, + attention_dropout: float, + rope_layer_enabled_list: list[bool], + layer_types: list[str], + layer_idx: int, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_layer_enabled_list = rope_layer_enabled_list + self.layer_types = layer_types + self._dot_product_equation = "bquh,bkuh->buqk" + self._combine_equation = "buqk,bkuh->bquh" + + self.head_dim = hidden_size // self.num_attention_heads + self._inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.rotary_embedding = RotaryEmbedding( + max_wavelength=5000000.0, + ) + + self.layer_idx = layer_idx + + self.head_dim = self.hidden_size // self.num_attention_heads + self.num_key_value_groups = ( + self.num_attention_heads // self.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.is_causal = True + + self.q_proj = layers.Dense( + self.num_attention_heads * self.head_dim, + use_bias=self.attention_bias, + name="q_proj", + ) + self.k_proj = layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="k_proj", + ) + self.v_proj = layers.Dense( + self.num_key_value_heads * self.head_dim, + use_bias=self.attention_bias, + name="v_proj", + ) + self.o_proj = layers.EinsumDense( + equation="bquh,uhm->bqm", + output_shape=(None, self.hidden_size), + name="attention_output", + ) + self.o_proj.build((None, None, self.num_attention_heads, self.head_dim)) + + self.use_rope = ( + self.rope_layer_enabled_list[self.layer_idx] + if self.layer_idx < len(self.rope_layer_enabled_list) + else True + ) # Default to True if index out of bounds + + self._softmax = layers.Softmax( + axis=-1, + dtype="float32", + name="attention_softmax", + ) + + def build(self, input_shape): + """ + Builds the internal Dense layers. + Args: + input_shape: A list/tuple of shapes for the inputs: + [hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape] + - hidden_states_shape: (batch_size, seq_len, hidden_size) + """ + # The input shape to the Dense layers (q_proj, k_proj, v_proj, o_proj) + # is the same as the hidden_states input to SmolLM3Attention. + hidden_states_shape = input_shape[0] + self.q_proj.build(hidden_states_shape) + self.k_proj.build(hidden_states_shape) + self.v_proj.build(hidden_states_shape) + super().build(input_shape) + + def call( + self, + hidden_states, + training=False, + attention_mask=None, + **kwargs, + ): + """ + Forward pass for SmolLM3Attention. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size). + position_embeddings: Tuple of (cos, sin) tensors for RoPE. + attention_mask: Attention mask tensor. + training: Whether the layer is in training mode. + """ + self.training = training + self_attention_cache = kwargs.get("self_attention_cache", None) + self_attention_cache_update_index = kwargs.get( + "self_attention_cache_update_index", None + ) + start_index = ( + self_attention_cache_update_index if self_attention_cache_update_index is not None else 0 + ) + + input_shape = ops.shape(hidden_states)[:-1] + hidden_shape = (*input_shape, self.num_attention_heads, self.head_dim) + + query = ops.reshape(self.q_proj(hidden_states), hidden_shape) + + def _compute_kv_values(x_input): + kv_hidden_shape = ( + *input_shape, + self.num_key_value_heads, + self.head_dim, + ) + + key = ops.reshape(self.k_proj(x_input), kv_hidden_shape) + value = ops.reshape( + self.v_proj(x_input), kv_hidden_shape + ) + + return key, value + + if self_attention_cache is not None: + key_cache = self_attention_cache[:, 0, ...] + value_cache = self_attention_cache[:, 1, ...] + + if self_attention_cache_update_index is None: + key = key_cache + value = value_cache + else: + key_update, value_update = _compute_kv_values(hidden_states) + start = [0, self_attention_cache_update_index, 0, 0] + key = ops.slice_update(key_cache, start, key_update) + value = ops.slice_update( + value_cache, start, value_update + ) + self_attention_cache = ops.stack( + (key, value), axis=1 + ) + else: + if self_attention_cache_update_index is not None: + raise ValueError( + "`self_attention_cache_update_index` should not be set if `self_attention_cache` is " + f"`None`. Received: self_attention_cache={self_attention_cache}, " + f"self_attention_cache_update_index={self_attention_cache_update_index}" + ) + key, value = _compute_kv_values(hidden_states) + + if self.use_rope: + query = self.rotary_embedding(query, start_index=start_index) + key = self.rotary_embedding(key, start_index=start_index) + + print('pre', key.shape, value.shape) + key = ops.repeat(key, repeats=self.num_key_value_groups, axis=2) + value = ops.repeat(value, repeats=self.num_key_value_groups, axis=2) + print('post', key.shape, value.shape) + + attn_output = self._compute_attention( + query, + key, + value, + attention_mask, + cache_update_index=self_attention_cache_update_index, + ) + + attn_output = self.o_proj(attn_output) + + if self_attention_cache is not None: + return attn_output, self_attention_cache + + return attn_output + + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: A list/tuple of shapes for the inputs: + [hidden_states_shape, position_embeddings_shape_tuple, attention_mask_shape] + - hidden_states_shape: (batch_size, seq_len, hidden_size) + - position_embeddings_shape_tuple: (cos_shape, sin_shape) where cos_shape/sin_shape is (batch_size, seq_len, head_dim) + - attention_mask_shape: (batch_size, 1, seq_len, seq_len) + + Returns: + A list of output shapes: [output_attn_output_shape, output_attn_weights_shape] + """ + hidden_states_shape = input_shape[0] + + batch_size = hidden_states_shape[0] + seq_len = hidden_states_shape[1] + + output_attn_output_shape = (batch_size, seq_len, self.hidden_size) + + output_attn_weights_shape = ( + batch_size, + self.num_attention_heads, + seq_len, + seq_len, + ) + + return [output_attn_output_shape, output_attn_weights_shape] + + + + def _masked_softmax(self, attention_scores, attention_mask=None): + """Applies softmax with optional masking. + + Args: + attention_scores: Attention score tensor. + attention_mask: Optional mask tensor. + + Returns: + Masked softmax attention weights. + """ + if attention_mask is not None: + return self._softmax( + attention_scores, attention_mask[:, None, :, :] + ) + return self._softmax(attention_scores) + + def _compute_attention( + self, query, key, value, attention_mask=None, cache_update_index=None + ): + """Computes attention using query, key, and value tensors. + + Uses Flash Attention when available for better performance. + + Args: + query: Query tensor. + key: Key tensor. + value: Value tensor. + attention_mask: Optional mask tensor. + cache_update_index: Index for sliding window computation. + + Returns: + attention_output: Output tensor after applying attention. + """ + attention_scores = ops.einsum(self._dot_product_equation, query, key) + + attention_scores = ops.multiply( + attention_scores, + ops.cast(self._inv_norm_factor, self.compute_dtype), + ) + attention_scores = self._masked_softmax( + attention_scores, attention_mask + ) + attention_scores = ops.cast(attention_scores, self.compute_dtype) + attention_output = ops.einsum( + self._combine_equation, attention_scores, value + ) + + return attention_output + + +class SmolLM3MLP(layers.Layer): + """ + Multi-layer perceptron (MLP) block for SmolLM3 model. + + Args: + hidden_size: The hidden size of the MLP. + intermediate_size: The intermediate size of the MLP. + mlp_bias: Whether to use bias in MLP dense layers. + """ + + def __init__( + self, hidden_size: int, intermediate_size: int, mlp_bias: bool, **kwargs + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.mlp_bias = mlp_bias + + self.gate_proj = layers.Dense( + self.intermediate_size, use_bias=self.mlp_bias, name="gate_proj" + ) + self.up_proj = layers.Dense( + self.intermediate_size, use_bias=self.mlp_bias, name="up_proj" + ) + self.down_proj = layers.Dense( + self.hidden_size, use_bias=self.mlp_bias, name="down_proj" + ) + + def build(self, input_shape): + """ + Builds the internal Dense layers. + Args: + input_shape: The shape of the input to this layer + (batch_size, seq_len, hidden_size). + """ + self.gate_proj.build(input_shape) + self.up_proj.build(input_shape) + # The down_proj takes intermediate_output, which has shape + # (batch_size, seq_len, intermediate_size) + down_proj_input_shape = ( + input_shape[0], + input_shape[1], + self.intermediate_size, + ) + self.down_proj.build(down_proj_input_shape) + super().build(input_shape) + + def call(self, x): + """ + Forward pass for SmolLM3MLP. + + Args: + x: Input tensor of shape (batch_size, seq_len, hidden_size). + """ + gate_output = activations.silu(self.gate_proj(x)) + up_output = self.up_proj(x) + intermediate_output = gate_output * up_output + down_proj_output = self.down_proj(intermediate_output) + return down_proj_output + + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: The input shape (batch_size, seq_len, hidden_size). + + Returns: + The output shape, which is the same as the input shape: + (batch_size, seq_len, hidden_size). + """ + return input_shape + + +class SmolLM3DecoderLayer(layers.Layer): + """ + Decoder layer for SmolLM3 model, combining self-attention and MLP. + + Args: + hidden_size: The hidden size of the layer. + num_attention_heads: The number of attention heads. + num_key_value_heads: The number of key-value heads. + attention_bias: Whether to use bias in attention projections. + attention_dropout: Dropout rate for attention weights. + rope_layer_enabled_list: List indicating if RoPE is enabled for each layer. + layer_types: List of layer types. + layer_idx: Index of the current layer. + intermediate_size: The intermediate size of the MLP. + mlp_bias: Whether to use bias in MLP dense layers. + layer_norm_epsilon: Epsilon for RMSNormalization. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + num_key_value_heads: int, + attention_bias: bool, + attention_dropout: float, + rope_layer_enabled_list: list[bool], + layer_types: list[str], + layer_idx: int, + intermediate_size: int, + mlp_bias: bool, + layer_norm_epsilon: float, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.layer_idx = layer_idx + + self.self_attn = SmolLM3Attention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + attention_bias=attention_bias, + attention_dropout=attention_dropout, + rope_layer_enabled_list=rope_layer_enabled_list, + layer_types=layer_types, + layer_idx=layer_idx, + name="self_attn", + ) + + self.mlp = SmolLM3MLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + mlp_bias=mlp_bias, + name="mlp", + ) + + self.input_layernorm = layers.RMSNormalization( + epsilon=layer_norm_epsilon, axis=-1, name="input_layernorm" + ) + self.post_attention_layernorm = layers.RMSNormalization( + epsilon=layer_norm_epsilon, axis=-1, name="post_attention_layernorm" + ) + + self.attention_type = layer_types[layer_idx] + + def _compute_self_attention_mask( + self, + decoder_sequence, + decoder_padding_mask, + decoder_attention_mask, + self_attention_cache, + self_attention_cache_update_index, + ): + decoder_mask = merge_padding_and_attention_mask( + decoder_sequence, decoder_padding_mask, decoder_attention_mask + ) + batch_size = ops.shape(decoder_sequence)[0] + input_length = output_length = ops.shape(decoder_sequence)[1] + # We need to handle a rectangular causal mask when doing cached + # decoding. For generative inference, `decoder_sequence` will + # generally be length 1, and `cache` will be the full generation length. + if self_attention_cache is not None: + input_length = ops.shape(self_attention_cache)[2] + + cache_update_index = ( + 0 + if self_attention_cache_update_index is None + else self_attention_cache_update_index + ) + + causal_mask = compute_causal_mask( + batch_size, input_length, output_length, cache_update_index + ) + + return ( + ops.minimum(decoder_mask, causal_mask) + if decoder_mask is not None + else causal_mask + ) + + def build(self, input_shape): + """ + Builds the sub-layers based on the input shape. + + Args: + input_shape: The input shape to the decoder layer + (batch_size, seq_len, hidden_size). + """ + # input_shape for SmolLM3DecoderLayer: (batch_size, seq_len, hidden_size) + batch_size = input_shape[0] + seq_len = input_shape[1] + + head_dim = self.self_attn.head_dim + pos_emb_shape = (batch_size, seq_len, head_dim) + + attn_mask_shape = (batch_size, 1, seq_len, seq_len) + + # Pass the correct input shape to self_attn's build method + # The input_shape for self_attn.build is a list: + # [hidden_states_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape] + self.self_attn.build( + [input_shape, (pos_emb_shape, pos_emb_shape), attn_mask_shape] + ) + + self.mlp.build(input_shape) + self.input_layernorm.build(input_shape) + self.post_attention_layernorm.build(input_shape) + + super().build(input_shape) + + def call( + self, + hidden_states, + training=False, + decoder_padding_mask=None, + decoder_attention_mask=None, + **kwargs, + ): + """ + Forward pass for SmolLM3DecoderLayer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_len, hidden_size). + position_embeddings: Optional tuple of (cos, sin) tensors for RoPE. + training: Whether the layer is in training mode. + """ + self_attention_cache = kwargs.get("self_attention_cache", None) + self_attention_cache_update_index = kwargs.get( + "self_attention_cache_update_index", None + ) + + self_attention_mask = self._compute_self_attention_mask( + decoder_sequence=hidden_states, + decoder_padding_mask=decoder_padding_mask, + decoder_attention_mask=decoder_attention_mask, + self_attention_cache=self_attention_cache, + self_attention_cache_update_index=self_attention_cache_update_index, + ) + + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + x = self.self_attn( + hidden_states=hidden_states, + training=training, + attention_mask=self_attention_mask, + **kwargs, + ) + + if isinstance(x, tuple): + attn_output, self_attention_cache = x + else: + attn_output = x + + hidden_states = ops.add(residual, attn_output) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = ops.add(residual, hidden_states) + + if self_attention_cache is not None: + return hidden_states, self_attention_cache + else: + return hidden_states + + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: The input shape (batch_size, seq_len, hidden_size). + + Returns: + The output shape, which is the same as the input shape: + (batch_size, seq_len, hidden_size). + """ + return input_shape + + +class SmolLM3RotaryEmbedding(layers.Layer): + """ + Rotary Position Embedding (RoPE) layer for SmolLM3 model. + + Args: + hidden_size: The hidden size of the model. + num_attention_heads: The number of attention heads. + max_position_embeddings: The maximum sequence length for position embeddings. + rope_theta: The theta value for RoPE. + partial_rotary_factor: The factor for partial rotary embedding. + """ + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + max_position_embeddings: int, + rope_theta: float, + partial_rotary_factor: float, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.rope_theta = rope_theta + self.partial_rotary_factor = partial_rotary_factor + + self.head_dim = self.hidden_size // self.num_attention_heads + + inv_freq_tensor, self.attention_scaling = rope_init( + self.rope_theta, self.partial_rotary_factor, self.head_dim + ) + + self.inv_freq = self.add_weight( + name="inv_freq", + shape=ops.shape(inv_freq_tensor), + dtype=inv_freq_tensor.dtype, + initializer=initializers.Constant( + ops.convert_to_numpy(inv_freq_tensor) + ), + trainable=False, # This weight is not trained + ) + self.original_inv_freq = self.inv_freq + + def build(self, input_shape): + """ + Builds the layer. For SmolLM3RotaryEmbedding, this mainly ensures + that the parent layer's build is called. + Args: + input_shape: A list/tuple of shapes for the inputs: + [x_shape, position_ids_shape] + - x_shape: (batch_size, ..., head_dim) + - position_ids_shape: (batch_size, seq_len) + """ + # No internal layers to explicitly build here, as inv_freq is added in __init__ + super().build(input_shape) + + def call( + self, + x, + start_index=0, + ): + """ + Forward pass for SmolLM3RotaryEmbedding. + + Args: + x: Input tensor, typically query or key states. + Shape can vary, but the last dimension is head_dim. + position_ids: Tensor of position IDs of shape (batch_size, seq_len). + """ + inv_freq_expanded = ops.expand_dims( + ops.expand_dims(self.inv_freq, axis=0), axis=-1 + ) + + batch_size = ops.shape(x)[0] + seq_len = ops.shape(x)[1] + positions = ops.arange(seq_len, dtype="float32") + positions = positions + ops.cast(start_index, dtype="float32") + + inv_freq_expanded = ops.broadcast_to( + inv_freq_expanded, (batch_size, ops.shape(self.inv_freq)[0], 1) + ) + + position_ids_expanded = ops.expand_dims(positions, axis=1).T + + freqs = ops.matmul( + ops.cast(inv_freq_expanded, "float32"), + ops.cast(position_ids_expanded, "float32"), + ) + + freqs = ops.transpose(freqs, axes=(0, 2, 1)) + + emb = ops.concatenate((freqs, freqs), axis=-1) + + cos = ops.cos(emb) * self.attention_scaling + sin = ops.sin(emb) * self.attention_scaling + + return ops.cast(cos, x.dtype), ops.cast(sin, x.dtype) + + def compute_output_shape(self, input_shape): + """ + Computes the output shape of the layer. + + Args: + input_shape: A list/tuple of shapes for the inputs: + [x_shape, position_ids_shape] + - x_shape: (batch_size, ..., head_dim) + - position_ids_shape: (batch_size, seq_len) + + Returns: + A list of output shapes for (cos, sin): + [(batch_size, seq_len, head_dim), (batch_size, seq_len, head_dim)] + """ + if input_shape[1] is not None and len(input_shape[1]) >= 2: + batch_size = input_shape[1][0] + seq_len = input_shape[1][1] + else: + # Fallback if position_ids_shape is None or malformed. + # In this case, the batch_size and seq_len are unknown. + batch_size = None + seq_len = None + + # The output cos and sin have shape (batch_size, seq_len, head_dim) + output_shape = (batch_size, seq_len, self.head_dim) + + return [output_shape, output_shape] diff --git a/keras_hub/src/models/smollm3/smollm3_tokenizer.py b/keras_hub/src/models/smollm3/smollm3_tokenizer.py new file mode 100644 index 0000000000..c1df7c5eb4 --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_tokenizer.py @@ -0,0 +1,60 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer + + +@keras_hub_export( + [ + "keras_hub.tokenizers.SmolLM3Tokenizer", + "keras_hub.tokenizers.SmolLMTokenizer", + "keras_hub.models.SmolLM3Tokenizer", + "keras_hub.models.SmolLMTokenizer", + ] +) +class SmolLM3Tokenizer(BytePairTokenizer): + """Tokenizer for SmolLM3 models. + + This tokenizer implements byte-pair encoding (BPE) for SmolLM3 models, + handling special tokens like BOS (beginning of sequence) and EOS (end of + sequence). + + Args: + vocabulary: Dictionary mapping tokens to token IDs, or path to + vocabulary file. + merges: List of BPE merges, or path to merges file. + bos_token: Beginning of sequence token. Defaults to None. + eos_token: End of sequence token. Defaults to "<|endoftext|>". + misc_special_tokens: Set of additional special tokens. Defaults to + empty set. + """ + + backbone_cls = SmolLM3Backbone + + def __init__( + self, + vocabulary=None, + merges=None, + **kwargs, + ): + # Add EOS token + eos_token = "<|end_of_text|>" + self._add_special_token(eos_token, "end_token") + + bos_token = "<|begin_of_text|>" + self._add_special_token(bos_token, "bos_token") + + start_think_token = "" + self._add_special_token(start_think_token, "start_think_token") + + end_think_token = "" + self._add_special_token(end_think_token, "end_think_token") + + self.start_token_id = None + self.start_token = None + self.pad_token_id = 0 + + super().__init__( + vocabulary=vocabulary, + merges=merges, + **kwargs, + ) diff --git a/keras_hub/src/models/smollm3/smollm3_utils.py b/keras_hub/src/models/smollm3/smollm3_utils.py new file mode 100644 index 0000000000..8fb057f363 --- /dev/null +++ b/keras_hub/src/models/smollm3/smollm3_utils.py @@ -0,0 +1,47 @@ +from keras import layers +from keras import ops +from keras import random + + +def rotate_half(x): + x1 = x[..., : ops.shape(x)[-1] // 2] + x2 = x[..., ops.shape(x)[-1] // 2 :] + return ops.concatenate((-x2, x1), axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, expansion_axis=1): + cos = ops.expand_dims(cos, expansion_axis) + sin = ops.expand_dims(sin, expansion_axis) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_single(tensor, cos, sin, expansion_axis=1): + cos = ops.expand_dims(cos, expansion_axis) + sin = ops.expand_dims(sin, expansion_axis) + tensor_embed = (tensor * cos) + (rotate_half(tensor) * sin) + return tensor_embed + + +def repeat_kv(hidden_states, n_rep): + batch, num_key_value_heads, slen, head_dim = ops.shape(hidden_states) + if n_rep == 1: + return hidden_states + hidden_states = ops.expand_dims(hidden_states, axis=2) + target_shape = (batch, num_key_value_heads, n_rep, slen, head_dim) + hidden_states = ops.broadcast_to(hidden_states, target_shape) + return ops.reshape( + hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim] + ) + + +def rope_init(rope_theta: float, partial_rotary_factor: float, head_dim: int): + base = rope_theta + dim = int(head_dim * partial_rotary_factor) + + inv_freq = 1.0 / ( + ops.power(base, ops.arange(0, dim, 2, dtype="float32") / dim) + ) + attention_scaling = 1.0 + return inv_freq, attention_scaling diff --git a/keras_hub/src/utils/transformers/convert_smollm3.py b/keras_hub/src/utils/transformers/convert_smollm3.py new file mode 100644 index 0000000000..23ab7c9210 --- /dev/null +++ b/keras_hub/src/utils/transformers/convert_smollm3.py @@ -0,0 +1,136 @@ +import numpy as np + +from keras_hub.src.models.smollm3.smollm3_backbone import SmolLM3Backbone +from keras_hub.src.utils.preset_utils import load_json + +backbone_cls = SmolLM3Backbone + + +def convert_backbone_config(transformers_config): + return { + "vocabulary_size": transformers_config["vocab_size"], + "hidden_dim": transformers_config["hidden_size"], + "num_layers": transformers_config["num_hidden_layers"], + "num_attention_heads": transformers_config["num_attention_heads"], + "num_key_value_heads": transformers_config["num_key_value_heads"], + "intermediate_dim": transformers_config["intermediate_size"], + "layer_norm_epsilon": transformers_config[ + "rms_norm_eps" + ], # Using rms_norm_eps as layer_norm_epsilon + "max_position_embeddings": transformers_config[ + "max_position_embeddings" + ], + "rope_theta": transformers_config["rope_theta"], + # partial_rotary_factor is not explicitly in config.json + # but is inherited from the default value in the `_compute_default_rope_parameters()` + # function + "partial_rotary_factor": 1.0, + "attention_bias": transformers_config["attention_bias"], + "attention_dropout": transformers_config["attention_dropout"], + "rope_layer_enabled_list": transformers_config["no_rope_layers"], + "layer_types": transformers_config["layer_types"], + "mlp_bias": transformers_config["mlp_bias"] + } + + +def convert_weights(backbone, loader, transformers_config): + loader.port_weight( + keras_variable=backbone.get_layer("token_embedding").embeddings, + hf_weight_key="model.embed_tokens.weight", + ) + + def transpose_and_reshape(x, shape): + return np.reshape(np.transpose(x), shape) + + for i in range(backbone.num_layers): + decoder_layer = backbone.get_layer(f"transformer_layer_{i}") + + # Input layernorm + loader.port_weight( + keras_variable=decoder_layer.input_layernorm.scale, + hf_weight_key=f"model.layers.{i}.input_layernorm.weight", + ) + + # Attention layers + ## Query + loader.port_weight( + keras_variable=decoder_layer.self_attn.q_proj.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.q_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Key + loader.port_weight( + keras_variable=decoder_layer.self_attn.k_proj.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.k_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Value + loader.port_weight( + keras_variable=decoder_layer.self_attn.v_proj.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.v_proj.weight", + hook_fn=transpose_and_reshape, + ) + ## Output + loader.port_weight( + keras_variable=decoder_layer.self_attn.o_proj.kernel, + hf_weight_key=f"model.layers.{i}.self_attn.o_proj.weight", + hook_fn=transpose_and_reshape, + ) + + # MLP layers + loader.port_weight( + keras_variable=decoder_layer.mlp.up_proj.kernel, + hf_weight_key=f"model.layers.{i}.mlp.up_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer.mlp.down_proj.kernel, + hf_weight_key=f"model.layers.{i}.mlp.down_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + loader.port_weight( + keras_variable=decoder_layer.mlp.gate_proj.kernel, + hf_weight_key=f"model.layers.{i}.mlp.gate_proj.weight", + # rearrange_patterns="b a -> a b", + hook_fn=lambda hf_tensor, _: np.transpose(hf_tensor, axes=(1, 0)), + ) + + # Feedforward layernorm + loader.port_weight( + keras_variable=decoder_layer.post_attention_layernorm.scale, + hf_weight_key=f"model.layers.{i}.post_attention_layernorm.weight", + ) + + # Final normalization layer + loader.port_weight( + keras_variable=backbone.get_layer("sequence_output_layernorm").scale, + hf_weight_key="model.norm.weight", + ) + + backbone.training = False + + return backbone + + +def convert_tokenizer(cls, preset, **kwargs): + tokenizer_config = load_json(preset, "tokenizer.json") + vocab = tokenizer_config["model"]["vocab"] + merges = tokenizer_config["model"]["merges"] + merges = [" ".join(item) for item in merges] + + # Load all special tokens with the exception of "reserved" ones. + special_tokens = set() + for token in tokenizer_config["added_tokens"]: + if not token["content"].startswith("<|reserved_special_token_"): + vocab[token["content"]] = token["id"] + special_tokens.add(token["content"]) + + kwargs.update( + { + "unsplittable_tokens": list(special_tokens), + } + ) + + return cls(vocabulary=vocab, merges=merges, **kwargs) diff --git a/keras_hub/src/utils/transformers/preset_loader.py b/keras_hub/src/utils/transformers/preset_loader.py index fe49a9b269..526922505d 100644 --- a/keras_hub/src/utils/transformers/preset_loader.py +++ b/keras_hub/src/utils/transformers/preset_loader.py @@ -17,6 +17,7 @@ from keras_hub.src.utils.transformers import convert_qwen from keras_hub.src.utils.transformers import convert_qwen3 from keras_hub.src.utils.transformers import convert_qwen_moe +from keras_hub.src.utils.transformers import convert_smollm3 from keras_hub.src.utils.transformers import convert_vit from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -56,6 +57,8 @@ def __init__(self, preset, config): self.converter = convert_qwen_moe elif model_type == "qwen3": self.converter = convert_qwen3 + elif model_type == "smollm3": + self.converter = convert_smollm3 else: raise ValueError( "KerasHub has no converter for huggingface/transformers models "