Skip to content

Add blender dataset + alpha training support #573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/datasets/blender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from dataclasses import dataclass
import json
from pathlib import Path
from typing import Any, Dict, Literal

import imageio.v2 as imageio
import numpy as np
import torch


@dataclass
class Dataset:
"""A simple dataset class for synthetic blender data."""

data_dir: str
"""The path to the blender scene, consisting of renders and transforms.json"""
split: Literal["train", "test", "val"] = "train"
"""Which split to use."""

def __post_init__(self):
self.data_dir = Path(self.data_dir)
transforms_path = self.data_dir / f"transforms_{self.split}.json"
with transforms_path.open("r") as transforms_handle:
transforms = json.load(transforms_handle)
image_ids = []
cam_to_worlds = []
images = []
for frame in transforms["frames"]:
image_id = frame["file_path"].replace("./", "")
image_ids.append(image_id)
file_path = self.data_dir / f"{image_id}.png"
images.append(imageio.imread(file_path))

c2w = torch.tensor(frame["transform_matrix"])
# Convert from OpenGL to OpenCV coordinate system
c2w[0:3, 1:3] *= -1
cam_to_worlds.append(c2w)

self.image_ids = image_ids
self.cam_to_worlds = cam_to_worlds
self.images = images

# all renders have the same intrinsics
# see also
# https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/data/dataparsers/blender_dataparser.py
image_height, image_width = self.images[0].shape[:2]
cx = image_width / 2.0
cy = image_height / 2.0
fl = 0.5 * image_width / np.tan(0.5 * transforms["camera_angle_x"])
self.intrinsics = torch.tensor(
[[fl, 0, cx], [0, fl, cy], [0, 0, 1]], dtype=torch.float32
)
self.image_height = image_height
self.image_width = image_width

# compute scene scale (as is done in the colmap parser)
camera_locations = np.stack(self.cam_to_worlds, axis=0)[:, :3, 3]
scene_center = np.mean(camera_locations, axis=0)
dists = np.linalg.norm(camera_locations - scene_center, axis=1)
self.scene_scale = np.max(dists)

def __len__(self):
return len(self.image_ids)

def __getitem__(self, item: int) -> Dict[str, Any]:
data = dict(
K=self.intrinsics,
camtoworld=self.cam_to_worlds[item],
image=torch.from_numpy(self.images[item]).float(),
image_id=item,
)
return data
166 changes: 136 additions & 30 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tyro
import viser
import yaml
from datasets.colmap import Dataset, Parser
from datasets.colmap import Parser
from datasets.traj import (
generate_interpolated_path,
generate_ellipse_path_z,
Expand Down Expand Up @@ -55,8 +55,10 @@ class Config:
# Render trajectory path
render_traj_path: str = "interp"

# Path to the Mip-NeRF 360 dataset
# Path to the dataset
data_dir: str = "data/360_v2/garden"
# Type of the dataset (e.g. COLMAP or Blender)
data_type: Literal["colmap", "blender"] = "colmap"
# Downsample factor for the dataset
data_factor: int = 4
# Directory to save results
Expand Down Expand Up @@ -92,7 +94,7 @@ class Config:
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

# Initialization strategy
init_type: str = "sfm"
init_type: Literal["sfm", "random"] = "sfm"
# Initial number of GSs. Ignored if using sfm
init_num_pts: int = 100_000
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
Expand Down Expand Up @@ -126,8 +128,10 @@ class Config:
# Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
antialiased: bool = False

# Use random background for training to discourage transparency
# Use random background for training to encourage alpha consistency w/ source
random_bkgd: bool = False
# Fixed background color to use w/ transparent source images for evaluation (and training if random_bkgd is False)
bkgd_color: List[int] = field(default_factory=lambda: [255, 255, 255])

# Opacity regularization
opacity_reg: float = 0.0
Expand Down Expand Up @@ -191,7 +195,7 @@ def adjust_steps(self, factor: float):


def create_splats_with_optimizers(
parser: Parser,
parser: Optional[Parser],
init_type: str = "sfm",
init_num_pts: int = 100_000,
init_extent: float = 3.0,
Expand Down Expand Up @@ -306,22 +310,33 @@ def __init__(
# Tensorboard
self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")

# Load data: Training data should contain initial points and colors.
self.parser = Parser(
data_dir=cfg.data_dir,
factor=cfg.data_factor,
normalize=cfg.normalize_world_space,
test_every=cfg.test_every,
)
self.trainset = Dataset(
self.parser,
split="train",
patch_size=cfg.patch_size,
load_depths=cfg.depth_loss,
)
self.valset = Dataset(self.parser, split="val")
self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale
print("Scene scale:", self.scene_scale)
if cfg.data_type == "colmap":
from datasets.colmap import Dataset

# Load data: Training data should contain initial points and colors.
self.parser = Parser(
data_dir=cfg.data_dir,
factor=cfg.data_factor,
normalize=cfg.normalize_world_space,
test_every=cfg.test_every,
)
self.trainset = Dataset(
self.parser,
split="train",
patch_size=cfg.patch_size,
load_depths=cfg.depth_loss,
)
self.valset = Dataset(self.parser, split="val")
self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale
print("Scene scale:", self.scene_scale)
elif cfg.data_type == "blender":
from datasets.blender import Dataset

self.parser = None
self.trainset = Dataset(cfg.data_dir, split="train")
# using `test` over `val` for evaluation - following same convention as in https://nerfbaselines.github.io/
self.valset = Dataset(cfg.data_dir, split="test")
self.scene_scale = self.trainset.scene_scale * 1.1 * cfg.global_scale

# Model
feature_dim = 32 if cfg.app_opt else None
Expand Down Expand Up @@ -448,6 +463,10 @@ def __init__(
mode="training",
)

self.fixed_bkgd = (
torch.tensor(cfg.bkgd_color, device=self.device)[None, :] / 255.0
)

def rasterize_splats(
self,
camtoworlds: Tensor,
Expand Down Expand Up @@ -576,7 +595,7 @@ def train(self):

camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4]
Ks = data["K"].to(device) # [1, 3, 3]
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
pixels = data["image"].to(device) / 255.0 # [1, H, W, 3 or 4]
num_train_rays_per_step = (
pixels.shape[0] * pixels.shape[1] * pixels.shape[2]
)
Expand Down Expand Up @@ -624,8 +643,35 @@ def train(self):
grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0)
colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"]

gt_has_alpha = pixels.shape[-1] == 4

if cfg.random_bkgd:
bkgd = torch.rand(1, 3, device=device)
else:
bkgd = self.fixed_bkgd

if gt_has_alpha:
# we will apply the same background to both gt and render
# this encourages consistency between the gt image alpha
# and the render alpha (without adding any new losses)
# this works best with random_bkgd = True, as the transparent pixels
# in the source image will be a random color each iteration that the
# splat can only match by also being transparent
gt_alpha = pixels[..., [-1]]
gt_rgb = pixels[..., :3]

# per NeRFStudio logic - we assume the source image is not premultiplied
# see https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L627
pixels = gt_rgb * gt_alpha + bkgd * (1 - gt_alpha)

# likewise - we consider the render RGB to be premultiplied
# this follows existing logic in this script and also
# https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L583
colors = colors + bkgd * (1.0 - alphas)
elif cfg.random_bkgd:
# random_bkgd is True but the source doesn't have alpha
# so we only apply the random background to the render colors
# this preserves the prior behavior ("to discourage transparaency")
colors = colors + bkgd * (1.0 - alphas)

self.cfg.strategy.step_pre_backward(
Expand Down Expand Up @@ -871,7 +917,7 @@ def eval(self, step: int, stage: str = "val"):

torch.cuda.synchronize()
tic = time.time()
colors, _, _ = self.rasterize_splats(
colors, alphas, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
Expand All @@ -884,8 +930,39 @@ def eval(self, step: int, stage: str = "val"):
torch.cuda.synchronize()
ellipse_time += time.time() - tic

colors = torch.clamp(colors, 0.0, 1.0)
canvas_list = [pixels, colors]
gt_has_alpha = pixels.shape[-1] == 4
if gt_has_alpha:
# We want to assess metrics with images composed on the fixed background
# But also prepare copies to write out with the alpha channel added normally (i.e. as RGBA)
bkgd = self.fixed_bkgd
gt_alpha = pixels[..., [-1]]
gt_rgb = pixels[..., :3]
# per NeRFStudio logic - we assume the source image is not premultiplied
# see https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L627
eval_pixels = gt_rgb * gt_alpha + bkgd * (1 - gt_alpha)
# likewise - we consider the render RGB to be premultiplied
# this follows existing logic in this script and also
# https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L583
eval_colors = colors + bkgd * (1.0 - alphas)
image_pixels = pixels
image_colors = torch.cat([colors, alphas], dim=-1)
# Clamp to [0.0, 1.0]
# Image pixels don't need it because they are read directly from file
eval_colors = torch.clamp(eval_colors, 0.0, 1.0)
eval_pixels = torch.clamp(eval_pixels, 0.0, 1.0)
image_colors = torch.clamp(image_colors, 0.0, 1.0)
else:
# Because the source image does not have an alpha channel, we
# will simply use the same values for metrics and saving images
image_pixels = pixels
eval_pixels = pixels
# Clamp to [0.0, 1.0]
# Only colors need clamping, as the pixels are read directly from file
colors = torch.clamp(colors, 0.0, 1.0)
image_colors = colors
eval_colors = colors

canvas_list = [image_pixels, image_colors]

if world_rank == 0:
# write images
Expand All @@ -896,8 +973,8 @@ def eval(self, step: int, stage: str = "val"):
canvas,
)

pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W]
pixels_p = eval_pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
colors_p = eval_colors.permute(0, 3, 1, 2) # [1, 3, H, W]
metrics["psnr"].append(self.psnr(colors_p, pixels_p))
metrics["ssim"].append(self.ssim(colors_p, pixels_p))
metrics["lpips"].append(self.lpips(colors_p, pixels_p))
Expand Down Expand Up @@ -936,7 +1013,21 @@ def render_traj(self, step: int):
cfg = self.cfg
device = self.device

camtoworlds_all = self.parser.camtoworlds[5:-5]
if self.parser is not None:
camtoworlds_all = self.parser.camtoworlds[5:-5]
K = (
torch.from_numpy(list(self.parser.Ks_dict.values())[0])
.float()
.to(device)
)
width, height = list(self.parser.imsize_dict.values())[0]
else:
camtoworlds_all = np.stack(
[elem.cpu().numpy() for elem in self.trainset.cam_to_worlds], axis=0
)
K = self.trainset.intrinsics.to(device)
width = self.trainset.image_width
height = self.trainset.image_height
if cfg.render_traj_path == "interp":
camtoworlds_all = generate_interpolated_path(
camtoworlds_all, 1
Expand Down Expand Up @@ -968,8 +1059,6 @@ def render_traj(self, step: int):
) # [N, 4, 4]

camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device)
K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
width, height = list(self.parser.imsize_dict.values())[0]

# save to video
video_dir = f"{cfg.result_dir}/videos"
Expand Down Expand Up @@ -1045,6 +1134,23 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
cfg.disable_viewer = True
if world_rank == 0:
print("Viewer is disabled in distributed training.")
if cfg.data_type == "blender":
print("Setting / overriding config settings for blender data")
print("Forcing init type to random")
cfg.init_type = "random"
print("Setting near and far to Blender data recommended settings (2 and 6, respectively)")
# Taken from nerfbaselines setting
cfg.near_plane = 2.0
cfg.far_plane = 6.0
print("Setting init_extent to 0.5")
# Taken from nerfbaselines setting
cfg.init_extent = 0.5

if cfg.render_traj_path == "spiral":
print(
"Spiral render trajectory is not supported for blender data, setting to interp instead"
)
cfg.render_traj_path = "interp"

runner = Runner(local_rank, world_rank, world_size, cfg)

Expand Down