diff --git a/classy_train.py b/classy_train.py index e693220057..7b036cbc80 100755 --- a/classy_train.py +++ b/classy_train.py @@ -49,6 +49,7 @@ from classy_vision.generic.util import load_checkpoint, load_json from classy_vision.hooks import ( CheckpointHook, + ExponentialMovingAverageModelHook, LossLrMeterLoggingHook, ModelComplexityHook, ProfilerHook, @@ -152,6 +153,8 @@ def configure_hooks(args, config): hooks.append(ProgressBarHook()) if args.visdom_server != "": hooks.append(VisdomHook(args.visdom_server, args.visdom_port)) + if args.ema_decay > 0: + hooks.append(ExponentialMovingAverageModelHook(args.ema_decay)) return hooks diff --git a/classy_vision/generic/opts.py b/classy_vision/generic/opts.py index e4d9cc5e0c..19d98b2461 100644 --- a/classy_vision/generic/opts.py +++ b/classy_vision/generic/opts.py @@ -121,6 +121,12 @@ def add_generic_args(parser): help="""Distributed backend: either 'none' (for non-distributed runs) or 'ddp' (for distributed runs). Default none.""", ) + parser.add_argument( + "--ema_decay", + default=0, + type=float, + help="""Decay rate of model Exponential Moving Averaging""", + ) return parser