diff --git a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py index 339d5c71267..d53bbb6f1bb 100644 --- a/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py +++ b/src/sparseml/pytorch/sparsification/quantization/modifier_quantization.py @@ -103,6 +103,7 @@ class QuantizationModifier(ScheduledModifier): | freeze_bn_stats_epoch: 3.0 | model_fuse_fn_name: 'fuse_module' | strict: True + | verbose: False :param start_epoch: The epoch to start the modifier at :param scheme: Default QuantizationScheme to use when enabling quantization @@ -133,6 +134,8 @@ class QuantizationModifier(ScheduledModifier): scheme_overrides or ignore are not found in a given module. Default True :param end_epoch: Disabled, setting to anything other than -1 will raise an exception. For compatibility with YAML serialization only. + :param verbose: if True, will log detailed information such as number of bits, batch + norm freezing etc. Default to False """ def __init__( @@ -148,6 +151,7 @@ def __init__( num_calibration_steps: Optional[int] = None, strict: bool = True, end_epoch: float = -1.0, + verbose: bool = False, ): raise_if_torch_quantization_not_available() if end_epoch != -1: @@ -178,7 +182,7 @@ def __init__( self._model_fuse_fn_name = None self._strict = strict - + self._verbose = verbose self._qat_enabled = False self._quantization_observer_disabled = False self._bn_stats_frozen = False @@ -348,6 +352,22 @@ def strict(self, value: bool): """ self._strict = value + @ModifierProp() + def verbose(self) -> bool: + """ + :return: if True, will log detailed information such as number of bits, batch + norm freezing etc + """ + return self._verbose + + @strict.setter + def verbose(self, value: bool): + """ + :params value: if True, will log detailed information such as number of bits, + batch norm freezing etc. + """ + self._verbose = value + def initialize( self, module: Module, @@ -455,7 +475,8 @@ def _check_quantization_update( module.apply(freeze_bn_stats) self._bn_stats_frozen = True - self._log_quantization(module, epoch, steps_per_epoch) + if self._verbose: + self._log_quantization(module, epoch, steps_per_epoch) def _disable_quantization_observer_update_ready(self, epoch: float) -> bool: return (