diff --git a/examples/datasets/blender.py b/examples/datasets/blender.py new file mode 100644 index 000000000..7654c678b --- /dev/null +++ b/examples/datasets/blender.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +import json +from pathlib import Path +from typing import Any, Dict, Literal + +import imageio.v2 as imageio +import numpy as np +import torch + + +@dataclass +class Dataset: + """A simple dataset class for synthetic blender data.""" + + data_dir: str + """The path to the blender scene, consisting of renders and transforms.json""" + split: Literal["train", "test", "val"] = "train" + """Which split to use.""" + + def __post_init__(self): + self.data_dir = Path(self.data_dir) + transforms_path = self.data_dir / f"transforms_{self.split}.json" + with transforms_path.open("r") as transforms_handle: + transforms = json.load(transforms_handle) + image_ids = [] + cam_to_worlds = [] + images = [] + for frame in transforms["frames"]: + image_id = frame["file_path"].replace("./", "") + image_ids.append(image_id) + file_path = self.data_dir / f"{image_id}.png" + images.append(imageio.imread(file_path)) + + c2w = torch.tensor(frame["transform_matrix"]) + # Convert from OpenGL to OpenCV coordinate system + c2w[0:3, 1:3] *= -1 + cam_to_worlds.append(c2w) + + self.image_ids = image_ids + self.cam_to_worlds = cam_to_worlds + self.images = images + + # all renders have the same intrinsics + # see also + # https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/data/dataparsers/blender_dataparser.py + image_height, image_width = self.images[0].shape[:2] + cx = image_width / 2.0 + cy = image_height / 2.0 + fl = 0.5 * image_width / np.tan(0.5 * transforms["camera_angle_x"]) + self.intrinsics = torch.tensor( + [[fl, 0, cx], [0, fl, cy], [0, 0, 1]], dtype=torch.float32 + ) + self.image_height = image_height + self.image_width = image_width + + # compute scene scale (as is done in the colmap parser) + camera_locations = np.stack(self.cam_to_worlds, axis=0)[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + self.scene_scale = np.max(dists) + + def __len__(self): + return len(self.image_ids) + + def __getitem__(self, item: int) -> Dict[str, Any]: + data = dict( + K=self.intrinsics, + camtoworld=self.cam_to_worlds[item], + image=torch.from_numpy(self.images[item]).float(), + image_id=item, + ) + return data diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ca9271e81..01fec1dc3 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -15,7 +15,7 @@ import tyro import viser import yaml -from datasets.colmap import Dataset, Parser +from datasets.colmap import Parser from datasets.traj import ( generate_interpolated_path, generate_ellipse_path_z, @@ -55,8 +55,10 @@ class Config: # Render trajectory path render_traj_path: str = "interp" - # Path to the Mip-NeRF 360 dataset + # Path to the dataset data_dir: str = "data/360_v2/garden" + # Type of the dataset (e.g. COLMAP or Blender) + data_type: Literal["colmap", "blender"] = "colmap" # Downsample factor for the dataset data_factor: int = 4 # Directory to save results @@ -92,7 +94,7 @@ class Config: ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Initialization strategy - init_type: str = "sfm" + init_type: Literal["sfm", "random"] = "sfm" # Initial number of GSs. Ignored if using sfm init_num_pts: int = 100_000 # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm @@ -126,8 +128,10 @@ class Config: # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. antialiased: bool = False - # Use random background for training to discourage transparency + # Use random background for training to encourage alpha consistency w/ source random_bkgd: bool = False + # Fixed background color to use w/ transparent source images for evaluation (and training if random_bkgd is False) + bkgd_color: List[int] = field(default_factory=lambda: [255, 255, 255]) # Opacity regularization opacity_reg: float = 0.0 @@ -191,7 +195,7 @@ def adjust_steps(self, factor: float): def create_splats_with_optimizers( - parser: Parser, + parser: Optional[Parser], init_type: str = "sfm", init_num_pts: int = 100_000, init_extent: float = 3.0, @@ -306,22 +310,33 @@ def __init__( # Tensorboard self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") - # Load data: Training data should contain initial points and colors. - self.parser = Parser( - data_dir=cfg.data_dir, - factor=cfg.data_factor, - normalize=cfg.normalize_world_space, - test_every=cfg.test_every, - ) - self.trainset = Dataset( - self.parser, - split="train", - patch_size=cfg.patch_size, - load_depths=cfg.depth_loss, - ) - self.valset = Dataset(self.parser, split="val") - self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale - print("Scene scale:", self.scene_scale) + if cfg.data_type == "colmap": + from datasets.colmap import Dataset + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + elif cfg.data_type == "blender": + from datasets.blender import Dataset + + self.parser = None + self.trainset = Dataset(cfg.data_dir, split="train") + # using `test` over `val` for evaluation - following same convention as in https://nerfbaselines.github.io/ + self.valset = Dataset(cfg.data_dir, split="test") + self.scene_scale = self.trainset.scene_scale * 1.1 * cfg.global_scale # Model feature_dim = 32 if cfg.app_opt else None @@ -448,6 +463,10 @@ def __init__( mode="training", ) + self.fixed_bkgd = ( + torch.tensor(cfg.bkgd_color, device=self.device)[None, :] / 255.0 + ) + def rasterize_splats( self, camtoworlds: Tensor, @@ -576,7 +595,7 @@ def train(self): camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] Ks = data["K"].to(device) # [1, 3, 3] - pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3 or 4] num_train_rays_per_step = ( pixels.shape[0] * pixels.shape[1] * pixels.shape[2] ) @@ -624,8 +643,35 @@ def train(self): grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + gt_has_alpha = pixels.shape[-1] == 4 + if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) + else: + bkgd = self.fixed_bkgd + + if gt_has_alpha: + # we will apply the same background to both gt and render + # this encourages consistency between the gt image alpha + # and the render alpha (without adding any new losses) + # this works best with random_bkgd = True, as the transparent pixels + # in the source image will be a random color each iteration that the + # splat can only match by also being transparent + gt_alpha = pixels[..., [-1]] + gt_rgb = pixels[..., :3] + + # per NeRFStudio logic - we assume the source image is not premultiplied + # see https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L627 + pixels = gt_rgb * gt_alpha + bkgd * (1 - gt_alpha) + + # likewise - we consider the render RGB to be premultiplied + # this follows existing logic in this script and also + # https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L583 + colors = colors + bkgd * (1.0 - alphas) + elif cfg.random_bkgd: + # random_bkgd is True but the source doesn't have alpha + # so we only apply the random background to the render colors + # this preserves the prior behavior ("to discourage transparaency") colors = colors + bkgd * (1.0 - alphas) self.cfg.strategy.step_pre_backward( @@ -871,7 +917,7 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() tic = time.time() - colors, _, _ = self.rasterize_splats( + colors, alphas, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -884,8 +930,39 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() ellipse_time += time.time() - tic - colors = torch.clamp(colors, 0.0, 1.0) - canvas_list = [pixels, colors] + gt_has_alpha = pixels.shape[-1] == 4 + if gt_has_alpha: + # We want to assess metrics with images composed on the fixed background + # But also prepare copies to write out with the alpha channel added normally (i.e. as RGBA) + bkgd = self.fixed_bkgd + gt_alpha = pixels[..., [-1]] + gt_rgb = pixels[..., :3] + # per NeRFStudio logic - we assume the source image is not premultiplied + # see https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L627 + eval_pixels = gt_rgb * gt_alpha + bkgd * (1 - gt_alpha) + # likewise - we consider the render RGB to be premultiplied + # this follows existing logic in this script and also + # https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/models/splatfacto.py#L583 + eval_colors = colors + bkgd * (1.0 - alphas) + image_pixels = pixels + image_colors = torch.cat([colors, alphas], dim=-1) + # Clamp to [0.0, 1.0] + # Image pixels don't need it because they are read directly from file + eval_colors = torch.clamp(eval_colors, 0.0, 1.0) + eval_pixels = torch.clamp(eval_pixels, 0.0, 1.0) + image_colors = torch.clamp(image_colors, 0.0, 1.0) + else: + # Because the source image does not have an alpha channel, we + # will simply use the same values for metrics and saving images + image_pixels = pixels + eval_pixels = pixels + # Clamp to [0.0, 1.0] + # Only colors need clamping, as the pixels are read directly from file + colors = torch.clamp(colors, 0.0, 1.0) + image_colors = colors + eval_colors = colors + + canvas_list = [image_pixels, image_colors] if world_rank == 0: # write images @@ -896,8 +973,8 @@ def eval(self, step: int, stage: str = "val"): canvas, ) - pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] - colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + pixels_p = eval_pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = eval_colors.permute(0, 3, 1, 2) # [1, 3, H, W] metrics["psnr"].append(self.psnr(colors_p, pixels_p)) metrics["ssim"].append(self.ssim(colors_p, pixels_p)) metrics["lpips"].append(self.lpips(colors_p, pixels_p)) @@ -936,7 +1013,21 @@ def render_traj(self, step: int): cfg = self.cfg device = self.device - camtoworlds_all = self.parser.camtoworlds[5:-5] + if self.parser is not None: + camtoworlds_all = self.parser.camtoworlds[5:-5] + K = ( + torch.from_numpy(list(self.parser.Ks_dict.values())[0]) + .float() + .to(device) + ) + width, height = list(self.parser.imsize_dict.values())[0] + else: + camtoworlds_all = np.stack( + [elem.cpu().numpy() for elem in self.trainset.cam_to_worlds], axis=0 + ) + K = self.trainset.intrinsics.to(device) + width = self.trainset.image_width + height = self.trainset.image_height if cfg.render_traj_path == "interp": camtoworlds_all = generate_interpolated_path( camtoworlds_all, 1 @@ -968,8 +1059,6 @@ def render_traj(self, step: int): ) # [N, 4, 4] camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) - K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) - width, height = list(self.parser.imsize_dict.values())[0] # save to video video_dir = f"{cfg.result_dir}/videos" @@ -1045,6 +1134,23 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): cfg.disable_viewer = True if world_rank == 0: print("Viewer is disabled in distributed training.") + if cfg.data_type == "blender": + print("Setting / overriding config settings for blender data") + print("Forcing init type to random") + cfg.init_type = "random" + print("Setting near and far to Blender data recommended settings (2 and 6, respectively)") + # Taken from nerfbaselines setting + cfg.near_plane = 2.0 + cfg.far_plane = 6.0 + print("Setting init_extent to 0.5") + # Taken from nerfbaselines setting + cfg.init_extent = 0.5 + + if cfg.render_traj_path == "spiral": + print( + "Spiral render trajectory is not supported for blender data, setting to interp instead" + ) + cfg.render_traj_path = "interp" runner = Runner(local_rank, world_rank, world_size, cfg)