From 0308931ea1934c2559c184083b8862d629acf76c Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sat, 5 Oct 2024 09:52:53 +0900 Subject: [PATCH 01/14] Add spherical render --- examples/datasets/opensfm.py | 304 +++++ examples/simple_trainer.py | 2 +- examples/simple_trainer_opensfm.py | 1094 +++++++++++++++++ gsplat/cuda/_torch_impl.py | 54 +- gsplat/cuda/_wrapper.py | 10 +- gsplat/cuda/csrc/bindings.h | 1 + gsplat/cuda/csrc/ext.cpp | 1 + .../cuda/csrc/fully_fused_projection_bwd.cu | 16 + .../cuda/csrc/fully_fused_projection_fwd.cu | 14 + .../csrc/fully_fused_projection_packed_bwd.cu | 16 + .../csrc/fully_fused_projection_packed_fwd.cu | 14 + gsplat/cuda/csrc/proj_bwd.cu | 16 + gsplat/cuda/csrc/proj_fwd.cu | 3 + gsplat/cuda/csrc/utils.cuh | 138 +++ gsplat/rendering.py | 2 +- 15 files changed, 1677 insertions(+), 8 deletions(-) create mode 100644 examples/datasets/opensfm.py create mode 100644 examples/simple_trainer_opensfm.py diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py new file mode 100644 index 000000000..b67e5efea --- /dev/null +++ b/examples/datasets/opensfm.py @@ -0,0 +1,304 @@ +import os +import numpy as np +import collections +import math +from pyproj import Proj +from typing import Dict, List, Any, Optional +import torch +from .normalize import ( + align_principle_axes, + similarity_from_cameras, + transform_cameras, + transform_points, +) + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params", "panorama"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids", "diff_ref"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +def angle_axis_to_quaternion(angle_axis: np.ndarray): + angle = np.linalg.norm(angle_axis) + + x = angle_axis[0] / angle + y = angle_axis[1] / angle + z = angle_axis[2] / angle + + qw = math.cos(angle / 2.0) + qx = x * math.sqrt(1 - qw * qw) + qy = y * math.sqrt(1 - qw * qw) + qz = z * math.sqrt(1 - qw * qw) + + return np.array([qw, qx, qy, qz]) + +def angle_axis_and_angle_to_quaternion(angle, axis): + half_angle = angle / 2.0 + sin_half_angle = math.sin(half_angle) + return np.array([ + math.cos(half_angle), + axis[0] * sin_half_angle, + axis[1] * sin_half_angle, + axis[2] * sin_half_angle + ]) + +def quaternion_multiply(q1, q2): + w1, x1, y1, z1 = q1[0], q1[1], q1[2], q1[3] + w2, x2, y2, z2 = q2[0], q2[1], q2[2], q2[3] + + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + + return np.array([w, x, y, z]) + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + +def _get_rel_paths(path_dir: str) -> List[str]: + """Recursively get relative paths of files in a directory.""" + paths = [] + for dp, dn, fn in os.walk(path_dir): + for f in fn: + paths.append(os.path.relpath(os.path.join(dp, f), path_dir)) + return paths + +class Parser: + """Parser for Opensfm data formatted similarly to the COLMAP parser.""" + + def __init__( + self, + reconstructions: List[Dict], + factor: int = 1, + normalize: bool = False, + ): + self.factor = factor + self.normalize = normalize + + # Extract data from reconstructions. + self._parse_reconstructions(reconstructions) + + def _parse_reconstructions(self, reconstructions: List[Dict]): + """Parse reconstructions data to extract camera information and extrinsics.""" + self.cameras, self.images = read_opensfm(reconstructions) + + # Extract extrinsic matrices in world-to-camera format. + w2c_mats = [] + camera_ids = [] + Ks_dict = dict() + params_dict = dict() + imsize_dict = dict() + bottom = np.array([0, 0, 0, 1]).reshape(1, 4) + + for img in self.images.values(): + # Extract rotation and translation vectors. + rot = img.qvec2rotmat() + trans = img.tvec.reshape(3, 1) + w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) + w2c_mats.append(w2c) + + # Camera intrinsics + cam = self.cameras[img.camera_id] + fx, fy, cx, cy = cam.params[0], cam.params[0], cam.params[1], cam.params[2] + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + K[:2, :] /= self.factor + Ks_dict[img.camera_id] = K + + # Distortion parameters + params_dict[img.camera_id] = cam.params[3:] + imsize_dict[img.camera_id] = (cam.width // self.factor, cam.height // self.factor) + camera_ids.append(img.camera_id) + + w2c_mats = np.stack(w2c_mats, axis=0) + + # Convert extrinsics to camera-to-world. + camtoworlds = np.linalg.inv(w2c_mats) + + # Normalize the world space if needed. + if self.normalize: + T1 = similarity_from_cameras(camtoworlds) + camtoworlds = transform_cameras(T1, camtoworlds) + + points = np.array([img.diff_ref for img in self.images.values()]) + T2 = align_principle_axes(points) + camtoworlds = transform_cameras(T2, camtoworlds) + transform = T2 @ T1 + else: + transform = np.eye(4) + + # Set instance variables. + self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.camera_ids = camera_ids # List[int], (num_images,) + self.Ks_dict = Ks_dict # Dict of camera_id -> K + self.params_dict = params_dict # Dict of camera_id -> params + self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) + self.transform = transform # np.ndarray, (4, 4) + + +class Dataset: + """A simple dataset class for OpensfmLoaderParser.""" + + def __init__( + self, + parser: Parser, + split: str = "train", + patch_size: Optional[int] = None, + load_depths: bool = False, + ): + self.parser = parser + self.split = split + self.patch_size = patch_size + self.load_depths = load_depths + indices = np.arange(len(self.parser.images)) + if split == "train": + self.indices = indices[indices % 8 != 0] + else: + self.indices = indices[indices % 8 == 0] + + def __len__(self): + return len(self.indices) + + def __getitem__(self, item: int) -> Dict[str, Any]: + index = self.indices[item] + img = self.parser.images[index] + camera_id = img.camera_id + K = self.parser.Ks_dict[camera_id].copy() # undistorted K + camtoworld = self.parser.camtoworlds[index] + + # Load image (dummy implementation, replace with actual image loading logic if available) + width, height = self.parser.imsize_dict[camera_id] + image = np.zeros((height, width, 3), dtype=np.uint8) # Placeholder for actual image loading + + data = { + "K": torch.from_numpy(K).float(), + "camtoworld": torch.from_numpy(camtoworld).float(), + "image": torch.from_numpy(image).float(), + "image_id": item, # the index of the image in the dataset + } + + if self.load_depths: + # Load depth data (dummy implementation, replace with actual depth loading logic if available) + depths = np.zeros((height, width), dtype=np.float32) # Placeholder for actual depth loading + data["depths"] = torch.from_numpy(depths).float() + + return data + + +def read_opensfm(reconstructions): + """Extracts camera and image information from OpenSfM reconstructions.""" + images = {} + i = 0 + reference_lat_0 = reconstructions[0]["reference_lla"]["latitude"] + reference_lon_0 = reconstructions[0]["reference_lla"]["longitude"] + reference_alt_0 = reconstructions[0]["reference_lla"]["altitude"] + e2u_zone = int(divmod(reference_lon_0, 6)[0]) + 31 + e2u_conv = Proj(proj='utm', zone=e2u_zone, ellps='WGS84') + reference_x_0, reference_y_0 = e2u_conv(reference_lon_0, reference_lat_0) + if reference_lat_0 < 0: + reference_y_0 += 10000000 + + cameras = {} + camera_names = {} + cam_id = 1 + + for reconstruction in reconstructions: + # Parse cameras. + for i, camera in enumerate(reconstruction["cameras"]): + camera_name = camera + camera_info = reconstruction["cameras"][camera] + if camera_info['projection_type'] in ['spherical', 'equirectangular']: + camera_id = 0 + model = "SPHERICAL" + width = reconstruction["cameras"][camera]["width"] + height = reconstruction["cameras"][camera]["height"] + f = width / 4 / 2 + params = np.array([f, width, height]) + cameras[camera_id] = Camera(id=camera_id, model=model, width=width, height=height, params=params, panorama=True) + camera_names[camera_name] = camera_id + elif reconstruction["cameras"][camera]['projection_type'] == "perspective": + model = "SIMPLE_PINHOLE" + width = reconstruction["cameras"][camera]["width"] + height = reconstruction["cameras"][camera]["height"] + f = reconstruction["cameras"][camera]["focal"] * width + k1 = reconstruction["cameras"][camera]["k1"] + k2 = reconstruction["cameras"][camera]["k2"] + params = np.array([f, width / 2, width / 2, k1, k2]) + camera_id = cam_id + cameras[camera_id] = Camera(id=camera_id, model=model, width=width, height=height, params=params, panorama=False) + camera_names[camera_name] = camera_id + cam_id += 1 + + # Parse images. + reference_lat = reconstruction["reference_lla"]["latitude"] + reference_lon = reconstruction["reference_lla"]["longitude"] + reference_alt = reconstruction["reference_lla"]["altitude"] + reference_x, reference_y = e2u_conv(reference_lon, reference_lat) + if reference_lat < 0: + reference_y += 10000000 + + for shot in reconstruction["shots"]: + translation = reconstruction["shots"][shot]["translation"] + rotation = reconstruction["shots"][shot]["rotation"] + qvec = angle_axis_to_quaternion(rotation) + diff_ref_x = reference_x - reference_x_0 + diff_ref_y = reference_y - reference_y_0 + diff_ref_alt = reference_alt - reference_alt_0 + tvec = np.array([translation[0], translation[1], translation[2]]) + diff_ref = np.array([diff_ref_x, diff_ref_y, diff_ref_alt]) + camera_name = reconstruction["shots"][shot]["camera"] + camera_id = camera_names.get(camera_name, 0) + image_id = i + image_name = shot + xys = np.array([0, 0]) + point3D_ids = np.array([0, 0]) + images[image_id] = Image(id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, diff_ref=diff_ref) + i += 1 + + return cameras, images \ No newline at end of file diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 93e70002f..6700deea6 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -69,7 +69,7 @@ class Config: # Normalize the world space normalize_world_space: bool = True # Camera model - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole" + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole" # Port for the viewer server port: int = 8080 diff --git a/examples/simple_trainer_opensfm.py b/examples/simple_trainer_opensfm.py new file mode 100644 index 000000000..3d2e20368 --- /dev/null +++ b/examples/simple_trainer_opensfm.py @@ -0,0 +1,1094 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import yaml +from datasets.opensfm import Dataset, Parser +from datasets.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal, assert_never +from utils import AppearanceOptModule, CameraOptModule, knn, rgb_to_sh, set_random_seed +from lib_bilagrid import ( + BilateralGrid, + slice, + color_correct, + total_variation_loss, +) + +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat.rendering import rasterization +from gsplat.strategy import DefaultStrategy, MCMCStrategy +from gsplat.optimizers import SelectiveAdam + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "sample/" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + # Camera model + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "spherical" + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: Union[DefaultStrategy, MCMCStrategy] = field( + default_factory=DefaultStrategy + ) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use visible adam from Taming 3DGS. (experimental) + visible_adam: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.0 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + if isinstance(strategy, DefaultStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.reset_every = int(strategy.reset_every * factor) + strategy.refine_every = int(strategy.refine_every * factor) + elif isinstance(strategy, MCMCStrategy): + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.refine_every = int(strategy.refine_every * factor) + else: + assert_never(strategy) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + visible_adam: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + rgbs = rgbs[world_rank::world_size] + scales = scales[world_rank::world_size] + + N = points.shape[0] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size + optimizer_class = None + if sparse_grad: + optimizer_class = torch.optim.SparseAdam + elif visible_adam: + optimizer_class = SelectiveAdam + else: + optimizer_class = torch.optim.Adam + optimizers = { + name: optimizer_class( + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + visible_adam=cfg.visible_adam, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + world_rank=world_rank, + world_size=world_size, + ) + print("Model initialized. Number of GS:", len(self.splats["means"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + if isinstance(self.cfg.strategy, DefaultStrategy): + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.strategy_state = self.cfg.strategy.initialize_state() + else: + assert_never(self.cfg.strategy) + + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + masks: Optional[Tensor] = None, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=( + self.cfg.strategy.absgrad + if isinstance(self.cfg.strategy, DefaultStrategy) + else False + ), + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + camera_model=self.cfg.camera_model, + **kwargs, + ) + if masks is not None: + render_colors[~masks] = 0 + return render_colors, render_alphas, info + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + if cfg.use_bilateral_grid: + # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. + schedulers.append( + torch.optim.lr_scheduler.ChainedScheduler( + [ + torch.optim.lr_scheduler.LinearLR( + self.bil_grid_optimizers[0], + start_factor=0.01, + total_iters=1000, + ), + torch.optim.lr_scheduler.ExponentialLR( + self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + masks = data["mask"].to(device) if "mask" in data else None # [1, H, W] + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + renders, alphas, info = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + masks=masks, + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.use_bilateral_grid: + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + if cfg.use_bilateral_grid: + tvloss = 10 * total_variation_loss(self.bil_grids.grids) + loss += tvloss + + # regularizations + if cfg.opacity_reg > 0.0: + loss = ( + loss + + cfg.opacity_reg + * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + ) + if cfg.scale_reg > 0.0: + loss = ( + loss + + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() + ) + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + # write images (gt and render) + # if world_rank == 0 and step % 800 == 0: + # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + # canvas = canvas.reshape(-1, *canvas.shape[2:]) + # imageio.imwrite( + # f"{self.render_dir}/train_rank{self.world_rank}.png", + # (canvas * 255).astype(np.uint8), + # ) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.use_bilateral_grid: + self.writer.add_scalar("train/tvloss", tvloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + if cfg.visible_adam: + gaussian_cnt = self.splats.means.shape[0] + if cfg.packed: + visibility_mask = torch.zeros_like( + self.splats["opacities"], dtype=bool + ) + visibility_mask.scatter_(0, info["gaussian_ids"], 1) + else: + visibility_mask = (info["radii"] > 0).any(0) + + # optimize + for optimizer in self.optimizers.values(): + if cfg.visible_adam: + optimizer.step(visibility_mask) + else: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.bil_grid_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # Run post-backward steps after backward and optimizer + if isinstance(self.cfg.strategy, DefaultStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + lr=schedulers[0].get_last_lr()[0], + ) + else: + assert_never(self.cfg.strategy) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + self.render_traj(step) + + # run compression + if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: + self.run_compression(step=step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + masks = data["mask"].to(device) if "mask" in data else None + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + masks=masks, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, + ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + for k in runner.splats.keys(): + runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) + step = ckpts[0]["step"] + runner.eval(step=step) + runner.render_traj(step=step) + if cfg.compression is not None: + runner.run_compression(step=step) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 + + """ + + # Config objects we can choose between. + # Each is a tuple of (CLI description, config object). + configs = { + "default": ( + "Gaussian splatting training using densification heuristics from the original paper.", + Config( + strategy=DefaultStrategy(verbose=True), + ), + ), + "mcmc": ( + "Gaussian splatting training using densification from the paper '3D Gaussian Splatting as Markov Chain Monte Carlo'.", + Config( + init_opa=0.5, + init_scale=0.1, + opacity_reg=0.01, + scale_reg=0.01, + strategy=MCMCStrategy(verbose=True), + ), + ), + } + cfg = tyro.extras.overridable_config_cli(configs) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 892c6a66f..25d229524 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -221,6 +221,56 @@ def _ortho_proj( ) # [C, N, 2] return means2d, cov2d # [C, N, 2], [C, N, 2, 2] +def _spherical_proj( + means: Tensor, # [C, N, 3] + covars: Tensor, # [C, N, 3, 3] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, +) -> Tuple[Tensor, Tensor]: + """PyTorch implementation of spherical projection for 3D Gaussians. + + Args: + means: Gaussian means in camera coordinate system. [C, N, 3]. + covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3]. + Ks: Camera intrinsics. [C, 3, 3]. + width: Image width. + height: Image height. + + Returns: + A tuple: + + - **means2d**: Projected means. [C, N, 2]. + - **cov2d**: Projected covariances. [C, N, 2, 2]. + """ + C, N, _ = means.shape + + tx, ty, tz = torch.unbind(means, dim=-1) # [C, N] + tr = torch.sqrt(tx**2 + ty**2 + tz**2) + + longitude = torch.atan2(tx, tz) + latitude = torch.atan2(ty, torch.sqrt(tx**2 + tz**2)) + + normalized_latitude = latitude / (torch.pi / 2.0) + normalized_longitude = longitude / torch.pi + + means2d = torch.stack([normalized_longitude, normalized_latitude], dim=-1) + + O = torch.zeros((C, N), device=means.device, dtype=means.dtype) + J = torch.stack( + [ + width / (2 * torch.pi) * tz / (tx**2 + tz**2), + O, + -width / (2 * torch.pi) * tx / (tx**2 + tz**2), + -height / torch.pi * (tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + height / torch.pi * torch.sqrt(tx**2 + tz**2) / tr**2, + -height / torch.pi * (tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + ], + dim=-1, + ).reshape(C, N, 2, 3) + + cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2)) + return means2d, cov2d # [C, N, 2], [C, N, 2, 2] def _world_to_cam( means: Tensor, # [N, 3] @@ -258,7 +308,7 @@ def _fully_fused_projection( near_plane: float = 0.01, far_plane: float = 1e10, calc_compensations: bool = False, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]: """PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()` @@ -275,6 +325,8 @@ def _fully_fused_projection( means2d, covars2d = _fisheye_proj(means_c, covars_c, Ks, width, height) elif camera_model == "pinhole": means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height) + elif camera_model == "spherical": + means2d, covars2d = _spherical_proj(means_c, covars_c, Ks, width, height) else: assert_never(camera_model) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 5a1065af0..fdd52520c 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -144,7 +144,7 @@ def proj( Ks: Tensor, # [C, 3, 3] width: int, height: int, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", ) -> Tuple[Tensor, Tensor]: """Projection of Gaussians (perspective or orthographic). @@ -216,7 +216,7 @@ def fully_fused_projection( packed: bool = False, sparse_grad: bool = False, calc_compensations: bool = False, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: """Projects Gaussians to 2D. @@ -695,7 +695,7 @@ def forward( Ks: Tensor, # [C, 3, 3] width: int, height: int, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", ) -> Tuple[Tensor, Tensor]: camera_model_type = _make_lazy_cuda_obj( f"CameraModelType.{camera_model.upper()}" @@ -791,7 +791,7 @@ def forward( far_plane: float, radius_clip: float, calc_compensations: bool, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: camera_model_type = _make_lazy_cuda_obj( f"CameraModelType.{camera_model.upper()}" @@ -1048,7 +1048,7 @@ def forward( radius_clip: float, sparse_grad: bool, calc_compensations: bool, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: camera_model_type = _make_lazy_cuda_obj( f"CameraModelType.{camera_model.upper()}" diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 66f7a2567..3b12ae83d 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -36,6 +36,7 @@ enum CameraModelType PINHOLE = 0, ORTHO = 1, FISHEYE = 2, + SPHERICAL = 3 }; std::tuple quat_scale_to_covar_preci_fwd_tensor( diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index c97ee7fda..4d6feab23 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -5,6 +5,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("PINHOLE", gsplat::CameraModelType::PINHOLE) .value("ORTHO", gsplat::CameraModelType::ORTHO) .value("FISHEYE", gsplat::CameraModelType::FISHEYE) + .value("SPHERICAL", gsplat::CameraModelType::SPHERICAL) .export_values(); m.def("compute_sh_fwd", &gsplat::compute_sh_fwd_tensor); diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index b5757ff40..b4cc3489c 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -177,6 +177,22 @@ __global__ void fully_fused_projection_bwd_kernel( v_covar_c ); break; + case CameraModelType::SPHERICAL: // spherical projection + spherical_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c + ); + break; } // add contribution from v_depths diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index c651e803d..85502a9f2 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -148,6 +148,20 @@ __global__ void fully_fused_projection_fwd_kernel( mean2d ); break; + case CameraModelType::SPHERICAL: // spherical projection + spherical_proj( + mean_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + covar2d, + mean2d + ); + break; } T compensation; diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index e5a0172fe..b1704d9d0 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -178,6 +178,22 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_covar_c ); break; + case CameraModelType::SPHERICAL: // spherical projection + spherical_proj_vjp( + mean_c, + covar_c, + fx, + fy, + cx, + cy, + image_width, + image_height, + v_covar2d, + glm::make_vec2(v_means2d), + v_mean_c, + v_covar_c + ); + break; } // add contribution from v_depths diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 4d8609f05..05e7f5c5f 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -163,6 +163,20 @@ __global__ void fully_fused_projection_packed_fwd_kernel( mean2d ); break; + case CameraModelType::SPHERICAL: // spherical projection + spherical_proj( + mean_c, + covar_c, + Ks[0], + Ks[4], + Ks[2], + Ks[5], + image_width, + image_height, + covar2d, + mean2d + ); + break; } det = add_blur(eps2d, covar2d, compensation); diff --git a/gsplat/cuda/csrc/proj_bwd.cu b/gsplat/cuda/csrc/proj_bwd.cu index 66557f679..d0bae5e8e 100644 --- a/gsplat/cuda/csrc/proj_bwd.cu +++ b/gsplat/cuda/csrc/proj_bwd.cu @@ -109,6 +109,22 @@ __global__ void proj_bwd_kernel( v_covar ); break; + case CameraModelType::SPHERICAL: // spherical projection + spherical_proj_vjp( + mean, + covar, + fx, + fy, + cx, + cy, + width, + height, + glm::transpose(v_covar2d), + v_mean2d, + v_mean, + v_covar + ); + break; } // write to outputs: glm is column-major but we want row-major diff --git a/gsplat/cuda/csrc/proj_fwd.cu b/gsplat/cuda/csrc/proj_fwd.cu index 861f60479..a7e44be93 100644 --- a/gsplat/cuda/csrc/proj_fwd.cu +++ b/gsplat/cuda/csrc/proj_fwd.cu @@ -63,6 +63,9 @@ __global__ void proj_fwd_kernel( case CameraModelType::FISHEYE: // fisheye projection fisheye_proj(mean, covar, fx, fy, cx, cy, width, height, covar2d, mean2d); break; + case CameraModelType::SPHERICAL: // spherical projection + spherical_proj(mean, covar, fx, fy, cx, cy, width, height, covar2d, mean2d); + break; } // write to outputs: glm is column-major but we want row-major diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index d6879ea6c..50b4cbf45 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -7,6 +7,7 @@ #include #define FILTER_INV_SQUARE 2.0f +#define M_PI 3.14159265358979323846 namespace gsplat { @@ -515,6 +516,143 @@ inline __device__ void fisheye_proj_vjp( v_mean3d.z += dL_dtz_raw; } +template +inline __device__ void spherical_proj( + // inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // outputs + mat2 &cov2d, + vec2 &mean2d +) { + T x = mean3d[0], y = mean3d[1], z = mean3d[2]; + + T r = sqrt(x * x + y * y + z * z); + + T longitude = atan2(x, z); + T latitude = atan2(y, sqrt(x * x + z * z)); + + T normalized_latitude = latitude / (M_PI / 2.0); + T normalized_longitude = longitude / M_PI; + + mean2d = vec2(normalized_longitude, normalized_latitude); + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + width / (2.f * M_PI) * z / (x * x + y * y), // 1st column + 0.f, + -width / (2.f * M_PI) * x / (x * x + z * z), + -height / M_PI * (x * y) / (r * r + sqrt(x * x + z * z)), // 2nd column + height / M_PI * (x * y) * sqrt(x * x + z * z) / (r * r), + -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) + ); + + cov2d = J * cov3d * glm::transpose(J); +} + +template +inline __device__ void spherical_proj_vjp( + // fwd inputs + const vec3 mean3d, + const mat3 cov3d, + const T fx, + const T fy, + const T cx, + const T cy, + const uint32_t width, + const uint32_t height, + // grad outputs + const mat2 v_cov2d, + const vec2 v_mean2d, + // grad inputs + vec3 &v_mean3d, + mat3 &v_cov3d +) { + T x = mean3d[0]; + T y = mean3d[1]; + T z = mean3d[2]; + + T r = sqrt(x * x + y * y + z * z); + T xz_norm = sqrt(x * x + z * z + 1e-8f); + + T longitude = atan2(x, z); + T latitude = atan2(y, xz_norm); + + T normalized_longitude = longitude / M_PI; + T normalized_latitude = latitude / (M_PI / 2.0); + + // mat3x2 is 3 columns x 2 rows. + mat3x2 J = mat3x2( + width / (2.f * M_PI) * z / (x * x + y * y), // 1st column + 0.f, + -width / (2.f * M_PI) * x / (x * x + z * z), + -height / M_PI * (x * y) / (r * r + sqrt(x * x + z * z)), // 2nd column + height / M_PI * (x * y) * sqrt(x * x + z * z) / (r * r), + -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) + ); + + v_cov3d += glm::transpose(J) * v_cov2d * J; + + // df/dx = d(normalized_longitude) / dx + // df/dy = d(normalized_latitude) / dy + // df/dz = d(normalized_longitude) / dz + d(normalized_latitude) / dz + T inv_r = 1.0 / r; + T inv_xz_norm = 1.0 / xz_norm; + + v_mean3d += vec3( + width / (2.f * M_PI) * z / (x * x + z * z) * v_mean2d[0], + height / M_PI * (x * y) / (r * r + xz_norm) * v_mean2d[1], + -width / (2.f * M_PI) * x / (x * x + z * z) * v_mean2d[0] + + -height / M_PI * (z * y) / (r * r + xz_norm) * v_mean2d[1] + ); + + // df/dx = d(J) / dx + // df/dy = d(J) / dy + // df/dz = d(J) / dz + mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + + glm::transpose(v_cov2d) * J * cov3d; + + T dJ_dx00 = width / (2.f * M_PI) * (z * z - x * x) / ((x * x + z * z) * (x * x + z * z)); + T dJ_dx01 = 0.f; + T dJ_dx02 = -width / (2.f * M_PI) * (2 * x * z) / ((x * x + z * z) * (x * x + z * z)); + T dJ_dx10 = height / M_PI * (y * (r * r + xz_norm) - x * y * xz_norm) / (r * r * (r * r + xz_norm) * (r * r + xz_norm)); + T dJ_dx11 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dx12 = -height / M_PI * (y * (r * r + xz_norm) - x * y * xz_norm) / (r * r * (r * r + xz_norm) * (r * r + xz_norm)); + + T dJ_dy00 = 0.f; + T dJ_dy01 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dy02 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dy10 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dy11 = height / M_PI * (y * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dy12 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); + + T dJ_dz00 = -width / (2.f * M_PI) * (2 * x * z) / ((x * x + z * z) * (x * x + z * z)); + T dJ_dz01 = 0.f; + T dJ_dz02 = width / (2.f * M_PI) * (x * x - z * z) / ((x * x + z * z) * (x * x + z * z)); + T dJ_dz10 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dz11 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); + T dJ_dz12 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); + + T dL_dtx_raw = dJ_dx00 * v_J[0][0] + dJ_dx01 * v_J[1][0] + + dJ_dx02 * v_J[2][0] + dJ_dx10 * v_J[0][1] + + dJ_dx11 * v_J[1][1] + dJ_dx12 * v_J[2][1]; + T dL_dty_raw = dJ_dy00 * v_J[0][0] + dJ_dy01 * v_J[1][0] + + dJ_dy02 * v_J[2][0] + dJ_dy10 * v_J[0][1] + + dJ_dy11 * v_J[1][1] + dJ_dy12 * v_J[2][1]; + T dL_dtz_raw = dJ_dz00 * v_J[0][0] + dJ_dz01 * v_J[1][0] + + dJ_dz02 * v_J[2][0] + dJ_dz10 * v_J[0][1] + + dJ_dz11 * v_J[1][1] + dJ_dz12 * v_J[2][1]; + v_mean3d.x += dL_dtx_raw; + v_mean3d.y += dL_dty_raw; + v_mean3d.z += dL_dtz_raw; +} + template inline __device__ void pos_world_to_cam( // [R, t] is the world-to-camera transformation diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 78da64abf..81d105977 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -49,7 +49,7 @@ def rasterization( rasterize_mode: Literal["classic", "antialiased"] = "classic", channel_chunk: int = 32, distributed: bool = False, - camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", + camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "pinhole", covars: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Dict]: """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). From b59b3b0ee77f4309ab759da6c5125bcc8b92ca11 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sat, 5 Oct 2024 18:34:33 +0900 Subject: [PATCH 02/14] Training started, not working well yet --- examples/datasets/opensfm.py | 51 +++++++++++++++++++++++++++--- examples/simple_trainer_opensfm.py | 20 +++++++----- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index b67e5efea..650ab450b 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -126,9 +126,10 @@ def __init__( self._parse_reconstructions(reconstructions) def _parse_reconstructions(self, reconstructions: List[Dict]): - """Parse reconstructions data to extract camera information and extrinsics.""" + """Parse reconstructions data to extract camera information, extrinsics, and 3D points.""" self.cameras, self.images = read_opensfm(reconstructions) - + self.points3D, self.colors, self.errors = read_opensfm_points3D(reconstructions) + # Extract extrinsic matrices in world-to-camera format. w2c_mats = [] camera_ids = [] @@ -180,6 +181,9 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): self.params_dict = params_dict # Dict of camera_id -> params self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) self.transform = transform # np.ndarray, (4, 4) + self.points = self.points3D # np.ndarray, (num_points, 3) + self.points_rgb = self.colors # np.ndarray, (num_points, 3) + self.points_err = self.errors # np.ndarray, (num_points, 1) class Dataset: @@ -300,5 +304,44 @@ def read_opensfm(reconstructions): point3D_ids = np.array([0, 0]) images[image_id] = Image(id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, diff_ref=diff_ref) i += 1 - - return cameras, images \ No newline at end of file + print("Number of cameras: ", len(cameras)) + print("Number of images: ", len(images)) + return cameras, images + +def read_opensfm_points3D(reconstructions): + xyzs = None + rgbs = None + errors = None + num_points = 0 + for reconstruction in reconstructions: + num_points = num_points + len(reconstruction["points"]) + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + count = 0 + reference_lat_0 = reconstructions[0]["reference_lla"]["latitude"] + reference_lon_0 = reconstructions[0]["reference_lla"]["longitude"] + reference_alt_0 = reconstructions[0]["reference_lla"]["altitude"] + e2u_zone=int(divmod(reference_lon_0, 6)[0])+31 + e2u_conv=Proj(proj='utm', zone=e2u_zone, ellps='WGS84') + reference_x_0, reference_y_0 = e2u_conv(reference_lon_0, reference_lat_0) + if reference_lat_0<0: + reference_y=reference_y+10000000 + for reconstruction in reconstructions: + reference_lat = reconstruction["reference_lla"]["latitude"] + reference_lon = reconstruction["reference_lla"]["longitude"] + reference_alt = reconstruction["reference_lla"]["altitude"] + reference_x, reference_y = e2u_conv(reference_lon, reference_lat) + for i in (reconstruction["points"]): + color = (reconstruction["points"][i]["color"]) + coordinates = (reconstruction["points"][i]["coordinates"]) + xyz = np.array([coordinates[0] + reference_x - reference_x_0, coordinates[1] + reference_y - reference_y_0, coordinates[2] - reference_alt + reference_alt_0]) + rgb = np.array([color[0], color[1], color[2]]) + error = np.array(0) + xyzs[count] = xyz + rgbs[count] = rgb + errors[count] = error + count += 1 + print("Number of points: ", num_points) + return xyzs, rgbs, errors \ No newline at end of file diff --git a/examples/simple_trainer_opensfm.py b/examples/simple_trainer_opensfm.py index 3d2e20368..03432d569 100644 --- a/examples/simple_trainer_opensfm.py +++ b/examples/simple_trainer_opensfm.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union +import json import imageio import nerfview @@ -55,7 +56,7 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "sample/" + data_dir: str = "/home/ubuntu/360-gaussian-splatting/kungsgatanparken" # Downsample factor for the dataset data_factor: int = 4 # Directory to save results @@ -106,7 +107,7 @@ class Config: # Near plane clipping distance near_plane: float = 0.01 # Far plane clipping distance - far_plane: float = 1e10 + far_plane: float = 100 # Strategy for GS densification strategy: Union[DefaultStrategy, MCMCStrategy] = field( @@ -183,7 +184,12 @@ def adjust_steps(self, factor: float): else: assert_never(strategy) - +def load_reconstructions(data_dir): + reconstructions_file = os.path.join(data_dir, 'reconstruction.json') + with open(reconstructions_file, 'r') as f: + reconstructions = json.load(f) + return reconstructions + def create_splats_with_optimizers( parser: Parser, init_type: str = "sfm", @@ -300,10 +306,9 @@ def __init__( # Load data: Training data should contain initial points and colors. self.parser = Parser( - data_dir=cfg.data_dir, + reconstructions=load_reconstructions(cfg.data_dir), factor=cfg.data_factor, - normalize=cfg.normalize_world_space, - test_every=cfg.test_every, + normalize=cfg.normalize_world_space ) self.trainset = Dataset( self.parser, @@ -312,9 +317,8 @@ def __init__( load_depths=cfg.depth_loss, ) self.valset = Dataset(self.parser, split="val") - self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + self.scene_scale = 1.0#self.parser.scene_scale * 1.1 * cfg.global_scale print("Scene scale:", self.scene_scale) - # Model feature_dim = 32 if cfg.app_opt else None self.splats, self.optimizers = create_splats_with_optimizers( From 17ab331c1ca30ef0b2ebe734397d8ed459c4837e Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sat, 5 Oct 2024 20:06:30 +0900 Subject: [PATCH 03/14] Image loading fixed --- examples/datasets/opensfm.py | 52 +++++++++++++++++++++++------- examples/simple_trainer_opensfm.py | 9 +----- 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index 650ab450b..aeebe2695 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -2,7 +2,9 @@ import numpy as np import collections import math +import json from pyproj import Proj +import imageio from typing import Dict, List, Any, Optional import torch from .normalize import ( @@ -115,14 +117,20 @@ class Parser: def __init__( self, - reconstructions: List[Dict], + data_dir: str, factor: int = 1, normalize: bool = False, + test_every: int = 8, ): + self.data_dir = data_dir self.factor = factor self.normalize = normalize - + self.test_every = test_every + # Extract data from reconstructions. + + reconstructions = self.load_reconstructions(data_dir) + self._parse_reconstructions(reconstructions) def _parse_reconstructions(self, reconstructions: List[Dict]): @@ -185,6 +193,11 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): self.points_rgb = self.colors # np.ndarray, (num_points, 3) self.points_err = self.errors # np.ndarray, (num_points, 1) + def load_reconstructions(self, data_dir): + reconstructions_file = os.path.join(data_dir, 'reconstruction.json') + with open(reconstructions_file, 'r') as f: + reconstructions = json.load(f) + return reconstructions class Dataset: """A simple dataset class for OpensfmLoaderParser.""" @@ -200,11 +213,11 @@ def __init__( self.split = split self.patch_size = patch_size self.load_depths = load_depths - indices = np.arange(len(self.parser.images)) + indices = np.arange(len(self.parser.images)) # Use images from parser if split == "train": - self.indices = indices[indices % 8 != 0] + self.indices = indices[indices % self.parser.test_every != 0] else: - self.indices = indices[indices % 8 == 0] + self.indices = indices[indices % self.parser.test_every == 0] def __len__(self): return len(self.indices) @@ -217,8 +230,28 @@ def __getitem__(self, item: int) -> Dict[str, Any]: camtoworld = self.parser.camtoworlds[index] # Load image (dummy implementation, replace with actual image loading logic if available) - width, height = self.parser.imsize_dict[camera_id] - image = np.zeros((height, width, 3), dtype=np.uint8) # Placeholder for actual image loading + image_path = os.path.join(self.parser.data_dir, "images/" + img.name) # Update with actual image path + image = imageio.imread(image_path)[..., :3] + + # Undistort if necessary + params = self.parser.params_dict[camera_id] + if len(params) > 0: + mapx, mapy = ( + self.parser.mapx_dict[camera_id], + self.parser.mapy_dict[camera_id], + ) + image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR) + x, y, w, h = self.parser.roi_undist_dict[camera_id] + image = image[y : y + h, x : x + w] + + if self.patch_size is not None: + # Random crop + h, w = image.shape[:2] + x = np.random.randint(0, max(w - self.patch_size, 1)) + y = np.random.randint(0, max(h - self.patch_size, 1)) + image = image[y : y + self.patch_size, x : x + self.patch_size] + K[0, 2] -= x + K[1, 2] -= y data = { "K": torch.from_numpy(K).float(), @@ -229,7 +262,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]: if self.load_depths: # Load depth data (dummy implementation, replace with actual depth loading logic if available) - depths = np.zeros((height, width), dtype=np.float32) # Placeholder for actual depth loading + depths = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) # Placeholder for actual depth loading data["depths"] = torch.from_numpy(depths).float() return data @@ -304,8 +337,6 @@ def read_opensfm(reconstructions): point3D_ids = np.array([0, 0]) images[image_id] = Image(id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, diff_ref=diff_ref) i += 1 - print("Number of cameras: ", len(cameras)) - print("Number of images: ", len(images)) return cameras, images def read_opensfm_points3D(reconstructions): @@ -343,5 +374,4 @@ def read_opensfm_points3D(reconstructions): rgbs[count] = rgb errors[count] = error count += 1 - print("Number of points: ", num_points) return xyzs, rgbs, errors \ No newline at end of file diff --git a/examples/simple_trainer_opensfm.py b/examples/simple_trainer_opensfm.py index 03432d569..01310ba99 100644 --- a/examples/simple_trainer_opensfm.py +++ b/examples/simple_trainer_opensfm.py @@ -5,7 +5,6 @@ from dataclasses import dataclass, field from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union -import json import imageio import nerfview @@ -184,12 +183,6 @@ def adjust_steps(self, factor: float): else: assert_never(strategy) -def load_reconstructions(data_dir): - reconstructions_file = os.path.join(data_dir, 'reconstruction.json') - with open(reconstructions_file, 'r') as f: - reconstructions = json.load(f) - return reconstructions - def create_splats_with_optimizers( parser: Parser, init_type: str = "sfm", @@ -306,7 +299,7 @@ def __init__( # Load data: Training data should contain initial points and colors. self.parser = Parser( - reconstructions=load_reconstructions(cfg.data_dir), + data_dir=cfg.data_dir, factor=cfg.data_factor, normalize=cfg.normalize_world_space ) From f912806d125007172681bd805632fba518a52a51 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sun, 6 Oct 2024 20:50:30 +0900 Subject: [PATCH 04/14] fixed opensfm.py --- examples/datasets/opensfm.py | 134 ++++++++++++++++++++++++++++++----- 1 file changed, 118 insertions(+), 16 deletions(-) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index aeebe2695..7af334ba7 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -1,9 +1,13 @@ import os import numpy as np import collections +from typing_extensions import assert_never + import math import json from pyproj import Proj + +import cv2 import imageio from typing import Dict, List, Any, Optional import torch @@ -137,13 +141,18 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): """Parse reconstructions data to extract camera information, extrinsics, and 3D points.""" self.cameras, self.images = read_opensfm(reconstructions) self.points3D, self.colors, self.errors = read_opensfm_points3D(reconstructions) - + self.points3D = self.points3D.astype(np.float32) + self.colors = self.colors.astype(np.uint8) + self.errors = self.errors.astype(np.float32) + # Extract extrinsic matrices in world-to-camera format. w2c_mats = [] camera_ids = [] Ks_dict = dict() params_dict = dict() imsize_dict = dict() + mask_dict = dict() + point_indices = {img.name: img.point3D_ids for img in self.images.values()} bottom = np.array([0, 0, 0, 1]).reshape(1, 4) for img in self.images.values(): @@ -153,18 +162,36 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) w2c_mats.append(w2c) + # support different camera intrinsics + camera_id = img.camera_id + camera_ids.append(camera_id) + # Camera intrinsics cam = self.cameras[img.camera_id] - fx, fy, cx, cy = cam.params[0], cam.params[0], cam.params[1], cam.params[2] - K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) - K[:2, :] /= self.factor - Ks_dict[img.camera_id] = K - - # Distortion parameters - params_dict[img.camera_id] = cam.params[3:] - imsize_dict[img.camera_id] = (cam.width // self.factor, cam.height // self.factor) - camera_ids.append(img.camera_id) - + type_ = cam.model + if type_ == 0 or type_ == "SIMPLE_PINHOLE": + params = np.empty(0, dtype=np.float32) + camtype = "perspective" + fx, fy, cx, cy = cam.params[0], cam.params[0], cam.params[1], cam.params[2] + K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + K[:2, :] /= self.factor + Ks_dict[img.camera_id] = K + + # Distortion parameters + params_dict[img.camera_id] = np.append(cam.params[3:5], np.array([0, 0])) + imsize_dict[img.camera_id] = (cam.width // self.factor, cam.height // self.factor) + mask_dict[camera_id] = None + camera_ids.append(img.camera_id) + elif type_ == 5 or type_ == "SPHERICAL": + params = np.empty(0, dtype=np.float32) + camtype = "spherical" + K = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + Ks_dict[img.camera_id] = K + params_dict[img.camera_id] = params + imsize_dict[img.camera_id] = (cam.width, cam.height) + mask_dict[camera_id] = None + camera_ids.append(img.camera_id) + w2c_mats = np.stack(w2c_mats, axis=0) # Convert extrinsics to camera-to-world. @@ -183,15 +210,93 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): transform = np.eye(4) # Set instance variables. + self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) self.camera_ids = camera_ids # List[int], (num_images,) self.Ks_dict = Ks_dict # Dict of camera_id -> K self.params_dict = params_dict # Dict of camera_id -> params self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) - self.transform = transform # np.ndarray, (4, 4) + self.mask_dict = mask_dict # Dict of camera_id -> mask self.points = self.points3D # np.ndarray, (num_points, 3) self.points_rgb = self.colors # np.ndarray, (num_points, 3) self.points_err = self.errors # np.ndarray, (num_points, 1) + self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] + self.transform = transform # np.ndarray, (4, 4) + + # undistortion + self.mapx_dict = dict() + self.mapy_dict = dict() + self.roi_undist_dict = dict() + for camera_id in self.params_dict.keys(): + params = self.params_dict[camera_id] + if len(params) == 0: + continue # no distortion + assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}" + assert ( + camera_id in self.params_dict + ), f"Missing params for camera {camera_id}" + K = self.Ks_dict[camera_id] + width, height = self.imsize_dict[camera_id] + + if camtype == "perspective": + K_undist, roi_undist = cv2.getOptimalNewCameraMatrix( + K, params, (width, height), 0 + ) + mapx, mapy = cv2.initUndistortRectifyMap( + K, params, None, K_undist, (width, height), cv2.CV_32FC1 + ) + mask = None + elif camtype == "fisheye": + fx = K[0, 0] + fy = K[1, 1] + cx = K[0, 2] + cy = K[1, 2] + grid_x, grid_y = np.meshgrid( + np.arange(width, dtype=np.float32), + np.arange(height, dtype=np.float32), + indexing="xy", + ) + x1 = (grid_x - cx) / fx + y1 = (grid_y - cy) / fy + theta = np.sqrt(x1**2 + y1**2) + r = ( + 1.0 + + params[0] * theta**2 + + params[1] * theta**4 + + params[2] * theta**6 + + params[3] * theta**8 + ) + mapx = fx * x1 * r + width // 2 + mapy = fy * y1 * r + height // 2 + + # Use mask to define ROI + mask = np.logical_and( + np.logical_and(mapx > 0, mapy > 0), + np.logical_and(mapx < width - 1, mapy < height - 1), + ) + y_indices, x_indices = np.nonzero(mask) + y_min, y_max = y_indices.min(), y_indices.max() + 1 + x_min, x_max = x_indices.min(), x_indices.max() + 1 + mask = mask[y_min:y_max, x_min:x_max] + K_undist = K.copy() + K_undist[0, 2] -= x_min + K_undist[1, 2] -= y_min + roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min] + else: + assert_never(camtype) + + self.mapx_dict[camera_id] = mapx + self.mapy_dict[camera_id] = mapy + self.Ks_dict[camera_id] = K_undist + self.roi_undist_dict[camera_id] = roi_undist + self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3]) + self.mask_dict[camera_id] = mask + + # size of the scene measured by cameras + camera_locations = camtoworlds[:, :3, 3] + scene_center = np.mean(camera_locations, axis=0) + dists = np.linalg.norm(camera_locations - scene_center, axis=1) + self.scene_scale = np.max(dists) def load_reconstructions(self, data_dir): reconstructions_file = os.path.join(data_dir, 'reconstruction.json') @@ -229,7 +334,6 @@ def __getitem__(self, item: int) -> Dict[str, Any]: K = self.parser.Ks_dict[camera_id].copy() # undistorted K camtoworld = self.parser.camtoworlds[index] - # Load image (dummy implementation, replace with actual image loading logic if available) image_path = os.path.join(self.parser.data_dir, "images/" + img.name) # Update with actual image path image = imageio.imread(image_path)[..., :3] @@ -267,7 +371,6 @@ def __getitem__(self, item: int) -> Dict[str, Any]: return data - def read_opensfm(reconstructions): """Extracts camera and image information from OpenSfM reconstructions.""" images = {} @@ -306,13 +409,12 @@ def read_opensfm(reconstructions): f = reconstruction["cameras"][camera]["focal"] * width k1 = reconstruction["cameras"][camera]["k1"] k2 = reconstruction["cameras"][camera]["k2"] - params = np.array([f, width / 2, width / 2, k1, k2]) + params = np.array([f, width / 2, height / 2, k1, k2]) camera_id = cam_id cameras[camera_id] = Camera(id=camera_id, model=model, width=width, height=height, params=params, panorama=False) camera_names[camera_name] = camera_id cam_id += 1 - # Parse images. reference_lat = reconstruction["reference_lla"]["latitude"] reference_lon = reconstruction["reference_lla"]["longitude"] reference_alt = reconstruction["reference_lla"]["altitude"] From 0716996a4a690017193784a0ce41765210126a61 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sun, 6 Oct 2024 21:17:25 +0900 Subject: [PATCH 05/14] fixed normalize --- examples/datasets/opensfm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index 7af334ba7..8820b8899 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -141,7 +141,7 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): """Parse reconstructions data to extract camera information, extrinsics, and 3D points.""" self.cameras, self.images = read_opensfm(reconstructions) self.points3D, self.colors, self.errors = read_opensfm_points3D(reconstructions) - self.points3D = self.points3D.astype(np.float32) + points = self.points3D.astype(np.float32) self.colors = self.colors.astype(np.uint8) self.errors = self.errors.astype(np.float32) @@ -201,10 +201,12 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): if self.normalize: T1 = similarity_from_cameras(camtoworlds) camtoworlds = transform_cameras(T1, camtoworlds) + points = transform_points(T1, points) - points = np.array([img.diff_ref for img in self.images.values()]) T2 = align_principle_axes(points) camtoworlds = transform_cameras(T2, camtoworlds) + points = transform_points(T2, points) + transform = T2 @ T1 else: transform = np.eye(4) @@ -217,7 +219,7 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): self.params_dict = params_dict # Dict of camera_id -> params self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height) self.mask_dict = mask_dict # Dict of camera_id -> mask - self.points = self.points3D # np.ndarray, (num_points, 3) + self.points = points # np.ndarray, (num_points, 3) self.points_rgb = self.colors # np.ndarray, (num_points, 3) self.points_err = self.errors # np.ndarray, (num_points, 1) self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,] From dfbc86a0452f9b1da0bc998d35d01aba9bb1efd4 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Mon, 7 Oct 2024 12:42:39 +0900 Subject: [PATCH 06/14] For image data_factor --- examples/datasets/opensfm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index 8820b8899..c4a03674a 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -339,6 +339,11 @@ def __getitem__(self, item: int) -> Dict[str, Any]: image_path = os.path.join(self.parser.data_dir, "images/" + img.name) # Update with actual image path image = imageio.imread(image_path)[..., :3] + # Resize image according to the factor + if self.parser.factor > 1: + new_size = (image.shape[1] // self.parser.factor, image.shape[0] // self.parser.factor) + image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA) + # Undistort if necessary params = self.parser.params_dict[camera_id] if len(params) > 0: From 9c9efec91340debd9cf35409d1f67c3b22ae8493 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Tue, 8 Oct 2024 00:46:53 +0900 Subject: [PATCH 07/14] Fixed mean2d --- gsplat/cuda/_torch_impl.py | 13 +++++++------ gsplat/cuda/csrc/utils.cuh | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 25d229524..7f68648a2 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -221,6 +221,7 @@ def _ortho_proj( ) # [C, N, 2] return means2d, cov2d # [C, N, 2], [C, N, 2, 2] + def _spherical_proj( means: Tensor, # [C, N, 3] covars: Tensor, # [C, N, 3, 3] @@ -254,17 +255,17 @@ def _spherical_proj( normalized_latitude = latitude / (torch.pi / 2.0) normalized_longitude = longitude / torch.pi - means2d = torch.stack([normalized_longitude, normalized_latitude], dim=-1) + means2d = torch.stack([(normalized_longitude * width + width / 2) / tr, (normalized_latitude * height + height / 2) / tr], dim=-1) O = torch.zeros((C, N), device=means.device, dtype=means.dtype) J = torch.stack( [ - width / (2 * torch.pi) * tz / (tx**2 + tz**2), + width / tr / (2 * torch.pi) * tz / (tx**2 + tz**2), O, - -width / (2 * torch.pi) * tx / (tx**2 + tz**2), - -height / torch.pi * (tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), - height / torch.pi * torch.sqrt(tx**2 + tz**2) / tr**2, - -height / torch.pi * (tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + -width / tr / (2 * torch.pi) * tx / (tx**2 + tz**2), + -height / tr / torch.pi * (tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + height / tr / torch.pi * torch.sqrt(tx**2 + tz**2) / tr**2, + -height / tr / torch.pi * (tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), ], dim=-1, ).reshape(C, N, 2, 3) diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 50b4cbf45..044a4709d 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -541,16 +541,16 @@ inline __device__ void spherical_proj( T normalized_latitude = latitude / (M_PI / 2.0); T normalized_longitude = longitude / M_PI; - mean2d = vec2(normalized_longitude, normalized_latitude); - + mean2d = vec2(normalized_longitude * width + width / 2, normalized_latitude * height + height / 2); + // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2( - width / (2.f * M_PI) * z / (x * x + y * y), // 1st column + width / (2.f * M_PI) * z / (x * x + z * z), + -height / M_PI * (x * y) / (r * r * sqrt(x * x + z * z)), // 1st column 0.f, + height / M_PI * sqrt(x * x + z * z) / (r * r), // 2st column -width / (2.f * M_PI) * x / (x * x + z * z), - -height / M_PI * (x * y) / (r * r + sqrt(x * x + z * z)), // 2nd column - height / M_PI * (x * y) * sqrt(x * x + z * z) / (r * r), - -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) + -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) // 1st column ); cov2d = J * cov3d * glm::transpose(J); @@ -589,12 +589,12 @@ inline __device__ void spherical_proj_vjp( // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2( - width / (2.f * M_PI) * z / (x * x + y * y), // 1st column + width / (2.f * M_PI) * z / (x * x + z * z), + -height / M_PI * (x * y) / (r * r * sqrt(x * x + z * z)), // 1st column 0.f, + height / M_PI * sqrt(x * x + z * z) / (r * r), // 2st column -width / (2.f * M_PI) * x / (x * x + z * z), - -height / M_PI * (x * y) / (r * r + sqrt(x * x + z * z)), // 2nd column - height / M_PI * (x * y) * sqrt(x * x + z * z) / (r * r), - -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) + -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) // 1st column ); v_cov3d += glm::transpose(J) * v_cov2d * J; From eb24ed49f451140de8c94bd1aa4f8681f755b1b4 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Thu, 10 Oct 2024 00:00:46 +0900 Subject: [PATCH 08/14] Data loading looks okay and need to check forward first. --- examples/datasets/opensfm.py | 29 ++++---- examples/simple_trainer_opensfm.py | 10 +-- gsplat/cuda/_torch_impl.py | 10 +-- .../cuda/csrc/fully_fused_projection_fwd.cu | 16 +++-- .../csrc/fully_fused_projection_packed_fwd.cu | 12 +++- gsplat/cuda/csrc/utils.cuh | 66 ++++--------------- 6 files changed, 60 insertions(+), 83 deletions(-) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index c4a03674a..c458cfb20 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -147,6 +147,8 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): # Extract extrinsic matrices in world-to-camera format. w2c_mats = [] + image_names = [] + image_paths = [] camera_ids = [] Ks_dict = dict() params_dict = dict() @@ -161,6 +163,8 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): trans = img.tvec.reshape(3, 1) w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) w2c_mats.append(w2c) + image_names.append(img.name) + image_paths.append(os.path.join(self.data_dir + "/images/", img.name)) # support different camera intrinsics camera_id = img.camera_id @@ -176,24 +180,20 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) K[:2, :] /= self.factor Ks_dict[img.camera_id] = K - - # Distortion parameters params_dict[img.camera_id] = np.append(cam.params[3:5], np.array([0, 0])) imsize_dict[img.camera_id] = (cam.width // self.factor, cam.height // self.factor) mask_dict[camera_id] = None - camera_ids.append(img.camera_id) elif type_ == 5 or type_ == "SPHERICAL": params = np.empty(0, dtype=np.float32) camtype = "spherical" - K = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + K = np.array([[cam.width // 8, 0, 0], [0, cam.height // 4, 0], [cam.width // 2, cam.height // 2, 1]]) Ks_dict[img.camera_id] = K params_dict[img.camera_id] = params - imsize_dict[img.camera_id] = (cam.width, cam.height) + imsize_dict[img.camera_id] = (cam.width // self.factor, cam.height // self.factor) mask_dict[camera_id] = None - camera_ids.append(img.camera_id) w2c_mats = np.stack(w2c_mats, axis=0) - + # Convert extrinsics to camera-to-world. camtoworlds = np.linalg.inv(w2c_mats) @@ -214,6 +214,9 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): # Set instance variables. self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4) + self.image_names = image_names # List[str], (num_images,) + + self.image_paths = image_paths # List[str], (num_images,) self.camera_ids = camera_ids # List[int], (num_images,) self.Ks_dict = Ks_dict # Dict of camera_id -> K self.params_dict = params_dict # Dict of camera_id -> params @@ -295,7 +298,7 @@ def _parse_reconstructions(self, reconstructions: List[Dict]): self.mask_dict[camera_id] = mask # size of the scene measured by cameras - camera_locations = camtoworlds[:, :3, 3] + camera_locations = np.array(camtoworlds)[:, :3, 3] scene_center = np.mean(camera_locations, axis=0) dists = np.linalg.norm(camera_locations - scene_center, axis=1) self.scene_scale = np.max(dists) @@ -332,12 +335,12 @@ def __len__(self): def __getitem__(self, item: int) -> Dict[str, Any]: index = self.indices[item] img = self.parser.images[index] + img_name = self.parser.images[index].name camera_id = img.camera_id K = self.parser.Ks_dict[camera_id].copy() # undistorted K camtoworld = self.parser.camtoworlds[index] - image_path = os.path.join(self.parser.data_dir, "images/" + img.name) # Update with actual image path - image = imageio.imread(image_path)[..., :3] + image = imageio.imread(self.parser.image_paths[index])[..., :3] # Resize image according to the factor if self.parser.factor > 1: @@ -405,8 +408,7 @@ def read_opensfm(reconstructions): model = "SPHERICAL" width = reconstruction["cameras"][camera]["width"] height = reconstruction["cameras"][camera]["height"] - f = width / 4 / 2 - params = np.array([f, width, height]) + params = np.array([0]) cameras[camera_id] = Camera(id=camera_id, model=model, width=width, height=height, params=params, panorama=True) camera_names[camera_name] = camera_id elif reconstruction["cameras"][camera]['projection_type'] == "perspective": @@ -427,8 +429,7 @@ def read_opensfm(reconstructions): reference_alt = reconstruction["reference_lla"]["altitude"] reference_x, reference_y = e2u_conv(reference_lon, reference_lat) if reference_lat < 0: - reference_y += 10000000 - + reference_y += 10000000 for shot in reconstruction["shots"]: translation = reconstruction["shots"][shot]["translation"] rotation = reconstruction["shots"][shot]["rotation"] diff --git a/examples/simple_trainer_opensfm.py b/examples/simple_trainer_opensfm.py index 01310ba99..b476cb955 100644 --- a/examples/simple_trainer_opensfm.py +++ b/examples/simple_trainer_opensfm.py @@ -55,7 +55,7 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/ubuntu/360-gaussian-splatting/kungsgatanparken" + data_dir: str = "/home/ubuntu/360-gaussian-splatting/sample" # Downsample factor for the dataset data_factor: int = 4 # Directory to save results @@ -67,7 +67,7 @@ class Config: # A global scaler that applies to the scene size related parameters global_scale: float = 1.0 # Normalize the world space - normalize_world_space: bool = True + normalize_world_space: bool = False # Camera model camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "spherical" @@ -106,7 +106,7 @@ class Config: # Near plane clipping distance near_plane: float = 0.01 # Far plane clipping distance - far_plane: float = 100 + far_plane: float = 1e8 # Strategy for GS densification strategy: Union[DefaultStrategy, MCMCStrategy] = field( @@ -310,7 +310,7 @@ def __init__( load_depths=cfg.depth_loss, ) self.valset = Dataset(self.parser, split="val") - self.scene_scale = 1.0#self.parser.scene_scale * 1.1 * cfg.global_scale + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale print("Scene scale:", self.scene_scale) # Model feature_dim = 32 if cfg.app_opt else None @@ -683,7 +683,7 @@ def train(self): # canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() # canvas = canvas.reshape(-1, *canvas.shape[2:]) # imageio.imwrite( - # f"{self.render_dir}/train_rank{self.world_rank}.png", + # f"{self.render_dir}/train_rank{self.world_rank}_{step}.png", # (canvas * 255).astype(np.uint8), # ) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 7f68648a2..bd32133c2 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -260,12 +260,12 @@ def _spherical_proj( O = torch.zeros((C, N), device=means.device, dtype=means.dtype) J = torch.stack( [ - width / tr / (2 * torch.pi) * tz / (tx**2 + tz**2), + width / (2 * torch.pi) * tz / (tx**2 + tz**2), O, - -width / tr / (2 * torch.pi) * tx / (tx**2 + tz**2), - -height / tr / torch.pi * (tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), - height / tr / torch.pi * torch.sqrt(tx**2 + tz**2) / tr**2, - -height / tr / torch.pi * (tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + -width / (2 * torch.pi) * tx / (tx**2 + tz**2), + -height / torch.pi * (tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + height / torch.pi * torch.sqrt(tx**2 + tz**2) / tr**2, + -height / torch.pi * (tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), ], dim=-1, ).reshape(C, N, 2, 3) diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 85502a9f2..36b502243 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -70,11 +70,19 @@ __global__ void fully_fused_projection_fwd_kernel( // transform Gaussian center to camera space vec3 mean_c; pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); - if (mean_c.z < near_plane || mean_c.z > far_plane) { - radii[idx] = 0; - return; + if(camera_model != CameraModelType::SPHERICAL){ + if (mean_c.z < near_plane || mean_c.z > far_plane) { + radii[idx] = 0; + return; + } + } + else{ + float r = sqrt(mean_c.x * mean_c.x + mean_c.y * mean_c.y + mean_c.z * mean_c.z); + if (r < near_plane || r > far_plane) { + radii[idx] = 0; + return; + } } - // transform Gaussian covariance to camera space mat3 covar; if (covars != nullptr) { diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 05e7f5c5f..8e3a2f8aa 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -80,8 +80,16 @@ __global__ void fully_fused_projection_packed_fwd_kernel( // transform Gaussian center to camera space pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); - if (mean_c.z < near_plane || mean_c.z > far_plane) { - valid = false; + if(camera_model != CameraModelType::SPHERICAL){ + if (mean_c.z < near_plane || mean_c.z > far_plane) { + valid = false; + } + } + else{ + float r = sqrt(mean_c.x * mean_c.x + mean_c.y * mean_c.y + mean_c.z * mean_c.z); + if (r < near_plane || r > far_plane) { + valid = false; + } } } diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 044a4709d..2f720b5f1 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -536,21 +536,21 @@ inline __device__ void spherical_proj( T r = sqrt(x * x + y * y + z * z); T longitude = atan2(x, z); - T latitude = atan2(y, sqrt(x * x + z * z)); + T latitude = asin(y / r); T normalized_latitude = latitude / (M_PI / 2.0); T normalized_longitude = longitude / M_PI; - mean2d = vec2(normalized_longitude * width + width / 2, normalized_latitude * height + height / 2); + mean2d = vec2((normalized_longitude + 1) * width / 2, (normalized_latitude + 1) * height / 2); // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2( - width / (2.f * M_PI) * z / (x * x + z * z), - -height / M_PI * (x * y) / (r * r * sqrt(x * x + z * z)), // 1st column + z / (x * x + z * z), + -(x * y) / (r * r * sqrt(x * x + z * z)), // 1st column 0.f, - height / M_PI * sqrt(x * x + z * z) / (r * r), // 2st column - -width / (2.f * M_PI) * x / (x * x + z * z), - -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) // 1st column + sqrt(x * x + z * z) / (r * r), // 2st column + - x / (x * x + z * z), + -(z * y) / (r * r + sqrt(x * x + z * z)) // 1st column ); cov2d = J * cov3d * glm::transpose(J); @@ -582,19 +582,19 @@ inline __device__ void spherical_proj_vjp( T xz_norm = sqrt(x * x + z * z + 1e-8f); T longitude = atan2(x, z); - T latitude = atan2(y, xz_norm); + T latitude = asin(y / r); T normalized_longitude = longitude / M_PI; T normalized_latitude = latitude / (M_PI / 2.0); // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2( - width / (2.f * M_PI) * z / (x * x + z * z), - -height / M_PI * (x * y) / (r * r * sqrt(x * x + z * z)), // 1st column + z / (x * x + z * z), + -(x * y) / (r * r * sqrt(x * x + z * z)), // 1st column 0.f, - height / M_PI * sqrt(x * x + z * z) / (r * r), // 2st column - -width / (2.f * M_PI) * x / (x * x + z * z), - -height / M_PI * (z * y) / (r * r + sqrt(x * x + z * z)) // 1st column + sqrt(x * x + z * z) / (r * r), // 2st column + - x / (x * x + z * z), + -(z * y) / (r * r + sqrt(x * x + z * z)) // 1st column ); v_cov3d += glm::transpose(J) * v_cov2d * J; @@ -611,46 +611,6 @@ inline __device__ void spherical_proj_vjp( -width / (2.f * M_PI) * x / (x * x + z * z) * v_mean2d[0] + -height / M_PI * (z * y) / (r * r + xz_norm) * v_mean2d[1] ); - - // df/dx = d(J) / dx - // df/dy = d(J) / dy - // df/dz = d(J) / dz - mat3x2 v_J = v_cov2d * J * glm::transpose(cov3d) + - glm::transpose(v_cov2d) * J * cov3d; - - T dJ_dx00 = width / (2.f * M_PI) * (z * z - x * x) / ((x * x + z * z) * (x * x + z * z)); - T dJ_dx01 = 0.f; - T dJ_dx02 = -width / (2.f * M_PI) * (2 * x * z) / ((x * x + z * z) * (x * x + z * z)); - T dJ_dx10 = height / M_PI * (y * (r * r + xz_norm) - x * y * xz_norm) / (r * r * (r * r + xz_norm) * (r * r + xz_norm)); - T dJ_dx11 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dx12 = -height / M_PI * (y * (r * r + xz_norm) - x * y * xz_norm) / (r * r * (r * r + xz_norm) * (r * r + xz_norm)); - - T dJ_dy00 = 0.f; - T dJ_dy01 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dy02 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dy10 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dy11 = height / M_PI * (y * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dy12 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); - - T dJ_dz00 = -width / (2.f * M_PI) * (2 * x * z) / ((x * x + z * z) * (x * x + z * z)); - T dJ_dz01 = 0.f; - T dJ_dz02 = width / (2.f * M_PI) * (x * x - z * z) / ((x * x + z * z) * (x * x + z * z)); - T dJ_dz10 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dz11 = height / M_PI * (x * y * xz_norm) / (r * r * (r * r + xz_norm)); - T dJ_dz12 = -height / M_PI * (z * y * xz_norm) / (r * r * (r * r + xz_norm)); - - T dL_dtx_raw = dJ_dx00 * v_J[0][0] + dJ_dx01 * v_J[1][0] + - dJ_dx02 * v_J[2][0] + dJ_dx10 * v_J[0][1] + - dJ_dx11 * v_J[1][1] + dJ_dx12 * v_J[2][1]; - T dL_dty_raw = dJ_dy00 * v_J[0][0] + dJ_dy01 * v_J[1][0] + - dJ_dy02 * v_J[2][0] + dJ_dy10 * v_J[0][1] + - dJ_dy11 * v_J[1][1] + dJ_dy12 * v_J[2][1]; - T dL_dtz_raw = dJ_dz00 * v_J[0][0] + dJ_dz01 * v_J[1][0] + - dJ_dz02 * v_J[2][0] + dJ_dz10 * v_J[0][1] + - dJ_dz11 * v_J[1][1] + dJ_dz12 * v_J[2][1]; - v_mean3d.x += dL_dtx_raw; - v_mean3d.y += dL_dty_raw; - v_mean3d.z += dL_dtz_raw; } template From 0f02ab9fe96ecaa3946ebf57cebabd1e04b20294 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Fri, 11 Oct 2024 07:07:23 +0900 Subject: [PATCH 09/14] Forward pass is working. --- gsplat/cuda/_torch_impl.py | 12 +++++------ .../cuda/csrc/fully_fused_projection_fwd.cu | 21 +++++++++++-------- .../csrc/fully_fused_projection_packed_fwd.cu | 8 ++++--- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index bd32133c2..cac623039 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -255,17 +255,17 @@ def _spherical_proj( normalized_latitude = latitude / (torch.pi / 2.0) normalized_longitude = longitude / torch.pi - means2d = torch.stack([(normalized_longitude * width + width / 2) / tr, (normalized_latitude * height + height / 2) / tr], dim=-1) + means2d = torch.stack([(normalized_longitude + 1) * width / 2, (normalized_latitude + 1) * height / 2], dim=-1) O = torch.zeros((C, N), device=means.device, dtype=means.dtype) J = torch.stack( [ - width / (2 * torch.pi) * tz / (tx**2 + tz**2), + tz / (tx**2 + tz**2), + -(tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), O, - -width / (2 * torch.pi) * tx / (tx**2 + tz**2), - -height / torch.pi * (tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), - height / torch.pi * torch.sqrt(tx**2 + tz**2) / tr**2, - -height / torch.pi * (tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + torch.sqrt(tx**2 + tz**2) / (tr**2), + -tx / (tx**2 + tz**2), + -(tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), ], dim=-1, ).reshape(C, N, 2, 3) diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 36b502243..7909dc829 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -174,10 +174,6 @@ __global__ void fully_fused_projection_fwd_kernel( T compensation; T det = add_blur(eps2d, covar2d, compensation); - if (det <= 0.f) { - radii[idx] = 0; - return; - } // compute the inverse of the 2d covariance mat2 covar2d_inv; @@ -196,17 +192,24 @@ __global__ void fully_fused_projection_fwd_kernel( } // mask out gaussians outside the image region - if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || - mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { - radii[idx] = 0; - return; + if(camera_model != CameraModelType::SPHERICAL) + { + if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || + mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { + radii[idx] = 0; + return; + } } // write to outputs radii[idx] = (int32_t)radius; means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; - depths[idx] = mean_c.z; + if (camera_model == CameraModelType::SPHERICAL) { + depths[idx] = sqrt(mean_c.x * mean_c.x + mean_c.y * mean_c.y + mean_c.z * mean_c.z); + } else { + depths[idx] = mean_c.z; + } conics[idx * 3] = covar2d_inv[0][0]; conics[idx * 3 + 1] = covar2d_inv[0][1]; conics[idx * 3 + 2] = covar2d_inv[1][1]; diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 8e3a2f8aa..8f0abb4da 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -210,9 +210,11 @@ __global__ void fully_fused_projection_packed_fwd_kernel( } // mask out gaussians outside the image region - if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || - mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { - valid = false; + if(camera_model != CameraModelType::SPHERICAL){ + if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || + mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { + valid = false; + } } } From b169c703c4aa04250c12f5526c019dcf1170448c Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sat, 12 Oct 2024 13:55:11 +0900 Subject: [PATCH 10/14] spherical camera training and rendering is working now. --- examples/datasets/opensfm.py | 5 ++-- gsplat/cuda/csrc/utils.cuh | 57 +++++++++++++++--------------------- 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/examples/datasets/opensfm.py b/examples/datasets/opensfm.py index c458cfb20..fb5afbc8e 100644 --- a/examples/datasets/opensfm.py +++ b/examples/datasets/opensfm.py @@ -430,7 +430,7 @@ def read_opensfm(reconstructions): reference_x, reference_y = e2u_conv(reference_lon, reference_lat) if reference_lat < 0: reference_y += 10000000 - for shot in reconstruction["shots"]: + for j, shot in enumerate(reconstruction["shots"]): translation = reconstruction["shots"][shot]["translation"] rotation = reconstruction["shots"][shot]["rotation"] qvec = angle_axis_to_quaternion(rotation) @@ -441,12 +441,11 @@ def read_opensfm(reconstructions): diff_ref = np.array([diff_ref_x, diff_ref_y, diff_ref_alt]) camera_name = reconstruction["shots"][shot]["camera"] camera_id = camera_names.get(camera_name, 0) - image_id = i + image_id = j image_name = shot xys = np.array([0, 0]) point3D_ids = np.array([0, 0]) images[image_id] = Image(id=image_id, qvec=qvec, tvec=tvec, camera_id=camera_id, name=image_name, xys=xys, point3D_ids=point3D_ids, diff_ref=diff_ref) - i += 1 return cameras, images def read_opensfm_points3D(reconstructions): diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 2f720b5f1..1b392a4b3 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -535,6 +535,8 @@ inline __device__ void spherical_proj( T r = sqrt(x * x + y * y + z * z); + T xz_norm = sqrt(x * x + z * z + 1e-8f); + T longitude = atan2(x, z); T latitude = asin(y / r); @@ -543,14 +545,16 @@ inline __device__ void spherical_proj( mean2d = vec2((normalized_longitude + 1) * width / 2, (normalized_latitude + 1) * height / 2); + T denom_xz = x * x + z * z + 1e-8f; + T denom_r2 = r * r + 1e-8f; // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2( - z / (x * x + z * z), - -(x * y) / (r * r * sqrt(x * x + z * z)), // 1st column + width / (2 * M_PI) * (z / denom_xz), + height / M_PI * (- (x * y) / (denom_r2 * xz_norm)), // 1st column 0.f, - sqrt(x * x + z * z) / (r * r), // 2st column - - x / (x * x + z * z), - -(z * y) / (r * r + sqrt(x * x + z * z)) // 1st column + height / M_PI * (xz_norm / denom_r2), // 2nd column + width / (2 * M_PI) * (-x / denom_xz), + height / M_PI * (- (z * y) / (denom_r2 * xz_norm)) // 3rd column ); cov2d = J * cov3d * glm::transpose(J); @@ -568,51 +572,36 @@ inline __device__ void spherical_proj_vjp( const uint32_t width, const uint32_t height, // grad outputs - const mat2 v_cov2d, - const vec2 v_mean2d, + const glm::mat<2, 2, T> v_cov2d, + const glm::vec<2, T> v_mean2d, // grad inputs - vec3 &v_mean3d, - mat3 &v_cov3d + glm::vec<3, T> &v_mean3d, + glm::mat<3, 3, T> &v_cov3d ) { T x = mean3d[0]; T y = mean3d[1]; T z = mean3d[2]; - T r = sqrt(x * x + y * y + z * z); + T r = sqrt(x * x + y * y + z * z + 1e-8f); T xz_norm = sqrt(x * x + z * z + 1e-8f); - T longitude = atan2(x, z); - T latitude = asin(y / r); - - T normalized_longitude = longitude / M_PI; - T normalized_latitude = latitude / (M_PI / 2.0); - + T denom_xz = x * x + z * z + 1e-8f; + T denom_r2 = r * r + 1e-8f; // mat3x2 is 3 columns x 2 rows. mat3x2 J = mat3x2( - z / (x * x + z * z), - -(x * y) / (r * r * sqrt(x * x + z * z)), // 1st column + width / (2 * M_PI) * (z / denom_xz), + height / M_PI * (- (x * y) / (denom_r2 * xz_norm)), // 1st column 0.f, - sqrt(x * x + z * z) / (r * r), // 2st column - - x / (x * x + z * z), - -(z * y) / (r * r + sqrt(x * x + z * z)) // 1st column + height / M_PI * (xz_norm / denom_r2), // 2nd column + width / (2 * M_PI) * (-x / denom_xz), + height / M_PI * (- (z * y) / (denom_r2 * xz_norm)) // 3rd column ); + v_mean3d += glm::transpose(J) * v_mean2d; v_cov3d += glm::transpose(J) * v_cov2d * J; - - // df/dx = d(normalized_longitude) / dx - // df/dy = d(normalized_latitude) / dy - // df/dz = d(normalized_longitude) / dz + d(normalized_latitude) / dz - T inv_r = 1.0 / r; - T inv_xz_norm = 1.0 / xz_norm; - - v_mean3d += vec3( - width / (2.f * M_PI) * z / (x * x + z * z) * v_mean2d[0], - height / M_PI * (x * y) / (r * r + xz_norm) * v_mean2d[1], - -width / (2.f * M_PI) * x / (x * x + z * z) * v_mean2d[0] + - -height / M_PI * (z * y) / (r * r + xz_norm) * v_mean2d[1] - ); } + template inline __device__ void pos_world_to_cam( // [R, t] is the world-to-camera transformation From 61d98a6895c6b1088249b79fdc0cb5971a3d5818 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sat, 12 Oct 2024 14:17:09 +0900 Subject: [PATCH 11/14] update some param --- examples/simple_trainer_opensfm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/simple_trainer_opensfm.py b/examples/simple_trainer_opensfm.py index b476cb955..c1be18fcd 100644 --- a/examples/simple_trainer_opensfm.py +++ b/examples/simple_trainer_opensfm.py @@ -49,13 +49,14 @@ class Config: disable_viewer: bool = False # Path to the .pt files. If provide, it will skip training and run evaluation only. ckpt: Optional[List[str]] = None + # Name of compression strategy to use compression: Optional[Literal["png"]] = None # Render trajectory path render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/ubuntu/360-gaussian-splatting/sample" + data_dir: str = "data_dir" # Downsample factor for the dataset data_factor: int = 4 # Directory to save results @@ -67,7 +68,7 @@ class Config: # A global scaler that applies to the scene size related parameters global_scale: float = 1.0 # Normalize the world space - normalize_world_space: bool = False + normalize_world_space: bool = True # Camera model camera_model: Literal["pinhole", "ortho", "fisheye", "spherical"] = "spherical" From c85f4233f972ee73e958088a9aa48c83ae9f1759 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sat, 12 Oct 2024 14:38:43 +0900 Subject: [PATCH 12/14] Refactored if to be consistent with other files --- gsplat/cuda/csrc/fully_fused_projection_fwd.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 7909dc829..6b75cee60 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -205,10 +205,11 @@ __global__ void fully_fused_projection_fwd_kernel( radii[idx] = (int32_t)radius; means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; - if (camera_model == CameraModelType::SPHERICAL) { + if (camera_model != CameraModelType::SPHERICAL) { + depths[idx] = mean_c.z; depths[idx] = sqrt(mean_c.x * mean_c.x + mean_c.y * mean_c.y + mean_c.z * mean_c.z); } else { - depths[idx] = mean_c.z; + depths[idx] = sqrt(mean_c.x * mean_c.x + mean_c.y * mean_c.y + mean_c.z * mean_c.z); } conics[idx * 3] = covar2d_inv[0][0]; conics[idx * 3 + 1] = covar2d_inv[0][1]; From 1474c16bf9461275df2fd441791127ef378487d6 Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sun, 13 Oct 2024 00:52:42 +0900 Subject: [PATCH 13/14] Fix _torch_impl.py --- gsplat/cuda/_torch_impl.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index cac623039..84630fa46 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -248,24 +248,26 @@ def _spherical_proj( tx, ty, tz = torch.unbind(means, dim=-1) # [C, N] tr = torch.sqrt(tx**2 + ty**2 + tz**2) + xz_norm = torch.sqrt(tx * tx + tz * tz + 1e-8) + denom_xz = tx * tx + tz * tz + 1e-8 + denom_r2 = tr * tr + 1e-8 longitude = torch.atan2(tx, tz) - latitude = torch.atan2(ty, torch.sqrt(tx**2 + tz**2)) + latitude = torch.atan2(ty, xz_norm) normalized_latitude = latitude / (torch.pi / 2.0) normalized_longitude = longitude / torch.pi means2d = torch.stack([(normalized_longitude + 1) * width / 2, (normalized_latitude + 1) * height / 2], dim=-1) - O = torch.zeros((C, N), device=means.device, dtype=means.dtype) J = torch.stack( [ - tz / (tx**2 + tz**2), - -(tx * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + width / (2 * torch.pi) * (tz / denom_xz), + height / torch.pi * -(tx * ty) / (denom_r2 * xz_norm), O, - torch.sqrt(tx**2 + tz**2) / (tr**2), - -tx / (tx**2 + tz**2), - -(tz * ty) / (tr**2 * torch.sqrt(tx**2 + tz**2)), + height / torch.pi * xz_norm / denom_r2, + width / (2 * torch.pi) * -tx / denom_xz, + height / torch.pi * -(tz * ty) / (denom_r2 * xz_norm) ], dim=-1, ).reshape(C, N, 2, 3) From b0e978da67fb4364611c6683c5f4e6e6c1d8d8cb Mon Sep 17 00:00:00 2001 From: inuex35 Date: Sun, 13 Oct 2024 11:57:48 +0900 Subject: [PATCH 14/14] delete camera params from spherical render --- gsplat/cuda/_torch_impl.py | 4 +--- gsplat/cuda/csrc/fully_fused_projection_bwd.cu | 4 ---- gsplat/cuda/csrc/fully_fused_projection_fwd.cu | 4 ---- gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu | 4 ---- gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu | 4 ---- gsplat/cuda/csrc/proj_bwd.cu | 4 ---- gsplat/cuda/csrc/proj_fwd.cu | 2 +- gsplat/cuda/csrc/utils.cuh | 8 -------- 8 files changed, 2 insertions(+), 32 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 84630fa46..dfc5c7251 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -225,7 +225,6 @@ def _ortho_proj( def _spherical_proj( means: Tensor, # [C, N, 3] covars: Tensor, # [C, N, 3, 3] - Ks: Tensor, # [C, 3, 3] width: int, height: int, ) -> Tuple[Tensor, Tensor]: @@ -234,7 +233,6 @@ def _spherical_proj( Args: means: Gaussian means in camera coordinate system. [C, N, 3]. covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3]. - Ks: Camera intrinsics. [C, 3, 3]. width: Image width. height: Image height. @@ -329,7 +327,7 @@ def _fully_fused_projection( elif camera_model == "pinhole": means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height) elif camera_model == "spherical": - means2d, covars2d = _spherical_proj(means_c, covars_c, Ks, width, height) + means2d, covars2d = _spherical_proj(means_c, covars_c, width, height) else: assert_never(camera_model) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index b4cc3489c..9265bcf6f 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -181,10 +181,6 @@ __global__ void fully_fused_projection_bwd_kernel( spherical_proj_vjp( mean_c, covar_c, - fx, - fy, - cx, - cy, image_width, image_height, v_covar2d, diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 6b75cee60..dc0b4579e 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -160,10 +160,6 @@ __global__ void fully_fused_projection_fwd_kernel( spherical_proj( mean_c, covar_c, - Ks[0], - Ks[4], - Ks[2], - Ks[5], image_width, image_height, covar2d, diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index b1704d9d0..fe3c32a77 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -182,10 +182,6 @@ __global__ void fully_fused_projection_packed_bwd_kernel( spherical_proj_vjp( mean_c, covar_c, - fx, - fy, - cx, - cy, image_width, image_height, v_covar2d, diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 8f0abb4da..5a40b3446 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -175,10 +175,6 @@ __global__ void fully_fused_projection_packed_fwd_kernel( spherical_proj( mean_c, covar_c, - Ks[0], - Ks[4], - Ks[2], - Ks[5], image_width, image_height, covar2d, diff --git a/gsplat/cuda/csrc/proj_bwd.cu b/gsplat/cuda/csrc/proj_bwd.cu index d0bae5e8e..f08530922 100644 --- a/gsplat/cuda/csrc/proj_bwd.cu +++ b/gsplat/cuda/csrc/proj_bwd.cu @@ -113,10 +113,6 @@ __global__ void proj_bwd_kernel( spherical_proj_vjp( mean, covar, - fx, - fy, - cx, - cy, width, height, glm::transpose(v_covar2d), diff --git a/gsplat/cuda/csrc/proj_fwd.cu b/gsplat/cuda/csrc/proj_fwd.cu index a7e44be93..e978c7bbb 100644 --- a/gsplat/cuda/csrc/proj_fwd.cu +++ b/gsplat/cuda/csrc/proj_fwd.cu @@ -64,7 +64,7 @@ __global__ void proj_fwd_kernel( fisheye_proj(mean, covar, fx, fy, cx, cy, width, height, covar2d, mean2d); break; case CameraModelType::SPHERICAL: // spherical projection - spherical_proj(mean, covar, fx, fy, cx, cy, width, height, covar2d, mean2d); + spherical_proj(mean, covar, width, height, covar2d, mean2d); break; } diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 1b392a4b3..a9c294383 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -521,10 +521,6 @@ inline __device__ void spherical_proj( // inputs const vec3 mean3d, const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, const uint32_t width, const uint32_t height, // outputs @@ -565,10 +561,6 @@ inline __device__ void spherical_proj_vjp( // fwd inputs const vec3 mean3d, const mat3 cov3d, - const T fx, - const T fy, - const T cx, - const T cy, const uint32_t width, const uint32_t height, // grad outputs