@@ -229,18 +229,63 @@ def keras_to_jte_optimizer(
229
229
# pylint: disable-next=protected-access
230
230
learning_rate = keras_to_jte_learning_rate (optimizer ._learning_rate )
231
231
232
- # SGD or Adagrad
232
+ # Unsupported keras optimizer general options.
233
+ if optimizer .clipnorm is not None :
234
+ raise ValueError ("Unsupported optimizer option `clipnorm`." )
235
+ if optimizer .global_clipnorm is not None :
236
+ raise ValueError ("Unsupported optimizer option `global_clipnorm`." )
237
+ if optimizer .use_ema :
238
+ raise ValueError ("Unsupported optimizer option `use_ema`." )
239
+ if optimizer .loss_scale_factor is not None :
240
+ raise ValueError ("Unsupported optimizer option `loss_scale_factor`." )
241
+
242
+ # Supported optimizers.
233
243
if isinstance (optimizer , keras .optimizers .SGD ):
244
+ if getattr (optimizer , "nesterov" , False ):
245
+ raise ValueError ("Unsupported optimizer option `nesterov`." )
246
+ if getattr (optimizer , "momentum" , 0.0 ) != 0.0 :
247
+ raise ValueError ("Unsupported optimizer option `momentum`." )
234
248
return embedding_spec .SGDOptimizerSpec (learning_rate = learning_rate )
235
249
elif isinstance (optimizer , keras .optimizers .Adagrad ):
250
+ if getattr (optimizer , "epsilon" , 1e-7 ) != 1e-7 :
251
+ raise ValueError ("Unsupported optimizer option `epsilon`." )
236
252
return embedding_spec .AdagradOptimizerSpec (
237
253
learning_rate = learning_rate ,
238
254
initial_accumulator_value = optimizer .initial_accumulator_value ,
239
255
)
256
+ elif isinstance (optimizer , keras .optimizers .Adam ):
257
+ if getattr (optimizer , "amsgrad" , False ):
258
+ raise ValueError ("Unsupported optimizer option `amsgrad`." )
240
259
241
- # Default to SGD for now, since other optimizers are still being created,
242
- # and we don't want to fail.
243
- return embedding_spec .SGDOptimizerSpec (learning_rate = learning_rate )
260
+ return embedding_spec .AdamOptimizerSpec (
261
+ learning_rate = learning_rate ,
262
+ beta_1 = optimizer .beta_1 ,
263
+ beta_2 = optimizer .beta_2 ,
264
+ epsilon = optimizer .epsilon ,
265
+ )
266
+ elif isinstance (optimizer , keras .optimizers .Ftrl ):
267
+ if (
268
+ getattr (optimizer , "l2_shrinkage_regularization_strength" , 0.0 )
269
+ != 0.0
270
+ ):
271
+ raise ValueError (
272
+ "Unsupported optimizer option "
273
+ "`l2_shrinkage_regularization_strength`."
274
+ )
275
+
276
+ return embedding_spec .FTRLOptimizerSpec (
277
+ learning_rate = learning_rate ,
278
+ learning_rate_power = optimizer .learning_rate_power ,
279
+ l1_regularization_strength = optimizer .l1_regularization_strength ,
280
+ l2_regularization_strength = optimizer .l2_regularization_strength ,
281
+ beta = optimizer .beta ,
282
+ initial_accumulator_value = optimizer .initial_accumulator_value ,
283
+ )
284
+
285
+ raise ValueError (
286
+ f"Unsupported optimizer type { type (optimizer )} . Optimizer must be "
287
+ f"one of [Adagrad, Adam, Ftrl, SGD]."
288
+ )
244
289
245
290
246
291
def jte_to_keras_optimizer (
@@ -262,8 +307,33 @@ def jte_to_keras_optimizer(
262
307
learning_rate = learning_rate ,
263
308
initial_accumulator_value = optimizer .initial_accumulator_value ,
264
309
)
310
+ elif isinstance (optimizer , embedding_spec .AdamOptimizerSpec ):
311
+ return keras .optimizers .Adam (
312
+ learning_rate = learning_rate ,
313
+ beta_1 = optimizer .beta_1 ,
314
+ beta_2 = optimizer .beta_2 ,
315
+ epsilon = optimizer .epsilon ,
316
+ )
317
+ elif isinstance (optimizer , embedding_spec .FTRLOptimizerSpec ):
318
+ if getattr (optimizer , "initial_linear_value" , 0.0 ) != 0.0 :
319
+ raise ValueError (
320
+ "Unsupported optimizer option `initial_linear_value`."
321
+ )
322
+ if getattr (optimizer , "multiply_linear_by_learning_rate" , False ):
323
+ raise ValueError (
324
+ "Unsupported optimizer option "
325
+ "`multiply_linear_by_learning_rate`."
326
+ )
327
+ return keras .optimizers .Ftrl (
328
+ learning_rate = learning_rate ,
329
+ learning_rate_power = optimizer .learning_rate_power ,
330
+ initial_accumulator_value = optimizer .initial_accumulator_value ,
331
+ l1_regularization_strength = optimizer .l1_regularization_strength ,
332
+ l2_regularization_strength = optimizer .l2_regularization_strength ,
333
+ beta = optimizer .beta ,
334
+ )
265
335
266
- raise ValueError (f"Unknown optimizer spec { optimizer } " )
336
+ raise ValueError (f"Unknown optimizer spec { type ( optimizer ) } . " )
267
337
268
338
269
339
def _keras_to_jte_table_config (
0 commit comments