Skip to content

Quality-of-Life for Google Colab #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
## Changelog

- save output images as JPG,
- automatically resume from the latest `.pkl` file with the command-line argument `--resume=latest`,
- automatically set the resume value of `kimg`,
- automatically set the resume value of the augmentation strength,
- allow to **manually** set the resume value of the augmentation strength,
- add config `auto_norp` to replicate the `auto` config without EMA rampup,
- allow to override mapping net depth with the command-line argument `--cfg_map`,
- allow to enforce CIFAR-specific architecture tuning with the command-line argument `--cifar_tune`.

## StyleGAN2-ADA — Official PyTorch implementation

![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)
Expand Down
5 changes: 5 additions & 0 deletions torch_utils/ops/conv2d_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import warnings
import contextlib
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -21,6 +22,7 @@

enabled = False # Enable the custom op by setting this to true.
weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11

@contextlib.contextmanager
def no_weight_gradients():
Expand Down Expand Up @@ -48,6 +50,9 @@ def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if _use_pytorch_1_11_api:
# The work-around code doesn't work on PyTorch 1.11.0 onwards
return False
if input.device.type != 'cuda':
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
Expand Down
13 changes: 11 additions & 2 deletions torch_utils/ops/grid_sample_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import warnings
import torch
from pkg_resources import parse_version

# pylint: disable=redefined-builtin
# pylint: disable=arguments-differ
Expand All @@ -21,6 +22,8 @@
#----------------------------------------------------------------------------

enabled = False # Enable the custom op by setting this to true.
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11
_use_pytorch_1_12_api = parse_version(torch.__version__) >= parse_version('1.12.0a') # Allow prerelease builds of 1.12

#----------------------------------------------------------------------------

Expand All @@ -34,7 +37,7 @@ def grid_sample(input, grid):
def _should_use_custom_op():
if not enabled:
return False
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']):
if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9', '1.1', '2']):
return True
warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().')
return False
Expand Down Expand Up @@ -62,7 +65,13 @@ class _GridSample2dBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, grad_output, input, grid):
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
if _use_pytorch_1_12_api:
op = op[0]
if _use_pytorch_1_11_api:
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)
else:
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
ctx.save_for_backward(grid)
return grad_input, grad_grid

Expand Down
28 changes: 23 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def setup_training_loop_kwargs(

# Base config.
cfg = None, # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar'
cifar_tune = None, # Enforce CIFAR-specific architecture tuning: <bool>, default = False
gamma = None, # Override R1 gamma: <float>
kimg = None, # Override training duration: <int>
batch = None, # Override batch size: <int>
cfg_map = None, # Override config map: <int>, default = depends on cfg

# Discriminator augmentation.
aug = None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed'
Expand Down Expand Up @@ -153,6 +155,7 @@ def setup_training_loop_kwargs(

cfg_specs = {
'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), # Populated dynamically based on resolution and GPU count.
'auto_norp': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=None, map=2),
'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), # Uses mixed-precision, unlike the original StyleGAN2.
'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8),
'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8),
Expand All @@ -162,7 +165,7 @@ def setup_training_loop_kwargs(

assert cfg in cfg_specs
spec = dnnlib.EasyDict(cfg_specs[cfg])
if cfg == 'auto':
if cfg.startswith('auto'):
desc += f'{gpus:d}'
spec.ref_gpus = gpus
res = args.training_set_kwargs.resolution
Expand Down Expand Up @@ -192,7 +195,14 @@ def setup_training_loop_kwargs(
args.ema_kimg = spec.ema
args.ema_rampup = spec.ramp

if cfg == 'cifar':
if cifar_tune is None:
cifar_tune = False
else:
assert isinstance(cifar_tune, bool)
if cifar_tune:
desc += '-tuning'

if cifar_tune or cfg == 'cifar':
args.loss_kwargs.pl_weight = 0 # disable path length regularization
args.loss_kwargs.style_mixing_prob = 0 # disable style mixing
args.D_kwargs.architecture = 'orig' # disable residual skip connections
Expand All @@ -219,6 +229,12 @@ def setup_training_loop_kwargs(
args.batch_size = batch
args.batch_gpu = batch // gpus

if cfg_map is not None:
assert isinstance(cfg_map, int)
if not cfg_map >= 1:
raise UserError('--cfg_map must be at least 1')
args.G_kwargs.mapping_kwargs.num_layers = cfg_map

# ---------------------------------------------------
# Discriminator augmentation: aug, p, target, augpipe
# ---------------------------------------------------
Expand All @@ -244,8 +260,8 @@ def setup_training_loop_kwargs(

if p is not None:
assert isinstance(p, float)
if aug != 'fixed':
raise UserError('--p can only be specified with --aug=fixed')
if resume != 'latest' and aug != 'fixed':
raise UserError('--p can only be specified with --resume=latest or --aug=fixed')
if not 0 <= p <= 1:
raise UserError('--p must be between 0 and 1')
desc += f'-p{p:g}'
Expand Down Expand Up @@ -413,10 +429,12 @@ def convert(self, value, param, ctx):
@click.option('--mirror', help='Enable dataset x-flips [default: false]', type=bool, metavar='BOOL')

# Base config.
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
@click.option('--cfg', help='Base config [default: auto]', type=click.Choice(['auto', 'auto_norp', 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar']))
@click.option('--cifar_tune', help='Enforce CIFAR-specific architecture tuning (default: false)', type=bool, metavar='BOOL')
@click.option('--gamma', help='Override R1 gamma', type=float)
@click.option('--kimg', help='Override training duration', type=int, metavar='INT')
@click.option('--batch', help='Override batch size', type=int, metavar='INT')
@click.option('--cfg_map', help='Override config map', type=int, metavar='INT')

# Discriminator augmentation.
@click.option('--aug', help='Augmentation mode [default: ada]', type=click.Choice(['noaug', 'ada', 'fixed']))
Expand Down
66 changes: 66 additions & 0 deletions training/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import glob
import os
import re

from pathlib import Path

def get_parent_dir(run_dir):
out_dir = Path(run_dir).parent

return out_dir

def locate_latest_pkl(out_dir):
all_pickle_names = sorted(glob.glob(os.path.join(out_dir, '0*', 'network-*.pkl')))

try:
latest_pickle_name = all_pickle_names[-1]
except IndexError:
latest_pickle_name = None

return latest_pickle_name

def parse_kimg_from_network_name(network_pickle_name):

if network_pickle_name is not None:
resume_run_id = os.path.basename(os.path.dirname(network_pickle_name))
RE_KIMG = re.compile('network-snapshot-(\d+).pkl')
try:
kimg = int(RE_KIMG.match(os.path.basename(network_pickle_name)).group(1))
except AttributeError:
kimg = 0.0
else:
kimg = 0.0

return float(kimg)


def parse_augment_p_from_log(network_pickle_name):

if network_pickle_name is not None:
network_folder_name = os.path.dirname(network_pickle_name)
log_file_name = network_folder_name + "/log.txt"

try:
with open(log_file_name, "r") as f:
# Tokenize each line starting with the word 'tick'
lines = [
l.strip().split() for l in f.readlines() if l.startswith("tick")
]
except FileNotFoundError:
lines = []

# Extract the last token of each line for which the second to last token is 'augment'
values = [
tokens[-1]
for tokens in lines
if len(tokens) > 1 and tokens[-2] == "augment"
]

if len(values)>0:
augment_p = float(values[-1])
else:
augment_p = 0.0
else:
augment_p = 0.0

return float(augment_p)
25 changes: 20 additions & 5 deletions training/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import legacy
from metrics import metric_main
from training import misc as tmisc

#----------------------------------------------------------------------------

Expand Down Expand Up @@ -152,6 +153,20 @@ def training_loop(
G_ema = copy.deepcopy(G).eval()

# Resume from existing pickle.
if resume_pkl == 'latest':
out_dir = tmisc.get_parent_dir(run_dir)
resume_pkl = tmisc.locate_latest_pkl(out_dir)

resume_kimg = tmisc.parse_kimg_from_network_name(resume_pkl)
if resume_kimg > 0:
print(f'Resuming from kimg = {resume_kimg}')

if ada_target is not None and augment_p == 0:
# Overwrite augment_p only if the augmentation probability is not fixed by the user
augment_p = tmisc.parse_augment_p_from_log(resume_pkl)
if augment_p > 0:
print(f'Resuming with augment_p = {augment_p}')

if (resume_pkl is not None) and (rank == 0):
print(f'Resuming from "{resume_pkl}"')
with dnnlib.util.open_url(resume_pkl) as f:
Expand Down Expand Up @@ -220,11 +235,11 @@ def training_loop(
if rank == 0:
print('Exporting sample images...')
grid_size, images, labels = setup_snapshot_image_grid(training_set=training_set)
save_image_grid(images, os.path.join(run_dir, 'reals.png'), drange=[0,255], grid_size=grid_size)
save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size)
grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu)
grid_c = torch.from_numpy(labels).to(device).split(batch_gpu)
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
save_image_grid(images, os.path.join(run_dir, 'fakes_init.png'), drange=[-1,1], grid_size=grid_size)
save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size)

# Initialize logs.
if rank == 0:
Expand All @@ -245,14 +260,14 @@ def training_loop(
if rank == 0:
print(f'Training for {total_kimg} kimg...')
print()
cur_nimg = 0
cur_nimg = int(resume_kimg * 1000)
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
maintenance_time = tick_start_time - start_time
batch_idx = 0
if progress_fn is not None:
progress_fn(0, total_kimg)
progress_fn(int(resume_kimg), total_kimg)
while True:

# Fetch training data.
Expand Down Expand Up @@ -347,7 +362,7 @@ def training_loop(
# Save image snapshot.
if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0):
images = torch.cat([G_ema(z=z, c=c, noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy()
save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size)

# Save network snapshot.
snapshot_pkl = None
Expand Down