Skip to content

FusedAdam as a drop-in replacement for AdamW #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions apex/optimizers/fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,11 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
bias_correction = 1 if group['bias_correction'] else 0
beta1, beta2 = group['betas']

# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
if 'step' in group:
group['step'] += 1 if not self.capturable else (self._dummy_overflow_buf != 1).to(torch.int)
else:
group['step'] = 1 if not self.capturable else torch.tensor([1], dtype=torch.int, device=device)

# create lists for multi-tensor apply
g_16, p_16, m_16, v_16 = [], [], [], []
g_bf, p_bf, m_bf, v_bf = [], [], [], []
g_32, p_32, m_32, v_32 = [], [], [], []

steps = []
for p in group['params']:
if p.grad is None:
continue
Expand All @@ -147,6 +140,15 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
# Backward compatibility, we
assert 'step' not in group
state['step'] = 1.0 if not self.capturable else torch.tensor(1, dtype=torch.float, device=device)
else:
# Backward compatibility: we used to assume that `step` was the same across group
if 'step' in group:
assert 'step' not in state
state['step'] = group['step'].squeeze(0)
state['step'] += 1.0 if not self.capturable else (self._dummy_overflow_buf != 1).to(torch.float).squeeze(0)

if p.dtype == torch.float16:
g_16.append(p.grad.data)
Expand All @@ -165,6 +167,13 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
v_32.append(state['exp_avg_sq'])
else:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
steps.append(state['step'])

# Backward compatibility: we used to assume that `step` was the same across group
if 'step' in group:
del group['step']
if self.capturable:
steps = torch.stack(steps)

# If the optimizer is capturable, then if there's a grad scaler it works
# on the GPU + a different multi_tensor_applier should be called
Expand Down Expand Up @@ -193,7 +202,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
beta1,
beta2,
group['eps'],
group['step'],
steps,
self.adam_w_mode,
bias_correction,
group['weight_decay'],
Expand All @@ -208,7 +217,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
beta1,
beta2,
group['eps'],
group['step'],
steps,
self.adam_w_mode,
bias_correction,
group['weight_decay'],
Expand All @@ -222,7 +231,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
beta1,
beta2,
group['eps'],
group['step'],
steps,
self.adam_w_mode,
bias_correction,
group['weight_decay'],
Expand All @@ -236,7 +245,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
beta1,
beta2,
group['eps'],
group['step'],
steps,
self.adam_w_mode,
bias_correction,
group['weight_decay'])
Expand All @@ -250,7 +259,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
beta1,
beta2,
group['eps'],
group['step'],
steps,
self.adam_w_mode,
bias_correction,
group['weight_decay'])
Expand All @@ -263,7 +272,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no
beta1,
beta2,
group['eps'],
group['step'],
steps,
self.adam_w_mode,
bias_correction,
group['weight_decay'])
Expand Down
4 changes: 2 additions & 2 deletions csrc/amp_C_frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void multi_tensor_adam_cuda(
const float beta1,
const float beta2,
const float epsilon,
const int step,
std::vector<float> steps,
const int mode,
const int bias_correction,
const float weight_decay);
Expand All @@ -89,7 +89,7 @@ void multi_tensor_adam_capturable_cuda(
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
at::Tensor steps,
const int mode,
const int bias_correction,
const float weight_decay,
Expand Down
56 changes: 31 additions & 25 deletions csrc/multi_tensor_adam.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ struct AdamFunctor
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float* steps,
const int bias_correction,
const float epsilon,
const float lr,
adamMode_t mode,
Expand All @@ -39,13 +39,18 @@ struct AdamFunctor
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
float beta1_correction = 1.0f, beta2_correction = 1.0f;

int tensor_loc = tl.block_to_tensor[blockIdx.x];
const int tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
const int tensor_num = tl.start_tensor_this_launch + tensor_loc;
if (bias_correction == 1) {
beta1_correction = 1 - pow(beta1, steps[tensor_num]);
beta2_correction = 1 - pow(beta2, steps[tensor_num]);
}

int chunk_idx = tl.block_to_chunk[blockIdx.x];
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
Expand Down Expand Up @@ -135,7 +140,7 @@ struct AdamCapturableFunctor
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const int* step,
const float* steps,
const int bias_correction,
const float epsilon,
const float* lr,
Expand All @@ -147,17 +152,17 @@ struct AdamCapturableFunctor
return;

float beta1_correction = 1.0f, beta2_correction = 1.0f;
if (bias_correction == 1) {
beta1_correction = 1 - pow(beta1, *step);
beta2_correction = 1 - pow(beta2, *step);
}

int tensor_loc = tl.block_to_tensor[blockIdx.x];
const int tensor_loc = tl.block_to_tensor[blockIdx.x];

// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
const int tensor_num = tl.start_tensor_this_launch + tensor_loc;
if (bias_correction == 1) {
beta1_correction = 1 - pow(beta1, steps[tensor_num]);
beta2_correction = 1 - pow(beta2, steps[tensor_num]);
}

int chunk_idx = tl.block_to_chunk[blockIdx.x];
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];

T* g = (T*)tl.addresses[0][tensor_loc];
Expand Down Expand Up @@ -247,23 +252,21 @@ void multi_tensor_adam_cuda(
const float beta1,
const float beta2,
const float epsilon,
const int step,
std::vector<float> steps,
const int mode,
const int bias_correction,
const float weight_decay)
{
using namespace at;

// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
float* cuda_steps;
cudaMalloc((void**)&cuda_steps, steps.size()*sizeof(float));
cudaMemcpy(cuda_steps, steps.data(), steps.size()*sizeof(float), cudaMemcpyHostToDevice);

// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
tensor_lists[0][0].scalar_type(), 0, "adam",

multi_tensor_apply<4>(
BLOCK_SIZE,
chunk_size,
Expand All @@ -272,13 +275,16 @@ void multi_tensor_adam_cuda(
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
cuda_steps,
bias_correction,
epsilon,
lr,
(adamMode_t) mode,
weight_decay); )
weight_decay
);
)

cudaFree(cuda_steps);
AT_CUDA_CHECK(cudaGetLastError());

}
Expand All @@ -291,7 +297,7 @@ void multi_tensor_adam_capturable_cuda(
const float beta1,
const float beta2,
const float epsilon,
at::Tensor step,
at::Tensor steps,
const int mode,
const int bias_correction,
const float weight_decay,
Expand All @@ -309,7 +315,7 @@ void multi_tensor_adam_capturable_cuda(
AdamCapturableFunctor<scalar_t_0>(),
beta1,
beta2,
step.data_ptr<int>(),
steps.data_ptr<float>(),
bias_correction,
epsilon,
lr.data_ptr<float>(),
Expand Down
62 changes: 62 additions & 0 deletions tests/L0/run_optimizers/test_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import random
import unittest
from typing import Dict, Optional, List

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -185,6 +186,67 @@ def testNative(self):

self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

def testStateDict(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
x_ = x.clone()
gt = torch.rand([32, 10]).cuda()
gt_ = gt.clone()

# Reference
y = self.model(x)
loss = ((gt - y) ** 2).mean()

loss.backward()
self.optimizer.step()

# DUT
y = self.model_(x)
loss_ = ((gt_ - y) ** 2).mean()

loss_.backward()
optimizer_.step()

opt_state_dict = self.optimizer.state_dict()
opt_state_dict_ = optimizer_.state_dict()
assert_is_dict_equal(opt_state_dict, opt_state_dict_)

# Init for next iteration
self.optimizer.zero_grad()
optimizer_.zero_grad()

self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))

def assert_is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None, **tensor_testing_assert_close_kwargs):
if sub_paths is None:
sub_paths = []

first_keys = set(first.keys())
second_keys = set(second.keys())
assert first_keys == second_keys, f"Keys don't match in {'.'.join(sub_paths)}.\nCur: {first_keys}\nRef: {second_keys}"

for key in first_keys:
first_elt = first[key]
second_elt = second[key]

if isinstance(first_elt, dict):
assert isinstance(second_elt, dict), f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}"
assert_is_dict_equal(first_elt, second_elt, sub_paths=sub_paths + [str(key)])
elif isinstance(first_elt, torch.Tensor):
# We accept that devices can be different
second_elt = torch.as_tensor(second_elt, device=first_elt.device)
torch.testing.assert_close(
first_elt,
second_elt,
**tensor_testing_assert_close_kwargs,
msg=lambda
msg: f"Tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}",
)
else:
assert first_elt != second_elt, f"Objects at key {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}"

if __name__ == '__main__':
unittest.main()
Expand Down