Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## neptune-pytorch 2.0.1

### Changes
- Add optional `param_preproc` and `grad_preproc` arguments to implement custom aggregation functions or handle nan/inf values.

## neptune-pytorch 2.0.0

### Changes
Expand Down
14 changes: 12 additions & 2 deletions src/neptune_pytorch/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import (
Optional,
Union,
Callable,
)

import torch
Expand Down Expand Up @@ -98,6 +99,8 @@ def __init__(
log_gradients: bool = False,
log_parameters: bool = False,
log_freq: int = 100,
param_preproc: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
grad_preproc: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
):
verify_type("run", run, (Run, Handler))

Expand All @@ -119,12 +122,14 @@ def __init__(
self._gradients_iter_tracker = {}
self._gradients_hook_handler = {}
if self.log_gradients:
self._grad_preproc = grad_preproc
self._add_hooks_for_grads()

self.log_parameters = log_parameters
self._params_iter_tracker = 0
self._params_hook_handler = None
if self.log_parameters:
self._param_preproc = param_preproc
self._add_hooks_for_params()

# Log integration version
Expand All @@ -136,12 +141,16 @@ def __init__(

def _add_hooks_for_grads(self):
for name, parameter in self.model.named_parameters():
if not parameter.requires_grad:
continue

self._gradients_iter_tracker[name] = 0

def hook(grad, name=name):
self._gradients_iter_tracker[name] += 1
if self._gradients_iter_tracker[name] % self.log_freq == 0:
self._namespace_handler["plots"]["gradients"][name].append(grad.norm())
x = grad.norm() if self._grad_preproc is None else self._grad_preproc(grad)
self._namespace_handler["plots"]["gradients"][name].append(x)

self._gradients_hook_handler[name] = parameter.register_hook(hook)

Expand Down Expand Up @@ -175,7 +184,8 @@ def hook(module, inp, output):
self._params_iter_tracker += 1
if self._params_iter_tracker % self.log_freq == 0:
for name, param in module.named_parameters():
self._namespace_handler["plots"]["parameters"][name].append(param.norm())
x = param.norm() if self._param_preproc is None else self._param_preproc(param)
self._namespace_handler["plots"]["parameters"][name].append(x)

self._params_hook_handler = self.model.register_forward_hook(hook)

Expand Down