diff --git a/keras_rs/src/layers/embedding/base_distributed_embedding.py b/keras_rs/src/layers/embedding/base_distributed_embedding.py index 8d29cd7..c30993f 100644 --- a/keras_rs/src/layers/embedding/base_distributed_embedding.py +++ b/keras_rs/src/layers/embedding/base_distributed_embedding.py @@ -174,12 +174,9 @@ class DistributedEmbedding(keras.layers.Layer): supported on all backends and accelerators: - `keras.optimizers.Adagrad` - - `keras.optimizers.SGD` - - The following are additionally available when using the TensorFlow backend: - - `keras.optimizers.Adam` - `keras.optimizers.Ftrl` + - `keras.optimizers.SGD` Also, not all parameters of the optimizers are supported (e.g. the `nesterov` option of `SGD`). An error is raised when an unsupported diff --git a/keras_rs/src/layers/embedding/jax/config_conversion.py b/keras_rs/src/layers/embedding/jax/config_conversion.py index 25a3ef8..1d11908 100644 --- a/keras_rs/src/layers/embedding/jax/config_conversion.py +++ b/keras_rs/src/layers/embedding/jax/config_conversion.py @@ -229,18 +229,63 @@ def keras_to_jte_optimizer( # pylint: disable-next=protected-access learning_rate = keras_to_jte_learning_rate(optimizer._learning_rate) - # SGD or Adagrad + # Unsupported keras optimizer general options. + if optimizer.clipnorm is not None: + raise ValueError("Unsupported optimizer option `clipnorm`.") + if optimizer.global_clipnorm is not None: + raise ValueError("Unsupported optimizer option `global_clipnorm`.") + if optimizer.use_ema: + raise ValueError("Unsupported optimizer option `use_ema`.") + if optimizer.loss_scale_factor is not None: + raise ValueError("Unsupported optimizer option `loss_scale_factor`.") + + # Supported optimizers. if isinstance(optimizer, keras.optimizers.SGD): + if getattr(optimizer, "nesterov", False): + raise ValueError("Unsupported optimizer option `nesterov`.") + if getattr(optimizer, "momentum", 0.0) != 0.0: + raise ValueError("Unsupported optimizer option `momentum`.") return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate) elif isinstance(optimizer, keras.optimizers.Adagrad): + if getattr(optimizer, "epsilon", 1e-7) != 1e-7: + raise ValueError("Unsupported optimizer option `epsilon`.") return embedding_spec.AdagradOptimizerSpec( learning_rate=learning_rate, initial_accumulator_value=optimizer.initial_accumulator_value, ) + elif isinstance(optimizer, keras.optimizers.Adam): + if getattr(optimizer, "amsgrad", False): + raise ValueError("Unsupported optimizer option `amsgrad`.") - # Default to SGD for now, since other optimizers are still being created, - # and we don't want to fail. - return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate) + return embedding_spec.AdamOptimizerSpec( + learning_rate=learning_rate, + beta_1=optimizer.beta_1, + beta_2=optimizer.beta_2, + epsilon=optimizer.epsilon, + ) + elif isinstance(optimizer, keras.optimizers.Ftrl): + if ( + getattr(optimizer, "l2_shrinkage_regularization_strength", 0.0) + != 0.0 + ): + raise ValueError( + "Unsupported optimizer option " + "`l2_shrinkage_regularization_strength`." + ) + + return embedding_spec.FTRLOptimizerSpec( + learning_rate=learning_rate, + learning_rate_power=optimizer.learning_rate_power, + l1_regularization_strength=optimizer.l1_regularization_strength, + l2_regularization_strength=optimizer.l2_regularization_strength, + beta=optimizer.beta, + initial_accumulator_value=optimizer.initial_accumulator_value, + ) + + raise ValueError( + f"Unsupported optimizer type {type(optimizer)}. Optimizer must be " + f"one of [Adagrad, Adam, Ftrl, SGD]." + ) def jte_to_keras_optimizer( @@ -262,8 +307,33 @@ def jte_to_keras_optimizer( learning_rate=learning_rate, initial_accumulator_value=optimizer.initial_accumulator_value, ) + elif isinstance(optimizer, embedding_spec.AdamOptimizerSpec): + return keras.optimizers.Adam( + learning_rate=learning_rate, + beta_1=optimizer.beta_1, + beta_2=optimizer.beta_2, + epsilon=optimizer.epsilon, + ) + elif isinstance(optimizer, embedding_spec.FTRLOptimizerSpec): + if getattr(optimizer, "initial_linear_value", 0.0) != 0.0: + raise ValueError( + "Unsupported optimizer option `initial_linear_value`." + ) + if getattr(optimizer, "multiply_linear_by_learning_rate", False): + raise ValueError( + "Unsupported optimizer option " + "`multiply_linear_by_learning_rate`." + ) + return keras.optimizers.Ftrl( + learning_rate=learning_rate, + learning_rate_power=optimizer.learning_rate_power, + initial_accumulator_value=optimizer.initial_accumulator_value, + l1_regularization_strength=optimizer.l1_regularization_strength, + l2_regularization_strength=optimizer.l2_regularization_strength, + beta=optimizer.beta, + ) - raise ValueError(f"Unknown optimizer spec {optimizer}") + raise ValueError(f"Unknown optimizer spec {type(optimizer)}.") def _keras_to_jte_table_config( diff --git a/keras_rs/src/layers/embedding/jax/config_conversion_test.py b/keras_rs/src/layers/embedding/jax/config_conversion_test.py index 7b4c987..166e339 100644 --- a/keras_rs/src/layers/embedding/jax/config_conversion_test.py +++ b/keras_rs/src/layers/embedding/jax/config_conversion_test.py @@ -239,6 +239,13 @@ def test_initializer_conversion( ), ), ("Adagrad", lambda: keras.optimizers.Adagrad(learning_rate=0.02)), + ("Adam", lambda: keras.optimizers.Adam(learning_rate=0.03)), + ( + "Ftrl", + lambda: keras.optimizers.Ftrl( + learning_rate=0.05, + ), + ), ("string", lambda: "adagrad"), ) def test_optimizer_conversion(