Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Fisheye-GS. #398

Merged
merged 41 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27fa353
baseline
jefequien Sep 9, 2024
96089a1
fisheye forward works
jefequien Sep 9, 2024
30b4579
torch implementation of fisheye projection
jefequien Sep 9, 2024
c4651c0
test basic
jefequien Sep 9, 2024
712ae95
close 0.3%
jefequien Sep 10, 2024
01daadb
19 mismatched
jefequien Sep 10, 2024
2de03b9
pass tests
jefequien Sep 10, 2024
9c61743
comment out
jefequien Sep 10, 2024
6347c49
crashing
jefequien Sep 10, 2024
6e56adf
remove dead code
jefequien Sep 10, 2024
cb08427
reduce diff
jefequien Sep 10, 2024
22060d5
video
jefequien Sep 10, 2024
139f7fe
distortion not handled correctly
jefequien Sep 10, 2024
fde16f8
test remap
jefequien Sep 10, 2024
37f40a5
remove hardcoded roi
jefequien Sep 10, 2024
98a7819
cleanup tests
jefequien Sep 10, 2024
67baaae
fix bug
jefequien Sep 10, 2024
e41679d
bug
jefequien Sep 10, 2024
1ed34f0
edit imsize_dict
jefequien Sep 11, 2024
4434f09
format c++
jefequien Sep 11, 2024
3948fc9
T
jefequien Sep 11, 2024
ce98242
use mask
jefequien Sep 11, 2024
66128ca
remove test_remap
jefequien Sep 11, 2024
419c9e1
mask roi
jefequien Sep 11, 2024
4211157
scripts
jefequien Sep 11, 2024
148c218
reduce diff
jefequien Sep 11, 2024
6adbf6d
minor
jefequien Sep 11, 2024
4663665
Merge branch 'main' into jeff/fisheye
jefequien Sep 12, 2024
c2e7ada
weird ortho bug
jefequien Sep 12, 2024
412ea62
Merge branch 'main' into jeff/fisheye
jefequien Sep 13, 2024
a31e65a
vectorize
jefequien Sep 17, 2024
41aa398
ellipse
jefequien Sep 17, 2024
7f2972a
unify python side camera_model
jefequien Sep 17, 2024
593769c
fisheye packed mode
jefequien Sep 17, 2024
b133704
Merge branch 'main' into jeff/fisheye
jefequien Sep 17, 2024
07d2087
cuda enum
jefequien Sep 18, 2024
e737073
use c++ enum
jefequien Sep 18, 2024
fa23297
download dataset
jefequien Sep 18, 2024
2065224
refactor dataset download to download zipnerf
jefequien Sep 18, 2024
6fe028d
use lists
jefequien Sep 18, 2024
e6c19d6
use bilateral grid as default for zipnerf
jefequien Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/benchmarks/fisheye/mcmc_zipnerf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
SCENE_DIR="data/zipnerf"
SCENE_LIST="berlin london nyc alameda"
DATA_FACTOR=2

RESULT_DIR="results/benchmark_mcmc_2M_zipnerf"
CAP_MAX=2000000

# RESULT_DIR="results/benchmark_mcmc_4M_zipnerf"
# CAP_MAX=4000000

for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"

# train and eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--opacity_reg 0.001 \
--camera_model fisheye \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/
done
22 changes: 22 additions & 0 deletions examples/benchmarks/fisheye/mcmc_zipnerf_undistort.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
SCENE_DIR="data/zipnerf_undistort"
SCENE_LIST="berlin london nyc alameda"
DATA_FACTOR=2

RESULT_DIR="results/benchmark_mcmc_2M_zipnerf_undistort"
CAP_MAX=2000000

# RESULT_DIR="results/benchmark_mcmc_4M_zipnerf_undistort"
# CAP_MAX=4000000

for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"

# train and eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--opacity_reg 0.001 \
--camera_model pinhole \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/
done
2 changes: 1 addition & 1 deletion examples/benchmarks/mcmc.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ do
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir data/360_v2/$SCENE/ \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/

# run eval and render
Expand Down
74 changes: 62 additions & 12 deletions examples/datasets/colmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
from typing import Any, Dict, List, Optional
from typing_extensions import assert_never

import cv2
import imageio.v2 as imageio
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
Ks_dict = dict()
params_dict = dict()
imsize_dict = dict() # width, height
mask_dict = dict()
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
for k in imdata:
im = imdata[k]
Expand Down Expand Up @@ -99,14 +101,12 @@ def __init__(
params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
camtype = "fisheye"
assert (
camtype == "perspective"
), f"Only support perspective camera model, got {type_}"
camtype == "perspective" or camtype == "fisheye"
), f"Only perspective and fisheye cameras are supported, got {type_}"

params_dict[camera_id] = params

# image size
imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)

mask_dict[camera_id] = None
print(
f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
)
Expand Down Expand Up @@ -203,6 +203,7 @@ def __init__(
self.Ks_dict = Ks_dict # Dict of camera_id -> K
self.params_dict = params_dict # Dict of camera_id -> params
self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
self.mask_dict = mask_dict # Dict of camera_id -> mask
self.points = points # np.ndarray, (num_points, 3)
self.points_err = points_err # np.ndarray, (num_points,)
self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
Expand Down Expand Up @@ -236,16 +237,62 @@ def __init__(
), f"Missing params for camera {camera_id}"
K = self.Ks_dict[camera_id]
width, height = self.imsize_dict[camera_id]
K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
K, params, (width, height), 0
)
mapx, mapy = cv2.initUndistortRectifyMap(
K, params, None, K_undist, (width, height), cv2.CV_32FC1
)
self.Ks_dict[camera_id] = K_undist

if camtype == "perspective":
K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
K, params, (width, height), 0
)
mapx, mapy = cv2.initUndistortRectifyMap(
K, params, None, K_undist, (width, height), cv2.CV_32FC1
)
mask = None
elif camtype == "fisheye":
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]
mapx = np.zeros((height, width), dtype=np.float32)
mapy = np.zeros((height, width), dtype=np.float32)
for i in range(0, width):
for j in range(0, height):
x = float(i)
y = float(j)
x1 = (x - cx) / fx
y1 = (y - cy) / fy
theta = np.sqrt(x1**2 + y1**2)
r = (
1.0
+ params[0] * theta**2
+ params[1] * theta**4
+ params[2] * theta**6
+ params[3] * theta**8
)
x2 = fx * x1 * r + width // 2
y2 = fy * y1 * r + height // 2
mapx[j, i] = x2
mapy[j, i] = y2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batchify this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


mask = np.logical_and(
np.logical_and(mapx > 0, mapy > 0),
np.logical_and(mapx < width - 1, mapy < height - 1),
)
y_indices, x_indices = np.nonzero(mask)
y_min, y_max = y_indices.min(), y_indices.max() + 1
x_min, x_max = x_indices.min(), x_indices.max() + 1
mask = mask[y_min:y_max, x_min:x_max]
K_undist = K.copy()
K_undist[0, 2] -= x_min
K_undist[1, 2] -= y_min
roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min]
else:
assert_never(camtype)

self.mapx_dict[camera_id] = mapx
self.mapy_dict[camera_id] = mapy
self.Ks_dict[camera_id] = K_undist
self.roi_undist_dict[camera_id] = roi_undist
self.imsize_dict[camera_id] = (roi_undist[2], roi_undist[3])
self.mask_dict[camera_id] = mask

# size of the scene measured by cameras
camera_locations = camtoworlds[:, :3, 3]
Expand Down Expand Up @@ -284,6 +331,7 @@ def __getitem__(self, item: int) -> Dict[str, Any]:
K = self.parser.Ks_dict[camera_id].copy() # undistorted K
params = self.parser.params_dict[camera_id]
camtoworlds = self.parser.camtoworlds[index]
mask = self.parser.mask_dict[camera_id]

if len(params) > 0:
# Images are distorted. Undistort them.
Expand All @@ -310,6 +358,8 @@ def __getitem__(self, item: int) -> Dict[str, Any]:
"image": torch.from_numpy(image).float(),
"image_id": item, # the index of the image in the dataset
}
if mask is not None:
data["mask"] = torch.from_numpy(mask).bool()

if self.load_depths:
# projected points to image plane to get depths
Expand Down
22 changes: 14 additions & 8 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class Config:
global_scale: float = 1.0
# Normalize the world space
normalize_world_space: bool = True
# Camera model
camera_model: Literal["pinhole", "fisheye"] = "pinhole"

# Port for the viewer server
port: int = 8080
Expand Down Expand Up @@ -432,6 +434,7 @@ def rasterize_splats(
Ks: Tensor,
width: int,
height: int,
masks: Optional[Tensor] = None,
**kwargs,
) -> Tuple[Tensor, Tensor, Dict]:
means = self.splats["means"] # [N, 3]
Expand Down Expand Up @@ -474,8 +477,11 @@ def rasterize_splats(
sparse_grad=self.cfg.sparse_grad,
rasterize_mode=rasterize_mode,
distributed=self.world_size > 1,
fisheye=self.cfg.camera_model == "fisheye",
**kwargs,
)
if masks is not None:
render_colors[~masks] = 0
return render_colors, render_alphas, info

def train(self):
Expand Down Expand Up @@ -555,6 +561,7 @@ def train(self):
pixels.shape[0] * pixels.shape[1] * pixels.shape[2]
)
image_ids = data["image_id"].to(device)
masks = data["mask"].to(device) if "mask" in data else None # [1, H, W]
if cfg.depth_loss:
points = data["points"].to(device) # [1, M, 2]
depths_gt = data["depths"].to(device) # [1, M]
Expand All @@ -581,6 +588,7 @@ def train(self):
far_plane=cfg.far_plane,
image_ids=image_ids,
render_mode="RGB+ED" if cfg.depth_loss else "RGB",
masks=masks,
)
if renders.shape[-1] == 4:
colors, depths = renders[..., 0:3], renders[..., 3:4]
Expand Down Expand Up @@ -806,6 +814,7 @@ def eval(self, step: int, stage: str = "val"):
camtoworlds = data["camtoworld"].to(device)
Ks = data["K"].to(device)
pixels = data["image"].to(device) / 255.0
masks = data["mask"].to(device) if "mask" in data else None
height, width = pixels.shape[1:3]

torch.cuda.synchronize()
Expand All @@ -818,6 +827,7 @@ def eval(self, step: int, stage: str = "val"):
sh_degree=cfg.sh_degree,
near_plane=cfg.near_plane,
far_plane=cfg.far_plane,
masks=masks,
) # [1, H, W, 3]
torch.cuda.synchronize()
ellipse_time += time.time() - tic
Expand Down Expand Up @@ -909,7 +919,10 @@ def render_traj(self, step: int):
K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
width, height = list(self.parser.imsize_dict.values())[0]

canvas_all = []
# save to video
video_dir = f"{cfg.result_dir}/videos"
os.makedirs(video_dir, exist_ok=True)
writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
camtoworlds = camtoworlds_all[i : i + 1]
Ks = K[None]
Expand All @@ -932,13 +945,6 @@ def render_traj(self, step: int):
# write images
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
canvas = (canvas * 255).astype(np.uint8)
canvas_all.append(canvas)

# save to video
video_dir = f"{cfg.result_dir}/videos"
os.makedirs(video_dir, exist_ok=True)
writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
for canvas in canvas_all:
writer.append_data(canvas)
writer.close()
print(f"Video saved to {video_dir}/traj_{step}.mp4")
Expand Down
68 changes: 68 additions & 0 deletions gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,71 @@ def _persp_proj(
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _fisheye_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of fisheye projection for 3D Gaussians.

Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3].
Ks: Camera intrinsics. [C, 3, 3].
width: Image width.
height: Image height.

Returns:
A tuple:

- **means2d**: Projected means. [C, N, 2].
- **cov2d**: Projected covariances. [C, N, 2, 2].
"""
C, N, _ = means.shape

x, y, z = torch.unbind(means, dim=-1) # [C, N]

fx = Ks[..., 0, 0, None] # [C, 1]
fy = Ks[..., 1, 1, None] # [C, 1]
cx = Ks[..., 0, 2, None] # [C, 1]
cy = Ks[..., 1, 2, None] # [C, 1]

eps = 0.0000001
xy_len = (x**2 + y**2) ** 0.5 + eps
theta = torch.atan2(xy_len, z + eps)
means2d = torch.stack(
[
x * fx * theta / xy_len + cx,
y * fy * theta / xy_len + cy,
],
dim=-1,
)

x2 = x * x + eps
y2 = y * y
xy = x * y
x2y2 = x2 + y2
x2y2z2_inv = 1.0 / (x2y2 + z * z)
b = torch.atan2(xy_len, z) / xy_len / x2y2
a = z * x2y2z2_inv / (x2y2)
J = torch.stack(
[
fx * (x2 * a + y2 * b),
fx * xy * (a - b),
-fx * x * x2y2z2_inv,
fy * xy * (a - b),
fy * (y2 * a + x2 * b),
-fy * y * x2y2z2_inv,
],
dim=-1,
).reshape(C, N, 2, 3)

cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2))
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _ortho_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Expand Down Expand Up @@ -193,6 +258,7 @@ def _fully_fused_projection(
far_plane: float = 1e10,
calc_compensations: bool = False,
ortho: bool = False,
fisheye: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
"""PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()`

Expand All @@ -205,6 +271,8 @@ def _fully_fused_projection(

if ortho:
means2d, covars2d = _ortho_proj(means_c, covars_c, Ks, width, height)
elif fisheye:
means2d, covars2d = _fisheye_proj(means_c, covars_c, Ks, width, height)
else:
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

Expand Down
Loading
Loading