Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distortion loss impl. #244

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from typing import Dict, List, Optional, Tuple

import imageio
import nerfview
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import tyro
import viser
import nerfview
from datasets.colmap import Dataset, Parser
from datasets.traj import generate_interpolated_path
from torch import Tensor
Expand Down Expand Up @@ -141,6 +141,11 @@ class Config:
# Weight for depth loss
depth_lambda: float = 1e-2

# Distoration loss. (experimental)
dist_loss: bool = False
# Weight for distortion loss
dist_lambda: float = 1e-3

# Dump information to tensorboard every this steps
tb_every: int = 100
# Save training images to tensorboard
Expand Down Expand Up @@ -472,7 +477,8 @@ def train(self):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
render_mode="RGB+ED" if (cfg.depth_loss or cfg.dist_loss) else "RGB",
distloss=self.cfg.dist_loss,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
Expand Down Expand Up @@ -510,12 +516,17 @@ def train(self):
disp_gt = 1.0 / depths_gt # [1, M]
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
loss += depthloss * cfg.depth_lambda
if cfg.dist_loss:
distloss = info["render_distloss"].mean()
loss += distloss * cfg.dist_lambda

loss.backward()

desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
if cfg.depth_loss:
desc += f"depth loss={depthloss.item():.6f}| "
if cfg.dist_loss:
desc += f"dist loss={distloss.item():.6f}| "
if cfg.pose_opt and cfg.pose_noise:
# monitor the pose error if we inject noise
pose_err = F.l1_loss(camtoworlds_gt, camtoworlds)
Expand All @@ -533,6 +544,8 @@ def train(self):
self.writer.add_scalar("train/mem", mem, step)
if cfg.depth_loss:
self.writer.add_scalar("train/depthloss", depthloss.item(), step)
if cfg.dist_loss:
self.writer.add_scalar("train/distloss", distloss.item(), step)
if cfg.tb_save_image:
canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy()
canvas = canvas.reshape(-1, *canvas.shape[2:])
Expand Down
27 changes: 17 additions & 10 deletions examples/simple_trainer_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,25 @@
from typing import Dict, List, Optional, Tuple

import imageio
import nerfview
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
import tyro
import viser
import nerfview
from datasets.colmap import Dataset, Parser
from datasets.traj import generate_interpolated_path
from simple_trainer import create_splats_with_optimizers
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from utils import (
AppearanceOptModule,
CameraOptModule,
set_random_seed,
)
from utils import AppearanceOptModule, CameraOptModule, set_random_seed

from gsplat import quat_scale_to_covar_preci
from gsplat.rendering import rasterization
from gsplat.relocation import compute_relocation
from gsplat.cuda_legacy._torch_impl import scale_rot_to_cov3d
from simple_trainer import create_splats_with_optimizers
from gsplat.rendering import rasterization


@dataclass
Expand Down Expand Up @@ -139,6 +135,11 @@ class Config:
# Weight for depth loss
depth_lambda: float = 1e-2

# Distoration loss. (experimental)
dist_loss: bool = False
# Weight for distortion loss
dist_lambda: float = 1e-3

# Dump information to tensorboard every this steps
tb_every: int = 100
# Save training images to tensorboard
Expand Down Expand Up @@ -394,7 +395,8 @@ def train(self):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
render_mode="RGB+ED" if (cfg.depth_loss or cfg.dist_loss) else "RGB",
distloss=self.cfg.dist_loss,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
Expand Down Expand Up @@ -432,6 +434,9 @@ def train(self):
disp_gt = 1.0 / depths_gt # [1, M]
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
loss += depthloss * cfg.depth_lambda
if cfg.dist_loss:
distloss = info["render_distloss"].mean()
loss += distloss * cfg.dist_lambda

loss = (
loss
Expand All @@ -448,6 +453,8 @@ def train(self):
desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
if cfg.depth_loss:
desc += f"depth loss={depthloss.item():.6f}| "
if cfg.dist_loss:
desc += f"dist loss={distloss.item():.6f}| "
if cfg.pose_opt and cfg.pose_noise:
# monitor the pose error if we inject noise
pose_err = F.l1_loss(camtoworlds_gt, camtoworlds)
Expand Down
42 changes: 35 additions & 7 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def accumulate(
camera_ids: Tensor, # [M]
image_width: int,
image_height: int,
) -> Tuple[Tensor, Tensor]:
distloss: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Alpah compositing of 2D Gaussians in Pure Pytorch.

This function performs alpha compositing for Gaussians based on the pair of indices
Expand All @@ -334,23 +335,30 @@ def accumulate(
means2d: Gaussian means in 2D. [C, N, 2]
conics: Inverse of the 2D Gaussian covariance, Only upper triangle values. [C, N, 3]
opacities: Per-view Gaussian opacities (for example, when antialiasing is
enabled, Gaussian in each view would efficiently have different opacity). [C, N]
enabled, Gaussian in each view would efficiently have different opacity). [C, N]
colors: Per-view Gaussian colors. Supports N-D features. [C, N, channels]
gaussian_ids: Collection of Gaussian indices to be rasterized. A flattened list of shape [M].
pixel_ids: Collection of pixel indices (row-major) to be rasterized. A flattened list of shape [M].
camera_ids: Collection of camera indices to be rasterized. A flattened list of shape [M].
image_width: Image width.
image_height: Image height.
distloss: If True, render the per-pixel distortion loss map. This requires the depth
information in the last channel of `colors`. Default is False.

Returns:
A tuple:

- **renders**: Accumulated colors. [C, image_height, image_width, channels]
- **alphas**: Accumulated opacities. [C, image_height, image_width, 1]
- **dists**: Distortion loss if `distloss=True` else None. [C, image_height, image_width, 1]
"""

try:
from nerfacc import accumulate_along_rays, render_weight_from_alpha
from nerfacc import (
accumulate_along_rays,
exclusive_sum,
render_weight_from_alpha,
)
except ImportError:
raise ImportError("Please install nerfacc package: pip install nerfacc")

Expand Down Expand Up @@ -386,7 +394,18 @@ def accumulate(
weights, None, ray_indices=indices, n_rays=total_pixels
).reshape(C, image_height, image_width, 1)

return renders, alphas
if distloss:
depths = colors[camera_ids, gaussian_ids, -1]
loss_bi_0 = weights * depths * exclusive_sum(weights, indices=indices)
loss_bi_1 = weights * exclusive_sum(weights * depths, indices=indices)
dists = 2 * (loss_bi_0 - loss_bi_1)
dists = accumulate_along_rays(dists, None, indices, total_pixels).reshape(
C, image_height, image_width, 1
)
else:
dists = None

return renders, alphas, dists


def _rasterize_to_pixels(
Expand All @@ -400,6 +419,7 @@ def _rasterize_to_pixels(
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
backgrounds: Optional[Tensor] = None, # [C, channels]
distloss: bool = False,
batch_per_iter: int = 100,
):
"""Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels()`.
Expand Down Expand Up @@ -434,6 +454,10 @@ def _rasterize_to_pixels(
(C, image_height, image_width, colors.shape[-1]), device=device
)
render_alphas = torch.zeros((C, image_height, image_width, 1), device=device)
if distloss:
render_distloss = torch.zeros((C, image_height, image_width, 1), device=device)
else:
render_distloss = None

# Split Gaussians into batches and iteratively accumulate the renderings
block_size = tile_size * tile_size
Expand Down Expand Up @@ -464,7 +488,7 @@ def _rasterize_to_pixels(
break

# Accumulate the renderings within this batch of Gaussians.
renders_step, accs_step = accumulate(
renders_step, accs_step, dists_step = accumulate(
means2d,
conics,
opacities,
Expand All @@ -474,17 +498,21 @@ def _rasterize_to_pixels(
camera_ids,
image_width,
image_height,
distloss=distloss,
)
render_colors = render_colors + renders_step * transmittances[..., None]
render_alphas = render_alphas + accs_step * transmittances[..., None]
if distloss:
render_distloss = render_distloss + dists_step * (
transmittances[..., None] ** 2
)

render_alphas = render_alphas
if backgrounds is not None:
render_colors = render_colors + backgrounds[:, None, None, :] * (
1.0 - render_alphas
)

return render_colors, render_alphas
return render_colors, render_alphas, render_distloss


def _eval_sh_bases_fast(basis_dim: int, dirs: Tensor):
Expand Down
30 changes: 25 additions & 5 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def rasterize_to_pixels(
backgrounds: Optional[Tensor] = None, # [C, channels]
packed: bool = False,
absgrad: bool = False,
distloss: bool = False,
) -> Tuple[Tensor, Tensor]:
"""Rasterizes Gaussians to pixels.

Expand All @@ -397,12 +398,15 @@ def rasterize_to_pixels(
backgrounds: Background colors. [C, channels]. Default: None.
packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False.
absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False.
distloss: If True, the per-pixel distortion loss will be rendered. We expect
the depths are concatenated into the last channel of the colors. Default: False.

Returns:
A tuple:

- **Rendered colors**. [C, image_height, image_width, channels]
- **Rendered alphas**. [C, image_height, image_width, 1]
- **Rendered distloss**. [C, image_height, image_width, 1] if `distloss` is True, otherwise None.
"""

C = isect_offsets.size(0)
Expand Down Expand Up @@ -478,7 +482,7 @@ def rasterize_to_pixels(
tile_width * tile_size >= image_width
), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"

render_colors, render_alphas = _RasterizeToPixels.apply(
render_colors, render_alphas, render_distloss = _RasterizeToPixels.apply(
means2d.contiguous(),
conics.contiguous(),
colors.contiguous(),
Expand All @@ -490,11 +494,12 @@ def rasterize_to_pixels(
isect_offsets.contiguous(),
flatten_ids.contiguous(),
absgrad,
distloss,
)

if padded_channels > 0:
render_colors = render_colors[..., :-padded_channels]
return render_colors, render_alphas
return render_colors, render_alphas, render_distloss


@torch.no_grad()
Expand Down Expand Up @@ -820,8 +825,9 @@ def forward(
isect_offsets: Tensor, # [C, tile_height, tile_width]
flatten_ids: Tensor, # [n_isects]
absgrad: bool,
) -> Tuple[Tensor, Tensor]:
render_colors, render_alphas, last_ids = _make_lazy_cuda_func(
distloss: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
render_colors, render_alphas, render_distloss, last_ids = _make_lazy_cuda_func(
"rasterize_to_pixels_fwd"
)(
means2d,
Expand All @@ -834,6 +840,7 @@ def forward(
tile_size,
isect_offsets,
flatten_ids,
distloss,
)

ctx.save_for_backward(
Expand All @@ -844,23 +851,26 @@ def forward(
backgrounds,
isect_offsets,
flatten_ids,
render_colors,
render_alphas,
last_ids,
)
ctx.width = width
ctx.height = height
ctx.tile_size = tile_size
ctx.absgrad = absgrad
ctx.distloss = distloss

# double to float
render_alphas = render_alphas.float()
return render_colors, render_alphas
return render_colors, render_alphas, render_distloss

@staticmethod
def backward(
ctx,
v_render_colors: Tensor, # [C, H, W, 3]
v_render_alphas: Tensor, # [C, H, W, 1]
v_render_distloss: Optional[Tensor], # [C, H, W, 1]
):
(
means2d,
Expand All @@ -870,13 +880,20 @@ def backward(
backgrounds,
isect_offsets,
flatten_ids,
render_colors,
render_alphas,
last_ids,
) = ctx.saved_tensors
width = ctx.width
height = ctx.height
tile_size = ctx.tile_size
absgrad = ctx.absgrad
distloss = ctx.distloss
if distloss:
assert v_render_distloss is not None, "v_render_distloss should not be None"
v_render_distloss = v_render_distloss.contiguous()
else:
assert v_render_distloss is None, "v_render_distloss should be None"

(
v_means2d_abs,
Expand All @@ -895,10 +912,12 @@ def backward(
tile_size,
isect_offsets,
flatten_ids,
render_colors,
render_alphas,
last_ids,
v_render_colors.contiguous(),
v_render_alphas.contiguous(),
v_render_distloss,
absgrad,
)

Expand All @@ -924,6 +943,7 @@ def backward(
None,
None,
None,
None,
)


Expand Down
Loading
Loading