From 268fe5afe646326382d8a818026b0e145692b523 Mon Sep 17 00:00:00 2001 From: Keillion <735699921@qq.com> Date: Tue, 21 Jun 2022 00:50:51 +0800 Subject: [PATCH] Resolve deprecated warning --- optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/optimizer.py b/optimizer.py index 6d6dc11..8225ef6 100644 --- a/optimizer.py +++ b/optimizer.py @@ -87,12 +87,12 @@ def step(self, closure=None): for group in self.optimizer.param_groups: for p in group['params']: param_state = self.state[p] - p.data.mul_(self.alpha).add_(1.0 - self.alpha, param_state['cached_params']) # crucial line + p.data.mul_(self.alpha).add_(param_state['cached_params'], alpha=1.0 - self.alpha, ) # crucial line param_state['cached_params'].copy_(p.data) if self.pullback_momentum == "pullback": internal_momentum = self.optimizer.state[p]["momentum_buffer"] self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.alpha).add_( - 1.0 - self.alpha, param_state["cached_mom"]) + param_state["cached_mom"], alpha=1.0 - self.alpha) param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"] elif self.pullback_momentum == "reset": self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)