Skip to content

Commit

Permalink
Should be properly working up to param optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 24, 2024
1 parent 3423efb commit 6de6760
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 213 deletions.
164 changes: 84 additions & 80 deletions examples/simple_trainer_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class Config:
render_traj_path: str = "interp"

# Path to the Mip-NeRF 360 dataset
data_dir: str = "/home/paja/data/bike"
#data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden"
data_dir: str = "/home/paja/data/bike_aliked"
# Downsample factor for the dataset
data_factor: int = 1
# Directory to save results
Expand Down Expand Up @@ -88,8 +89,6 @@ class Config:
# Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
init_extent: float = 3.0
# Turn on another SH degree every this steps
sh_degree_interval: int = 1000
# Initial opacity of GS
init_opa: float = 0.1
# Initial scale of GS
init_scale: float = 1.0
Expand All @@ -113,8 +112,6 @@ class Config:
# Use random background for training to discourage transparency
random_bkgd: bool = False

# Opacity regularization
opacity_reg: float = 0.0
# Scale regularization
scale_reg: float = 0.01

Expand Down Expand Up @@ -157,7 +154,6 @@ def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)

strategy = self.strategy
strategy.refine_start_iter = int(strategy.refine_start_iter * factor)
Expand Down Expand Up @@ -196,7 +192,6 @@ def create_splats_with_optimizers(
# Distribute the GSs to different ranks (also works for single rank)
points = points[world_rank::world_size]
scales = scales[world_rank::world_size]

N = points.shape[0]
quats = torch.rand((N, 4)) # [N, 4]

Expand All @@ -209,8 +204,8 @@ def create_splats_with_optimizers(
("scales", torch.nn.Parameter(scales), 0.007),
("quats", torch.nn.Parameter(quats), 0.002),
("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002),
("features", torch.nn.Parameter(features), 0.0075 * scene_scale),
("offsets", torch.nn.Parameter(offsets), 0.01),
("features", torch.nn.Parameter(features), 0.0075),
("offsets", torch.nn.Parameter(offsets), 0.01 * scene_scale),
("colors_mlp", strategy.colors_mlp.parameters(), 0.008),
("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004),
]
Expand Down Expand Up @@ -242,6 +237,7 @@ def __init__(
set_random_seed(42 + local_rank)

self.cfg = cfg
self.cfg.strategy.voxel_size = self.cfg.voxel_size
self.world_rank = world_rank
self.local_rank = local_rank
self.world_size = world_size
Expand Down Expand Up @@ -395,51 +391,44 @@ def __init__(
mode="training",
)

def get_visible_anchor_mask(
self,
camtoworlds: Tensor,
Ks: Tensor,
width: int,
height: int,
packed: bool,
rasterize_mode: str,
def get_neural_gaussians(self,
camtoworlds: Tensor,
Ks: Tensor,
width: int,
height: int,
packed: bool,
rasterize_mode: str,
):
anchors = self.splats["anchors"] # [N, 3]
# rasterization does normalization internally
quats = self.splats["quats"] # [N, 4]
scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3]

# Compare paper: Helps mainly to speed up the rasterization. Has no quality impact
visible_anchor_mask = view_to_visible_anchors(
means=anchors,
quats=quats,
scales=scales,
means=self.splats["anchors"],
quats=self.splats["quats"],
scales=torch.exp(self.splats["scales"])[:, :3],
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
Ks=Ks, # [C, 3, 3]
width=width,
height=height,
packed=packed,
rasterize_mode=rasterize_mode,
)

return visible_anchor_mask

def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None):

# If no visibility mask is provided, we select all anchors including their offsets
if visible_anchor_mask is None:
visible_anchor_mask = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device)

selected_features = self.splats["features"][visible_anchor_mask] # [M, c]
selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3]
selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3]
selected_scales = torch.exp(self.splats["scales"][visible_anchor_mask]) # [M, 6]

# See formula (5) in Scaffold-GS

cam_pos = camtoworlds[:, :3, 3]
view_dir = selected_anchors - cam_pos # [M, 3]
view_dir_normalized = view_dir / view_dir.norm(dim=1, keepdim=True) # [M, 3]
length = view_dir.norm(dim=1, keepdim=True)
view_dir_normalized = view_dir / length # [M, 3]

view_length = torch.cat([view_dir_normalized, length], dim=1)

# See formula (9) and the appendix for the rest
feature_view_dir = torch.cat([selected_features, view_dir_normalized], dim=1) # [M, c+3]
feature_view_dir = torch.cat([selected_features, view_length], dim=1) # [M, c+3]

k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor

Expand All @@ -462,22 +451,36 @@ def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None):
anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3]

# Apply positive opacity mask
selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [m]
selected_colors = neural_colors[neural_selection_mask] # [m, 3]
selected_scale_rot = neural_scale_rot[neural_selection_mask] # [m, 7]
selected_offsets = selected_offsets[neural_selection_mask] # [m, 3]
scales_repeated = scales_repeated[neural_selection_mask] # [m, 6]
anchors_repeated = anchors_repeated[neural_selection_mask] # [m, 3]
selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [M]
selected_colors = neural_colors[neural_selection_mask] # [M, 3]
selected_scale_rot = neural_scale_rot[neural_selection_mask] # [M, 7]
selected_offsets = selected_offsets[neural_selection_mask] # [M, 3]
scales_repeated = scales_repeated[neural_selection_mask] # [M, 6]
anchors_repeated = anchors_repeated[neural_selection_mask] # [M, 3]

# Compute scales and rotations
scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3]
rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [m, 4]
scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [M, 3]
rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [M, 4]

# Compute offsets and anchors
offsets = selected_offsets * scales_repeated[:, :3] # [m, 3]
means = anchors_repeated + offsets # [m, 3]

return means, selected_colors, selected_opacity, scales, rotation, neural_opacity, neural_selection_mask
offsets = selected_offsets * scales_repeated[:, :3] # [M, 3]
means = anchors_repeated + offsets # [M, 3]

v_a = visible_anchor_mask.unsqueeze(dim=1).repeat([1, self.cfg.strategy.n_feat_offsets]).view(-1)
all_neural_gaussians = torch.zeros_like(v_a, dtype=torch.bool)
all_neural_gaussians[v_a] = neural_selection_mask

info = {
"means": means,
"colors": selected_colors,
"opacities": selected_opacity,
"scales": scales,
"quats": rotation,
"neural_opacities": neural_opacity,
"neural_selection_mask":all_neural_gaussians,
"visible_anchor_mask": visible_anchor_mask,
}
return info

def rasterize_splats(
self,
Expand All @@ -486,38 +489,35 @@ def rasterize_splats(
width: int,
height: int,
**kwargs,
) -> Tuple[Tensor, Tensor, Dict, Tensor]:

# We select only the visible anchors for faster inference
visible_anchor_mask = self.get_visible_anchor_mask(camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
packed=self.cfg.packed,
rasterize_mode = "antialiased" if self.cfg.antialiased else "classic",
)
) -> Tuple[Tensor, Tensor, Dict]:

# Get all the gaussians per voxel spawned from the anchors
means, color_mlp, opacities, scales, quats, neural_opacity, neural_selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], visible_anchor_mask=visible_anchor_mask)
info = self.get_neural_gaussians(camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
packed=self.cfg.packed,
rasterize_mode = "antialiased" if self.cfg.antialiased else "classic",
)

image_ids = kwargs.pop("image_ids", None)
if self.cfg.app_opt:
colors = self.app_module(
features=self.splats["features"],
embed_ids=image_ids,
dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
dirs=info["means"][None, :, :] - camtoworlds[:, None, :3, 3],
)
colors = colors + color_mlp
colors = colors + info["colors"]
colors = torch.sigmoid(colors)
else:
colors = color_mlp # [N, K, 3]
colors = info["colors"] # [N, K, 3]

rasterize_mode = "antialiased" if self.cfg.antialiased else "classic"
render_colors, render_alphas, info = rasterization(
means=means,
quats=quats,
scales=scales,
opacities=opacities,
render_colors, render_alphas, raster_info = rasterization(
means=info["means"],
quats=info["quats"],
scales=info["scales"],
opacities=info["opacities"],
colors=colors,
viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
Ks=Ks, # [C, 3, 3]
Expand All @@ -530,14 +530,10 @@ def rasterize_splats(
distributed=self.world_size > 1,
**kwargs,
)
info.update(
{
"visible_anchor_mask": visible_anchor_mask,
"neural_selection_mask": neural_selection_mask,
"neural_opacities": neural_opacity,
}
raster_info.update(
info
)
return render_colors, render_alphas, info, scales
return render_colors, render_alphas, raster_info

def train(self):
cfg = self.cfg
Expand All @@ -555,13 +551,16 @@ def train(self):

schedulers = [
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["offsets"], gamma=0.01 ** (1.0 / max_steps)
self.optimizers["anchors"], gamma=0.001 ** (1.0 / max_steps)
),
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["offsets"], gamma=(0.01 * self.scene_scale) ** (1.0 / max_steps)
),
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["opacities_mlp"], gamma=0.002 ** (1.0 / max_steps)
self.optimizers["opacities_mlp"], gamma=0.001 ** (1.0 / max_steps)
),
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["colors_mlp"], gamma=0.008 ** (1.0 / max_steps)
self.optimizers["colors_mlp"], gamma=0.00625 ** (1.0 / max_steps)
),
]
if cfg.pose_opt:
Expand Down Expand Up @@ -601,6 +600,11 @@ def train(self):
# Training loop.
global_tic = time.time()
pbar = tqdm.tqdm(range(init_step, max_steps))

self.cfg.strategy.scale_rot_mlp.train()
self.cfg.strategy.opacities_mlp.train()
self.cfg.strategy.colors_mlp.train()

for step in pbar:
if not cfg.disable_viewer:
while self.viewer.state.status == "paused":
Expand Down Expand Up @@ -634,7 +638,7 @@ def train(self):
camtoworlds = self.pose_adjust(camtoworlds, image_ids)

# forward
renders, alphas, info, scales = self.rasterize_splats(
renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
Expand Down Expand Up @@ -678,7 +682,7 @@ def train(self):
)
loss = l1loss * (1.0 - cfg.ssim_lambda)
loss += ssimloss * cfg.ssim_lambda
loss += scales.prod(dim=1).mean() * cfg.scale_reg
loss += info["scales"].prod(dim=1).mean() * cfg.scale_reg
if cfg.depth_loss:
# query depths from depth map
points = torch.stack(
Expand Down Expand Up @@ -842,7 +846,7 @@ def eval(self, step: int, stage: str = "val"):

torch.cuda.synchronize()
tic = time.time()
colors, _, _, _ = self.rasterize_splats(
colors, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
Expand Down Expand Up @@ -946,7 +950,7 @@ def render_traj(self, step: int):
camtoworlds = camtoworlds_all[i : i + 1]
Ks = K[None]

renders, _, _, _ = self.rasterize_splats(
renders, _, _ = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
Expand Down Expand Up @@ -1003,7 +1007,7 @@ def _viewer_render_fn(
c2w = torch.from_numpy(c2w).float().to(self.device)
K = torch.from_numpy(K).float().to(self.device)

render_colors, _, _, _ = self.rasterize_splats(
render_colors, _, _ = self.rasterize_splats(
camtoworlds=c2w[None],
Ks=K[None],
width=W,
Expand Down
49 changes: 46 additions & 3 deletions gsplat/strategy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,14 @@ def remove(
sel = torch.where(~mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel])
if name == "scales":
p = p[sel]
temp = p[:,3:]
temp[temp>0.05] = 0.05
p[:,3:] = temp
return torch.nn.Parameter(p)
else:
return torch.nn.Parameter(p[sel])

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]
Expand Down Expand Up @@ -403,8 +410,7 @@ def grow_anchors(
log_voxel_size = torch.log(torch.tensor(voxel_size, device=device))
scaling = log_voxel_size.expand(num_new, anchors.size(1) * 2) # [N_new, 6]

rotation = torch.zeros((num_new, 4), device=device)
rotation[:, 0] = 1.0 # Identity quaternion
rotation = torch.ones((num_new, 4), device=device)

# Prepare new features
existing_features = params["features"] # [N_existing, feat_dim]
Expand Down Expand Up @@ -457,3 +463,40 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor:
zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device)
state[k] = torch.cat([v, zeros], dim=0)

@torch.no_grad()
def remove_anchors(
params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict],
optimizers: Dict[str, torch.optim.Optimizer],
n_feat_offsets: int,
state: Dict[str, Tensor],
mask: Tensor,
names: Union[List[str], None] = None,
):
"""Inplace remove the Gaussian with the given mask.
Args:
params: A dictionary of parameters.
optimizers: A dictionary of optimizers, each corresponding to a parameter.
n_feat_offsets: Number of feature offsets.
state: A dictionary of extra state tensors.
mask: A boolean mask to remove the Gaussians.
names: A list of parameter names to update. If None, update all. Default: None.
"""
sel = torch.where(~mask)[0]

def param_fn(name: str, p: Tensor) -> Tensor:
return torch.nn.Parameter(p[sel])

def optimizer_fn(key: str, v: Tensor) -> Tensor:
return v[sel]

# update the parameters and the state in the optimizers
_update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names)
# update the extra running state
for k, v in state.items():
if isinstance(v, torch.Tensor):
if k in ["anchor_count", "anchor_opacity"]:
state[k] = v[sel]
else:
offset_sel = sel.unsqueeze(dim=1).repeat([1, n_feat_offsets]).view(-1)
state[k] = v[offset_sel]
Loading

0 comments on commit 6de6760

Please sign in to comment.