Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement normal consistency loss. #273

Open
wants to merge 77 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
d31aa2b
use inria cuda to train
jefequien Jul 2, 2024
6a58dc3
canvas list
jefequien Jul 2, 2024
1b6e094
2dgs
jefequien Jul 3, 2024
55b39bb
lambda_dist
jefequien Jul 3, 2024
f4ea0e0
clean up
jefequien Jul 3, 2024
87748f4
format
jefequien Jul 3, 2024
7ee6c20
3m garden first
jefequien Jul 3, 2024
143af51
2dgs_mcmc_sfm
jefequien Jul 4, 2024
6670314
train 3dgs without normal loss
jefequien Jul 4, 2024
286867b
Merge branch 'main' into jeff/inria
jefequien Jul 4, 2024
a8fd7a9
3dgs normal working
jefequien Jul 4, 2024
5870b25
baseline no gradient
jefequien Jul 5, 2024
a8adfe5
normal backprob
jefequien Jul 5, 2024
08637cb
cleanup
jefequien Jul 5, 2024
ea6ce1a
cleanup
jefequien Jul 5, 2024
f7f0b9c
cleanup
jefequien Jul 5, 2024
468c451
cmocean dense
jefequien Jul 5, 2024
8cfa0fc
cmo ice
jefequien Jul 5, 2024
20cd888
voltage
jefequien Jul 5, 2024
503ac93
edit bash
jefequien Jul 5, 2024
23f9984
clean up benchmark script
jefequien Jul 5, 2024
02e00fb
remove dist_loss
jefequien Jul 5, 2024
96e8d23
depth must be last
jefequien Jul 5, 2024
f336b95
colors normal depth
jefequien Jul 5, 2024
54f137e
reduce diff
jefequien Jul 8, 2024
50749de
cleanup
jefequien Jul 8, 2024
bcac804
remove distloss
jefequien Jul 8, 2024
5752261
benchmark script
jefequien Jul 8, 2024
dadaf75
refactor
jefequien Jul 8, 2024
6800602
compile bug
jefequien Jul 8, 2024
98573e6
Merge branch 'main' into jeff/normal_consistency
jefequien Jul 9, 2024
6d156d2
remove benchmark script
jefequien Jul 9, 2024
9c06a24
support packed and sparse
jefequien Jul 15, 2024
7294be4
refactor
jefequien Jul 15, 2024
0a9ce04
bugfix
jefequien Jul 15, 2024
1bd9e91
utils/
jefequien Jul 15, 2024
d7a2916
rasterization backend
jefequien Jul 15, 2024
0274d1b
v_rotmat
jefequien Jul 15, 2024
6840727
render traj
jefequien Jul 15, 2024
5fdc934
point inward during ellipse
jefequien Jul 15, 2024
e3dc5e5
add tests for fwd and bwd pass
jefequien Jul 16, 2024
467ffe6
all but one test passing
jefequien Jul 16, 2024
8e158a2
weird bug
jefequien Jul 16, 2024
e5c9b64
merge
jefequien Jul 16, 2024
08b4631
test passes but test suite does not pass
jefequien Jul 16, 2024
c7e2fd4
count radii
jefequien Jul 16, 2024
ee40bb2
change sel to pass tests
jefequien Jul 16, 2024
6f21cb7
__sel
jefequien Jul 16, 2024
d2f89ab
cleanup
jefequien Jul 16, 2024
12723b1
simplify utils
jefequien Jul 17, 2024
433b313
merge
jefequien Jul 19, 2024
fadb7b4
util
jefequien Jul 19, 2024
f40b879
test rasterization
jefequien Jul 19, 2024
d766117
benchmark script
jefequien Jul 19, 2024
8e8d681
__init__
jefequien Jul 19, 2024
90c62af
merge
jefequien Aug 29, 2024
24eb1b5
fix normal consistency
jefequien Aug 31, 2024
a86ef37
ellipse
jefequien Aug 31, 2024
f0b93f0
cleanup
jefequien Aug 31, 2024
a014c89
uncomment
jefequien Aug 31, 2024
35b21c4
canvas list
jefequien Aug 31, 2024
9aaf257
Merge branch 'main' into jeff/normal_consistency
jefequien Aug 31, 2024
1afd41a
merge with traj
jefequien Aug 31, 2024
f072f17
cleanup
jefequien Aug 31, 2024
1fa189c
fix tests
jefequien Aug 31, 2024
b9ef876
remove 2dgs inria
jefequien Aug 31, 2024
c08a23e
script
jefequien Sep 1, 2024
1535db1
merge
jefequien Sep 3, 2024
6346132
merge
jefequien Sep 13, 2024
bfd78bc
fix merge
jefequien Sep 13, 2024
0a1dc09
fix utils
jefequien Sep 13, 2024
5b5a7c3
reduce diff test_basic
jefequien Sep 13, 2024
9c4186a
tests not passing
jefequien Sep 13, 2024
0c5e3ed
all tests passed
jefequien Sep 13, 2024
541990a
merge
jefequien Sep 22, 2024
38f2532
summarize stats
jefequien Sep 24, 2024
72d38d6
merge
jefequien Sep 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/benchmarks/compression/mcmc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ done
if command -v zip &> /dev/null
then
echo "Zipping results"
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR
python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage compress
else
echo "zip command not found, skipping zipping"
fi
2 changes: 1 addition & 1 deletion examples/benchmarks/compression/mcmc_tt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ done
if command -v zip &> /dev/null
then
echo "Zipping results"
python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST
python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage compress
else
echo "zip command not found, skipping zipping"
fi
17 changes: 17 additions & 0 deletions examples/benchmarks/normal/2dgs_dtu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
SCENE_DIR="data/DTU"
SCENE_LIST="scan24 scan37 scan40 scan55 scan63 scan65 scan69 scan83 scan97 scan105 scan106 scan110 scan114 scan118 scan122"

RESULT_DIR="results/benchmark_dtu_2dgs"

for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"

# train and eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --disable_viewer --data_factor 1 \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/
done

echo "Summarizing results"
python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
28 changes: 28 additions & 0 deletions examples/benchmarks/normal/mcmc_dtu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
SCENE_DIR="data/DTU"
SCENE_LIST="scan24 scan37 scan40 scan55 scan63 scan65 scan69 scan83 scan97 scan105 scan106 scan110 scan114 scan118 scan122"
RENDER_TRAJ_PATH="ellipse"

RESULT_DIR="results/benchmark_dtu_mcmc_0.25M_normal"
CAP_MAX=250000

# RESULT_DIR="results/benchmark_dtu_mcmc_0.5M_normal"
# CAP_MAX=500000

# RESULT_DIR="results/benchmark_dtu_mcmc_1M_normal"
# CAP_MAX=1000000

for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"

# train and eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \
--strategy.cap-max $CAP_MAX \
--normal_consistency_loss \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/
done

echo "Summarizing results"
python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
25 changes: 25 additions & 0 deletions examples/benchmarks/normal/mcmc_normal.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
SCENE_DIR="data/360_v2"
SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers"

RESULT_DIR="results/benchmark_normal"
RENDER_TRAJ_PATH="ellipse"

for SCENE in $SCENE_LIST;
do
if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then
DATA_FACTOR=2
else
DATA_FACTOR=4
fi

echo "Running $SCENE"

CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--normal_consistency_loss \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/
done

echo "Summarizing results"
python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import tyro


def main(results_dir: str, scenes: List[str]):
print("scenes:", scenes)
stage = "compress"

summary = defaultdict(list)
def main(results_dir: str, scenes: List[str], stage: str = "val"):
stats_all = defaultdict(list)
for scene in scenes:
scene_dir = os.path.join(results_dir, scene)

Expand All @@ -25,15 +22,20 @@ def main(results_dir: str, scenes: List[str]):
f"stat -c%s {zip_path}", shell=True, capture_output=True
)
size = int(out.stdout)
summary["size"].append(size)
stats_all["size"].append(size)

with open(os.path.join(scene_dir, f"stats/{stage}_step29999.json"), "r") as f:
stats = json.load(f)
for k, v in stats.items():
summary[k].append(v)
stats_all[k].append(v)

summary = {"scenes": scenes}
for k, v in stats_all.items():
summary[k] = np.mean(v)
print(summary)

for k, v in summary.items():
print(k, np.mean(v))
with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f:
json.dump(summary, f, indent=2)


if __name__ == "__main__":
Expand Down
69 changes: 56 additions & 13 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ class Config:
# Weight for depth loss
depth_lambda: float = 1e-2

# Enable normal consistency loss. (experimental)
normal_consistency_loss: bool = False
# Weight for normal consistency loss
normal_consistency_lambda: float = 0.05
# Start applying normal consistency loss after this iteration
normal_consistency_start_iter: int = 7000

# Dump information to tensorboard every this steps
tb_every: int = 100
# Save training images to tensorboard
Expand Down Expand Up @@ -273,6 +280,12 @@ def __init__(
self.world_size = world_size
self.device = f"cuda:{local_rank}"

self.render_mode = "RGB"
if cfg.depth_loss:
self.render_mode = "RGB+ED"
if cfg.normal_consistency_loss:
self.render_mode = "RGB+ED+N"

# Where to dump results.
os.makedirs(cfg.result_dir, exist_ok=True)

Expand Down Expand Up @@ -587,13 +600,10 @@ def train(self):
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
render_mode=self.render_mode,
masks=masks,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
else:
colors, depths = renders, None
colors = renders[..., :3]

if cfg.use_bilateral_grid:
grid_y, grid_x = torch.meshgrid(
Expand Down Expand Up @@ -623,6 +633,7 @@ def train(self):
)
loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
if cfg.depth_loss:
depths = renders[..., -1:]
# query depths from depth map
points = torch.stack(
[
Expand All @@ -641,6 +652,14 @@ def train(self):
disp_gt = 1.0 / depths_gt # [1, M]
depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
loss += depthloss * cfg.depth_lambda
if cfg.normal_consistency_loss:
normals_rend = info["normals_rend"]
normals_surf = info["normals_surf"]
normalconsistencyloss = (
1 - (normals_rend * normals_surf).sum(dim=-1)
).mean()
if step > cfg.normal_consistency_start_iter:
loss += normalconsistencyloss * cfg.normal_consistency_lambda
if cfg.use_bilateral_grid:
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
loss += tvloss
Expand Down Expand Up @@ -687,6 +706,12 @@ def train(self):
self.writer.add_scalar("train/mem", mem, step)
if cfg.depth_loss:
self.writer.add_scalar("train/depthloss", depthloss.item(), step)
if cfg.normal_consistency_loss:
self.writer.add_scalar(
"train/normalconsistencyloss",
normalconsistencyloss.item(),
step,
)
if cfg.use_bilateral_grid:
self.writer.add_scalar("train/tvloss", tvloss.item(), step)
if cfg.tb_save_image:
Expand Down Expand Up @@ -819,21 +844,31 @@ def eval(self, step: int, stage: str = "val"):

torch.cuda.synchronize()
tic = time.time()
colors, _, _ = self.rasterize_splats(
renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
render_mode=self.render_mode,
masks=masks,
) # [1, H, W, 3]
torch.cuda.synchronize()
ellipse_time += time.time() - tic

colors = torch.clamp(colors, 0.0, 1.0)
colors = torch.clamp(renders[..., 0:3], 0.0, 1.0)
canvas_list = [pixels, colors]
if cfg.depth_loss:
depths = renders[..., -1:]
depths = (depths - depths.min()) / (depths.max() - depths.min())
canvas_list.append(depths)
if cfg.normal_consistency_loss:
normals_rend = info["normals_rend"]
normals_surf = info["normals_surf"]
canvas_list.extend([normals_rend * 0.5 + 0.5])
canvas_list.extend([normals_surf * 0.5 + 0.5])

if world_rank == 0:
# write images
Expand Down Expand Up @@ -927,20 +962,28 @@ def render_traj(self, step: int):
camtoworlds = camtoworlds_all[i : i + 1]
Ks = K[None]

renders, _, _ = self.rasterize_splats(
renders, alphas, info = self.rasterize_splats(
camtoworlds=camtoworlds,
Ks=Ks,
width=width,
height=height,
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
render_mode="RGB+ED",
render_mode=self.render_mode,
) # [1, H, W, 4]
colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3]
depths = renders[..., 3:4] # [1, H, W, 1]
depths = (depths - depths.min()) / (depths.max() - depths.min())
canvas_list = [colors, depths.repeat(1, 1, 1, 3)]

colors = torch.clamp(renders[..., 0:3], 0.0, 1.0)
canvas_list = [colors]
if cfg.depth_loss:
depths = renders[..., -1:]
depths = (depths - depths.min()) / (depths.max() - depths.min())
canvas_list.append(depths)
if cfg.normal_consistency_loss:
normals_rend = info["normals_rend"]
normals_surf = info["normals_surf"]
canvas_list.extend([normals_rend * 0.5 + 0.5])
canvas_list.extend([normals_surf * 0.5 + 0.5])

# write images
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
Expand Down
9 changes: 7 additions & 2 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def _world_to_cam(

def _fully_fused_projection(
means: Tensor, # [N, 3]
covars: Tensor, # [N, 3, 3]
quats: Tensor,
scales: Tensor,
viewmats: Tensor, # [C, 4, 4]
Ks: Tensor, # [C, 3, 3]
width: int,
Expand All @@ -267,6 +268,10 @@ def _fully_fused_projection(
This is a minimal implementation of fully fused version, which has more
arguments. Not all arguments are supported.
"""
covars, _ = _quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3]
normals = _quat_to_rotmat(quats)[..., 2] # [N, 3]
normals = normals.repeat(viewmats.shape[0], 1, 1) # [C, N, 3]

means_c, covars_c = _world_to_cam(means, covars, viewmats)

if camera_model == "ortho":
Expand Down Expand Up @@ -324,7 +329,7 @@ def _fully_fused_projection(
radius[~inside] = 0.0

radii = radius.int()
return radii, means2d, depths, conics, compensations
return radii, means2d, depths, normals, conics, compensations


@torch.no_grad()
Expand Down
23 changes: 19 additions & 4 deletions gsplat/cuda/_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def forward(
)

# "covars" and {"quats", "scales"} are mutually exclusive
radii, means2d, depths, conics, compensations = _make_lazy_cuda_func(
radii, means2d, depths, normals, conics, compensations = _make_lazy_cuda_func(
"fully_fused_projection_fwd"
)(
means,
Expand Down Expand Up @@ -808,10 +808,12 @@ def forward(
ctx.eps2d = eps2d
ctx.camera_model_type = camera_model_type

return radii, means2d, depths, conics, compensations
return radii, means2d, depths, normals, conics, compensations

@staticmethod
def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
def backward(
ctx, v_radii, v_means2d, v_depths, v_normals, v_conics, v_compensations
):
(
means,
covars,
Expand Down Expand Up @@ -847,6 +849,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations):
compensations,
v_means2d.contiguous(),
v_depths.contiguous(),
v_normals.contiguous(),
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
Expand Down Expand Up @@ -1043,6 +1046,7 @@ def forward(
radii,
means2d,
depths,
normals,
conics,
compensations,
) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")(
Expand Down Expand Up @@ -1081,7 +1085,16 @@ def forward(
ctx.sparse_grad = sparse_grad
ctx.camera_model_type = camera_model_type

return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations
return (
camera_ids,
gaussian_ids,
radii,
means2d,
depths,
normals,
conics,
compensations,
)

@staticmethod
def backward(
Expand All @@ -1091,6 +1104,7 @@ def backward(
v_radii,
v_means2d,
v_depths,
v_normals,
v_conics,
v_compensations,
):
Expand Down Expand Up @@ -1133,6 +1147,7 @@ def backward(
compensations,
v_means2d.contiguous(),
v_depths.contiguous(),
v_normals.contiguous(),
v_conics.contiguous(),
v_compensations,
ctx.needs_input_grad[4], # viewmats_requires_grad
Expand Down
Loading
Loading