From 128aa8d8e3e358a05dae4760b450d80201f35dae Mon Sep 17 00:00:00 2001 From: Mannat Singh Date: Wed, 4 Nov 2020 09:58:47 -0800 Subject: [PATCH] Create function to see if loss has learable parameters Summary: Re-usable function which will be used in the next diff Differential Revision: D24729686 fbshipit-source-id: e3c25d78b70c792291ef418f016d5411bab270e7 --- classy_vision/tasks/classification_task.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 443ce61dcb..526ae8d58d 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -719,10 +719,7 @@ def init_distributed_data_parallel_model(self): broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, ) - if ( - isinstance(self.base_loss, ClassyLoss) - and self.base_loss.has_learned_parameters() - ): + if self._loss_has_learnable_params(): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, @@ -1014,6 +1011,13 @@ def _broadcast_buffers(self): for buffer in buffers: broadcast(buffer, 0, group=self.distributed_model.process_group) + def _loss_has_learnable_params(self): + """Returns True if the loss has any learnable parameters""" + return ( + isinstance(self.base_loss, ClassyLoss) + and self.base_loss.has_learned_parameters() + ) + # TODO: Functions below should be better abstracted into the dataloader # abstraction def get_batchsize_per_replica(self):