From 906cfc836e3eb572ad532c03c1df804c52b5e992 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Tue, 27 Oct 2020 22:37:09 -0700 Subject: [PATCH 1/4] change FLOPS compute in profiler Differential Revision: D21949491 fbshipit-source-id: fc85f28421169ae7617e781592773802e4dd5c62 --- classy_vision/generic/profiler.py | 232 +++++++++++++----------------- 1 file changed, 98 insertions(+), 134 deletions(-) diff --git a/classy_vision/generic/profiler.py b/classy_vision/generic/profiler.py index 020f67bcd5..da2d7e6a6c 100644 --- a/classy_vision/generic/profiler.py +++ b/classy_vision/generic/profiler.py @@ -86,7 +86,24 @@ def get_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]: return x.size() -def _layer_flops(layer: nn.Module, x: Any, y: Any) -> int: +def _get_batchsize_per_replica(x: Union[Tuple, List, Dict]) -> int: + """ + Some layer may take tuple/list/dict/list[dict] as input in forward function. We + recursively dive into the tuple/list until we meet a tensor and infer the batch size + """ + while isinstance(x, (list, tuple)): + assert len(x) > 0, "input x of tuple/list type must have at least one element" + x = x[0] + + if isinstance(x, (dict,)): + # index zero is always equal to batch size. select an arbitrary key. + key_list = list(x.keys()) + x = x[key_list[0]] + + return x.size()[0] + + +def _layer_flops(layer: nn.Module, x: Any, y: Any, verbose: bool = False) -> int: """ Computes the number of FLOPs required for a single layer. @@ -146,6 +163,36 @@ def flops(self, x): / layer.groups ) + # 3D convolution + elif layer_type in ["Conv3d"]: + out_t = int( + (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) + // layer.stride[0] + + 1 + ) + out_h = int( + (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) + // layer.stride[1] + + 1 + ) + out_w = int( + (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) + // layer.stride[2] + + 1 + ) + flops = ( + batchsize_per_replica + * layer.in_channels + * layer.out_channels + * layer.kernel_size[0] + * layer.kernel_size[1] + * layer.kernel_size[2] + * out_t + * out_h + * out_w + / layer.groups + ) + # learned group convolution: elif layer_type in ["LearnedGroupConv"]: conv = layer.conv @@ -170,51 +217,36 @@ def flops(self, x): ) flops = count1 + count2 - # non-linearities: + # non-linearities are not considered in MAC counting elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]: - flops = x.numel() - - # 2D pooling layers: - elif layer_type in ["AvgPool2d", "MaxPool2d"]: - in_h = x.size()[2] - in_w = x.size()[3] - if isinstance(layer.kernel_size, int): - layer.kernel_size = (layer.kernel_size, layer.kernel_size) - kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] - out_h = 1 + int( - (in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride - ) - out_w = 1 + int( - (in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride + flops = 0 + + elif layer_type in [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + ]: + flops = 0 + + elif layer_type in ["AvgPool1d", "AvgPool2d", "AvgPool3d"]: + kernel_ops = 1 + flops = kernel_ops * y.numel() + + elif layer_type in ["AdaptiveAvgPool1d", "AdaptiveAvgPool2d", "AdaptiveAvgPool3d"]: + assert isinstance(layer.output_size, (list, tuple)) + kernel = torch.Tensor(list(x.shape[2:])) // torch.Tensor( + [list(layer.output_size)] ) - flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops - - # adaptive avg pool2d - # This is approximate and works only for downsampling without padding - # based on aten/src/ATen/native/AdaptiveAveragePooling.cpp - elif layer_type in ["AdaptiveAvgPool2d"]: - in_h = x.size()[2] - in_w = x.size()[3] - if isinstance(layer.output_size, int): - out_h, out_w = layer.output_size, layer.output_size - elif len(layer.output_size) == 1: - out_h, out_w = layer.output_size[0], layer.output_size[0] - else: - out_h, out_w = layer.output_size - if out_h > in_h or out_w > in_w: - raise ClassyProfilerNotImplementedError(layer) - batchsize_per_replica = x.size()[0] - num_channels = x.size()[1] - kh = in_h - out_h + 1 - kw = in_w - out_w + 1 - kernel_ops = kh * kw - flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops + kernel_ops = torch.prod(kernel) + flops = kernel_ops * y.numel() # linear layer: elif layer_type in ["Linear"]: weight_ops = layer.weight.numel() - bias_ops = layer.bias.numel() if layer.bias is not None else 0 - flops = x.size()[0] * (weight_ops + bias_ops) + flops = x.size()[0] * weight_ops # batch normalization / layer normalization: elif layer_type in [ @@ -224,94 +256,12 @@ def flops(self, x): "SyncBatchNorm", "LayerNorm", ]: - flops = 2 * x.numel() - - # 3D convolution - elif layer_type in ["Conv3d"]: - out_t = int( - (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) - // layer.stride[0] - + 1 - ) - out_h = int( - (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) - // layer.stride[1] - + 1 - ) - out_w = int( - (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) - // layer.stride[2] - + 1 - ) - flops = ( - batchsize_per_replica - * layer.in_channels - * layer.out_channels - * layer.kernel_size[0] - * layer.kernel_size[1] - * layer.kernel_size[2] - * out_t - * out_h - * out_w - / layer.groups - ) - - # 3D pooling layers - elif layer_type in ["AvgPool3d", "MaxPool3d"]: - in_t = x.size()[2] - in_h = x.size()[3] - in_w = x.size()[4] - if isinstance(layer.kernel_size, int): - layer.kernel_size = ( - layer.kernel_size, - layer.kernel_size, - layer.kernel_size, - ) - if isinstance(layer.padding, int): - layer.padding = (layer.padding, layer.padding, layer.padding) - if isinstance(layer.stride, int): - layer.stride = (layer.stride, layer.stride, layer.stride) - kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2] - out_t = 1 + int( - (in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] - ) - out_h = 1 + int( - (in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1] - ) - out_w = 1 + int( - (in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2] - ) - flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops - - # adaptive avg pool3d - # This is approximate and works only for downsampling without padding - # based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp - elif layer_type in ["AdaptiveAvgPool3d"]: - in_t = x.size()[2] - in_h = x.size()[3] - in_w = x.size()[4] - out_t = layer.output_size[0] - out_h = layer.output_size[1] - out_w = layer.output_size[2] - if out_t > in_t or out_h > in_h or out_w > in_w: - raise ClassyProfilerNotImplementedError(layer) - batchsize_per_replica = x.size()[0] - num_channels = x.size()[1] - kt = in_t - out_t + 1 - kh = in_h - out_h + 1 - kw = in_w - out_w + 1 - kernel_ops = kt * kh * kw - flops = ( - batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops - ) + # batchnorm can be merged into conv op. Thus, count 0 FLOPS + flops = 0 # dropout layer elif layer_type in ["Dropout"]: - # At test time, we do not drop values but scale the feature map by the - # dropout ratio - flops = 1 - for dim_size in x.size(): - flops *= dim_size + flops = 0 elif layer_type == "Identity": flops = 0 @@ -335,11 +285,14 @@ def flops(self, x): f"params(M): {count_params(layer) / 1e6}", f"flops(M): {int(flops) / 1e6}", ] - logging.debug("\t".join(message)) + if verbose: + logging.info("\t".join(message)) return flops -def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int: +def _layer_activations( + layer: nn.Module, x: Any, out: Any, verbose: bool = False +) -> int: """ Computes the number of activations produced by a single layer. @@ -360,7 +313,8 @@ def activations(self, x, out): return 0 message = [f"module: {typestr}", f"activations: {activations}"] - logging.debug("\t".join(message)) + if verbose: + logging.info("\t".join(message)) return activations @@ -386,17 +340,19 @@ def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str: class ComplexityComputer: - def __init__(self, compute_fn: Callable, count_unique: bool): + def __init__(self, compute_fn: Callable, count_unique: bool, verbose: bool = False): self.compute_fn = compute_fn self.count_unique = count_unique self.count = 0 + self.verbose = verbose self.seen_modules = set() def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str): if self.count_unique and module_name in self.seen_modules: return - logging.debug(f"module name: {module_name}") - self.count += self.compute_fn(layer, x, out) + if self.verbose: + logging.info(f"module name: {module_name}") + self.count += self.compute_fn(layer, x, out, self.verbose) self.seen_modules.add(module_name) def reset(self): @@ -482,6 +438,7 @@ def compute_complexity( input_key: Optional[Union[str, List[str]]] = None, patch_attr: str = None, compute_unique: bool = False, + verbose: bool = False, ) -> int: """ Compute the complexity of a forward pass. @@ -501,7 +458,7 @@ def compute_complexity( else: input = get_model_dummy_input(model, input_shape, input_key) - complexity_computer = ComplexityComputer(compute_fn, compute_unique) + complexity_computer = ComplexityComputer(compute_fn, compute_unique, verbose) # measure FLOPs: modify_forward(model, complexity_computer, patch_attr=patch_attr) @@ -519,12 +476,13 @@ def compute_flops( model: nn.Module, input_shape: Tuple[int] = (3, 224, 224), input_key: Optional[Union[str, List[str]]] = None, + verbose: bool = False, ) -> int: """ Compute the number of FLOPs needed for a forward pass. """ return compute_complexity( - model, _layer_flops, input_shape, input_key, patch_attr="flops" + model, _layer_flops, input_shape, input_key, patch_attr="flops", verbose=verbose ) @@ -532,12 +490,18 @@ def compute_activations( model: nn.Module, input_shape: Tuple[int] = (3, 224, 224), input_key: Optional[Union[str, List[str]]] = None, + verbose: bool = False, ) -> int: """ Compute the number of activations created in a forward pass. """ return compute_complexity( - model, _layer_activations, input_shape, input_key, patch_attr="activations" + model, + _layer_activations, + input_shape, + input_key, + patch_attr="activations", + verbose=verbose, ) From a745189f4b1d82d159f6f8edb742e15b72862a1f Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Tue, 27 Oct 2020 22:37:09 -0700 Subject: [PATCH 2/4] reduce memory use of PreciseBatchNormHook Differential Revision: D23191690 fbshipit-source-id: e3ff067ed472f1b9e36b598a921a2d4930cddd5c --- .../hooks/precise_batch_norm_hook.py | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/classy_vision/hooks/precise_batch_norm_hook.py b/classy_vision/hooks/precise_batch_norm_hook.py index de08fde558..1304d88e42 100644 --- a/classy_vision/hooks/precise_batch_norm_hook.py +++ b/classy_vision/hooks/precise_batch_norm_hook.py @@ -4,21 +4,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch -from classy_vision.generic.util import ( - get_batchsize_per_replica, - recursive_copy_to_device, - recursive_copy_to_gpu, -) +import logging + +from classy_vision.generic.util import get_batchsize_per_replica, recursive_copy_to_gpu from classy_vision.hooks import ClassyHook, register_hook from fvcore.nn.precise_bn import update_bn_stats -def _get_iterator(cache, use_gpu): - for elem in cache: +def _get_iterator(data_iter, use_gpu): + for elem in data_iter: if use_gpu: elem = recursive_copy_to_gpu(elem, non_blocking=True) - yield elem + yield elem["input"] @register_hook("precise_bn") @@ -32,9 +29,7 @@ class PreciseBatchNormHook(ClassyHook): fvcore/nn/precise_bn.py>`_ for more information. """ - on_start = ClassyHook._noop on_phase_start = ClassyHook._noop - on_step = ClassyHook._noop on_end = ClassyHook._noop def __init__(self, num_samples: int) -> None: @@ -49,31 +44,32 @@ def __init__(self, num_samples: int) -> None: if num_samples <= 0: raise ValueError("num_samples has to be a positive integer") self.num_samples = num_samples - self.cache = [] - self.current_samples = 0 + self.batch_size = None @classmethod def from_config(cls, config): return cls(config["num_samples"]) - def on_phase_start(self, task) -> None: - self.cache = [] - self.current_samples = 0 + def on_start(self, task) -> None: + logging.info("Use precise BatchNorm hook") def on_step(self, task) -> None: - if not task.train or self.current_samples >= self.num_samples: + if not task.train or self.batch_size is not None: return - input = recursive_copy_to_device( - task.last_batch.sample["input"], - non_blocking=True, - device=torch.device("cpu"), - ) - self.cache.append(input) - self.current_samples += get_batchsize_per_replica(input) + + self.batch_size = get_batchsize_per_replica(task.last_batch.sample["input"]) def on_phase_end(self, task) -> None: if not task.train: return - iterator = _get_iterator(self.cache, task.use_gpu) - num_batches = len(self.cache) + + num_batches = (self.num_samples + self.batch_size - 1) // self.batch_size + + task.build_dataloaders_for_current_phase() + task.create_data_iterators() + if num_batches > len(task.data_iterator): + num_batches = len(task.data_iterator) + logging.info(f"Reduce no. of samples to {num_batches * self.batch_size}") + + iterator = _get_iterator(task.data_iterator, task.use_gpu) update_bn_stats(task.base_model, iterator, num_batches) From 317ef61b968ffc05b136a3722c53e0b1d0e53439 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Tue, 27 Oct 2020 22:37:09 -0700 Subject: [PATCH 3/4] misc changes Differential Revision: D23261222 fbshipit-source-id: 46643b527cceb66bcc4f88af905615b61cafe036 --- classy_train.py | 2 +- .../dataset/dataloader_limit_wrapper.py | 2 +- classy_vision/generic/profiler.py | 11 +++---- .../hooks/loss_lr_meter_logging_hook.py | 8 +++++ classy_vision/hooks/model_complexity_hook.py | 10 +++---- classy_vision/tasks/fine_tuning_task.py | 30 +++++++++++-------- 6 files changed, 38 insertions(+), 25 deletions(-) diff --git a/classy_train.py b/classy_train.py index e693220057..7b942496d0 100755 --- a/classy_train.py +++ b/classy_train.py @@ -117,7 +117,7 @@ def main(args, config): def configure_hooks(args, config): - hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] + hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook(verbose=True)] # Make a folder to store checkpoints and tensorboard logging outputs suffix = datetime.now().isoformat() diff --git a/classy_vision/dataset/dataloader_limit_wrapper.py b/classy_vision/dataset/dataloader_limit_wrapper.py index 6e6e347636..787ebe5339 100644 --- a/classy_vision/dataset/dataloader_limit_wrapper.py +++ b/classy_vision/dataset/dataloader_limit_wrapper.py @@ -58,7 +58,7 @@ def __next__(self) -> Any: if self.wrap_around: # create a new iterator to load data from the beginning logging.info( - f"Wrapping around after {self._count} calls. Limit: {self.limit}" + f"Wrapping around after {self._count - 1} calls. Limit: {self.limit}" ) try: self._iter = iter(self.dataloader) diff --git a/classy_vision/generic/profiler.py b/classy_vision/generic/profiler.py index da2d7e6a6c..712978f3bb 100644 --- a/classy_vision/generic/profiler.py +++ b/classy_vision/generic/profiler.py @@ -246,7 +246,8 @@ def flops(self, x): # linear layer: elif layer_type in ["Linear"]: weight_ops = layer.weight.numel() - flops = x.size()[0] * weight_ops + bias_ops = layer.bias.numel() if layer.bias is not None else 0 + flops = x.size()[0] * (weight_ops + bias_ops) # batch normalization / layer normalization: elif layer_type in [ @@ -287,7 +288,7 @@ def flops(self, x): ] if verbose: logging.info("\t".join(message)) - return flops + return int(flops) def _layer_activations( @@ -315,7 +316,7 @@ def activations(self, x, out): message = [f"module: {typestr}", f"activations: {activations}"] if verbose: logging.info("\t".join(message)) - return activations + return int(activations) def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str: @@ -350,9 +351,9 @@ def __init__(self, compute_fn: Callable, count_unique: bool, verbose: bool = Fal def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str): if self.count_unique and module_name in self.seen_modules: return - if self.verbose: - logging.info(f"module name: {module_name}") self.count += self.compute_fn(layer, x, out, self.verbose) + if self.verbose: + logging.info(f"module name: {module_name}, count {self.count}") self.seen_modules.add(module_name) def reset(self): diff --git a/classy_vision/hooks/loss_lr_meter_logging_hook.py b/classy_vision/hooks/loss_lr_meter_logging_hook.py index 99078efb27..155bb30ebb 100644 --- a/classy_vision/hooks/loss_lr_meter_logging_hook.py +++ b/classy_vision/hooks/loss_lr_meter_logging_hook.py @@ -7,6 +7,7 @@ import logging from typing import Optional +import torch from classy_vision.generic.distributed_util import get_rank from classy_vision.hooks import register_hook from classy_vision.hooks.classy_hook import ClassyHook @@ -49,6 +50,13 @@ def on_phase_end(self, task) -> None: # for meters to not provide a sync function. self._log_loss_lr_meters(task, prefix="Synced meters: ", log_batches=True) + logging.info( + f"max memory allocated(MB) {torch.cuda.max_memory_allocated() // 1e6}" + ) + logging.info( + f"max memory reserved(MB) {torch.cuda.max_memory_reserved() // 1e6}" + ) + def on_step(self, task) -> None: """ Log the LR every log_freq batches, if log_freq is not None. diff --git a/classy_vision/hooks/model_complexity_hook.py b/classy_vision/hooks/model_complexity_hook.py index 2d950e229a..0fc4b234c3 100644 --- a/classy_vision/hooks/model_complexity_hook.py +++ b/classy_vision/hooks/model_complexity_hook.py @@ -27,11 +27,12 @@ class ModelComplexityHook(ClassyHook): on_phase_end = ClassyHook._noop on_end = ClassyHook._noop - def __init__(self) -> None: + def __init__(self, verbose=False) -> None: super().__init__() self.num_flops = None self.num_activations = None self.num_parameters = None + self.verbose = verbose def on_start(self, task) -> None: """Measure number of parameters, FLOPs and activations.""" @@ -48,15 +49,13 @@ def on_start(self, task) -> None: input_key=task.base_model.input_key if hasattr(task.base_model, "input_key") else None, + verbose=self.verbose, ) if self.num_flops is None: logging.info("FLOPs for forward pass: skipped.") self.num_flops = 0 else: - logging.info( - "FLOPs for forward pass: %d MFLOPs" - % (float(self.num_flops) / 1e6) - ) + logging.info(f"FLOPs for forward pass: {self.num_flops} FLOPs") except ClassyProfilerNotImplementedError as e: logging.warning(f"Could not compute FLOPs for model forward pass: {e}") try: @@ -66,6 +65,7 @@ def on_start(self, task) -> None: input_key=task.base_model.input_key if hasattr(task.base_model, "input_key") else None, + verbose=self.verbose, ) logging.info(f"Number of activations in model: {self.num_activations}") except ClassyProfilerNotImplementedError as e: diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index bd30e29c7d..9f20ce107a 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Any, Dict from classy_vision.generic.util import ( @@ -95,19 +96,22 @@ def prepare(self) -> None: self.pretrained_checkpoint_path ) - assert ( - self.pretrained_checkpoint_dict is not None - ), "Need a pretrained checkpoint for fine tuning" - - state_load_success = update_classy_model( - self.base_model, - self.pretrained_checkpoint_dict["classy_state_dict"]["base_model"], - self.reset_heads, - self.pretrained_checkpoint_load_strict, - ) - assert ( - state_load_success - ), "Update classy state from pretrained checkpoint was unsuccessful." + if self.pretrained_checkpoint_dict is None: + logging.warn("a pretrained checkpoint is not provided") + else: + assert ( + self.pretrained_checkpoint_dict is not None + ), "Need a pretrained checkpoint for fine tuning" + + state_load_success = update_classy_model( + self.base_model, + self.pretrained_checkpoint_dict["classy_state_dict"]["base_model"], + self.reset_heads, + self.pretrained_checkpoint_load_strict, + ) + assert ( + state_load_success + ), "Update classy state from pretrained checkpoint was unsuccessful." if self.freeze_trunk: # do not track gradients for all the parameters in the model except From 1fc561ff6a71b0ba2b6200d1dd912728574547ca Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Tue, 27 Oct 2020 22:38:04 -0700 Subject: [PATCH 4/4] add extra work during phrase advance in classification task Summary: Allow to do extra work during phrase advance in the child class of classification task. This will avoid duplicate code. For example, in `classification_task_with_dict_output.py` of D24528158, the extra work is to refresh the weighted-sampled dataset. Differential Revision: D24586321 fbshipit-source-id: 4e420aad3fe8df4596688073938f5d66a1a6174e --- classy_vision/tasks/classification_task.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 443ce61dcb..3e10d76071 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -953,6 +953,9 @@ def synchronize_losses(self): synchronized_losses_tensor = all_reduce_mean(losses_tensor) self.losses = synchronized_losses_tensor.tolist() + def extra_work_advance_phase(self): + pass + def advance_phase(self): """Performs bookkeeping / task updates between phases @@ -975,6 +978,7 @@ def advance_phase(self): if self.train: self.train_phase_idx += 1 + self.extra_work_advance_phase() # Re-build dataloader & re-create iterator anytime membership changes. self.build_dataloaders_for_current_phase() self.create_data_iterators()