Skip to content

Commit ed273b1

Browse files
authored
Add FTRL and Adam optimizers. (#123)
1 parent be49522 commit ed273b1

File tree

3 files changed

+83
-9
lines changed

3 files changed

+83
-9
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,9 @@ class DistributedEmbedding(keras.layers.Layer):
174174
supported on all backends and accelerators:
175175
176176
- `keras.optimizers.Adagrad`
177-
- `keras.optimizers.SGD`
178-
179-
The following are additionally available when using the TensorFlow backend:
180-
181177
- `keras.optimizers.Adam`
182178
- `keras.optimizers.Ftrl`
179+
- `keras.optimizers.SGD`
183180
184181
Also, not all parameters of the optimizers are supported (e.g. the
185182
`nesterov` option of `SGD`). An error is raised when an unsupported

keras_rs/src/layers/embedding/jax/config_conversion.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,18 +229,63 @@ def keras_to_jte_optimizer(
229229
# pylint: disable-next=protected-access
230230
learning_rate = keras_to_jte_learning_rate(optimizer._learning_rate)
231231

232-
# SGD or Adagrad
232+
# Unsupported keras optimizer general options.
233+
if optimizer.clipnorm is not None:
234+
raise ValueError("Unsupported optimizer option `clipnorm`.")
235+
if optimizer.global_clipnorm is not None:
236+
raise ValueError("Unsupported optimizer option `global_clipnorm`.")
237+
if optimizer.use_ema:
238+
raise ValueError("Unsupported optimizer option `use_ema`.")
239+
if optimizer.loss_scale_factor is not None:
240+
raise ValueError("Unsupported optimizer option `loss_scale_factor`.")
241+
242+
# Supported optimizers.
233243
if isinstance(optimizer, keras.optimizers.SGD):
244+
if getattr(optimizer, "nesterov", False):
245+
raise ValueError("Unsupported optimizer option `nesterov`.")
246+
if getattr(optimizer, "momentum", 0.0) != 0.0:
247+
raise ValueError("Unsupported optimizer option `momentum`.")
234248
return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate)
235249
elif isinstance(optimizer, keras.optimizers.Adagrad):
250+
if getattr(optimizer, "epsilon", 1e-7) != 1e-7:
251+
raise ValueError("Unsupported optimizer option `epsilon`.")
236252
return embedding_spec.AdagradOptimizerSpec(
237253
learning_rate=learning_rate,
238254
initial_accumulator_value=optimizer.initial_accumulator_value,
239255
)
256+
elif isinstance(optimizer, keras.optimizers.Adam):
257+
if getattr(optimizer, "amsgrad", False):
258+
raise ValueError("Unsupported optimizer option `amsgrad`.")
240259

241-
# Default to SGD for now, since other optimizers are still being created,
242-
# and we don't want to fail.
243-
return embedding_spec.SGDOptimizerSpec(learning_rate=learning_rate)
260+
return embedding_spec.AdamOptimizerSpec(
261+
learning_rate=learning_rate,
262+
beta_1=optimizer.beta_1,
263+
beta_2=optimizer.beta_2,
264+
epsilon=optimizer.epsilon,
265+
)
266+
elif isinstance(optimizer, keras.optimizers.Ftrl):
267+
if (
268+
getattr(optimizer, "l2_shrinkage_regularization_strength", 0.0)
269+
!= 0.0
270+
):
271+
raise ValueError(
272+
"Unsupported optimizer option "
273+
"`l2_shrinkage_regularization_strength`."
274+
)
275+
276+
return embedding_spec.FTRLOptimizerSpec(
277+
learning_rate=learning_rate,
278+
learning_rate_power=optimizer.learning_rate_power,
279+
l1_regularization_strength=optimizer.l1_regularization_strength,
280+
l2_regularization_strength=optimizer.l2_regularization_strength,
281+
beta=optimizer.beta,
282+
initial_accumulator_value=optimizer.initial_accumulator_value,
283+
)
284+
285+
raise ValueError(
286+
f"Unsupported optimizer type {type(optimizer)}. Optimizer must be "
287+
f"one of [Adagrad, Adam, Ftrl, SGD]."
288+
)
244289

245290

246291
def jte_to_keras_optimizer(
@@ -262,8 +307,33 @@ def jte_to_keras_optimizer(
262307
learning_rate=learning_rate,
263308
initial_accumulator_value=optimizer.initial_accumulator_value,
264309
)
310+
elif isinstance(optimizer, embedding_spec.AdamOptimizerSpec):
311+
return keras.optimizers.Adam(
312+
learning_rate=learning_rate,
313+
beta_1=optimizer.beta_1,
314+
beta_2=optimizer.beta_2,
315+
epsilon=optimizer.epsilon,
316+
)
317+
elif isinstance(optimizer, embedding_spec.FTRLOptimizerSpec):
318+
if getattr(optimizer, "initial_linear_value", 0.0) != 0.0:
319+
raise ValueError(
320+
"Unsupported optimizer option `initial_linear_value`."
321+
)
322+
if getattr(optimizer, "multiply_linear_by_learning_rate", False):
323+
raise ValueError(
324+
"Unsupported optimizer option "
325+
"`multiply_linear_by_learning_rate`."
326+
)
327+
return keras.optimizers.Ftrl(
328+
learning_rate=learning_rate,
329+
learning_rate_power=optimizer.learning_rate_power,
330+
initial_accumulator_value=optimizer.initial_accumulator_value,
331+
l1_regularization_strength=optimizer.l1_regularization_strength,
332+
l2_regularization_strength=optimizer.l2_regularization_strength,
333+
beta=optimizer.beta,
334+
)
265335

266-
raise ValueError(f"Unknown optimizer spec {optimizer}")
336+
raise ValueError(f"Unknown optimizer spec {type(optimizer)}.")
267337

268338

269339
def _keras_to_jte_table_config(

keras_rs/src/layers/embedding/jax/config_conversion_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,13 @@ def test_initializer_conversion(
239239
),
240240
),
241241
("Adagrad", lambda: keras.optimizers.Adagrad(learning_rate=0.02)),
242+
("Adam", lambda: keras.optimizers.Adam(learning_rate=0.03)),
243+
(
244+
"Ftrl",
245+
lambda: keras.optimizers.Ftrl(
246+
learning_rate=0.05,
247+
),
248+
),
242249
("string", lambda: "adagrad"),
243250
)
244251
def test_optimizer_conversion(

0 commit comments

Comments
 (0)