Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Fisheye-GS. #398

Merged
merged 41 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27fa353
baseline
jefequien Sep 9, 2024
96089a1
fisheye forward works
jefequien Sep 9, 2024
30b4579
torch implementation of fisheye projection
jefequien Sep 9, 2024
c4651c0
test basic
jefequien Sep 9, 2024
712ae95
close 0.3%
jefequien Sep 10, 2024
01daadb
19 mismatched
jefequien Sep 10, 2024
2de03b9
pass tests
jefequien Sep 10, 2024
9c61743
comment out
jefequien Sep 10, 2024
6347c49
crashing
jefequien Sep 10, 2024
6e56adf
remove dead code
jefequien Sep 10, 2024
cb08427
reduce diff
jefequien Sep 10, 2024
22060d5
video
jefequien Sep 10, 2024
139f7fe
distortion not handled correctly
jefequien Sep 10, 2024
fde16f8
test remap
jefequien Sep 10, 2024
37f40a5
remove hardcoded roi
jefequien Sep 10, 2024
98a7819
cleanup tests
jefequien Sep 10, 2024
67baaae
fix bug
jefequien Sep 10, 2024
e41679d
bug
jefequien Sep 10, 2024
1ed34f0
edit imsize_dict
jefequien Sep 11, 2024
4434f09
format c++
jefequien Sep 11, 2024
3948fc9
T
jefequien Sep 11, 2024
ce98242
use mask
jefequien Sep 11, 2024
66128ca
remove test_remap
jefequien Sep 11, 2024
419c9e1
mask roi
jefequien Sep 11, 2024
4211157
scripts
jefequien Sep 11, 2024
148c218
reduce diff
jefequien Sep 11, 2024
6adbf6d
minor
jefequien Sep 11, 2024
4663665
Merge branch 'main' into jeff/fisheye
jefequien Sep 12, 2024
c2e7ada
weird ortho bug
jefequien Sep 12, 2024
412ea62
Merge branch 'main' into jeff/fisheye
jefequien Sep 13, 2024
a31e65a
vectorize
jefequien Sep 17, 2024
41aa398
ellipse
jefequien Sep 17, 2024
7f2972a
unify python side camera_model
jefequien Sep 17, 2024
593769c
fisheye packed mode
jefequien Sep 17, 2024
b133704
Merge branch 'main' into jeff/fisheye
jefequien Sep 17, 2024
07d2087
cuda enum
jefequien Sep 18, 2024
e737073
use c++ enum
jefequien Sep 18, 2024
fa23297
download dataset
jefequien Sep 18, 2024
2065224
refactor dataset download to download zipnerf
jefequien Sep 18, 2024
6fe028d
use lists
jefequien Sep 18, 2024
e6c19d6
use bilateral grid as default for zipnerf
jefequien Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions examples/benchmarks/mcmc_zipnerf.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SCENE_DIR="data/zipnerf/undistort"
# RESULT_DIR="results/benchmark_zipnerf/undistort"
# CAMERA_MODEL="pinhole"

SCENE_DIR="data/zipnerf/fisheye"
RESULT_DIR="results/benchmark_zipnerf/fisheye"
CAMERA_MODEL="fisheye"
SCENE_LIST="berlin alameda london nyc"
RENDER_TRAJ_PATH="interp"

CAP_MAX=2000000
DATA_FACTOR=4

for SCENE in $SCENE_LIST;
do
echo "Running $SCENE"

# train without eval
CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
--strategy.cap-max $CAP_MAX \
--opacity_reg 0.001 \
--camera_model $CAMERA_MODEL \
--render_traj_path $RENDER_TRAJ_PATH \
--data_dir $SCENE_DIR/$SCENE/ \
--result_dir $RESULT_DIR/$SCENE/

# run eval and render
# for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
# do
# CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \
# --strategy.cap-max $CAP_MAX \
# --camera_model $CAMERA_MODEL \
# --render_traj_path $RENDER_TRAJ_PATH \
# --data_dir $SCENE_DIR/$SCENE/ \
# --result_dir $RESULT_DIR/$SCENE/ \
# --ckpt $CKPT
# done
done


for SCENE in $SCENE_LIST;
do
echo "=== Eval Stats ==="

for STATS in $RESULT_DIR/$SCENE/stats/val*.json;
do
echo $STATS
cat $STATS;
echo
done

echo "=== Train Stats ==="

for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json;
do
echo $STATS
cat $STATS;
echo
done
done
53 changes: 44 additions & 9 deletions examples/datasets/colmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def __init__(
elif type_ == 5 or type_ == "OPENCV_FISHEYE":
params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
camtype = "fisheye"
assert (
camtype == "perspective"
), f"Only support perspective camera model, got {type_}"
# assert (
# camtype == "perspective"
# ), f"Only support perspective camera model, got {type_}"

params_dict[camera_id] = params

Expand Down Expand Up @@ -219,12 +219,47 @@ def __init__(
), f"Missing params for camera {camera_id}"
K = self.Ks_dict[camera_id]
width, height = self.imsize_dict[camera_id]
K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
K, params, (width, height), 0
)
mapx, mapy = cv2.initUndistortRectifyMap(
K, params, None, K_undist, (width, height), cv2.CV_32FC1
)

if camtype == "perspective":
K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
K, params, (width, height), 0
)
mapx, mapy = cv2.initUndistortRectifyMap(
K, params, None, K_undist, (width, height), cv2.CV_32FC1
)
elif camtype == "fisheye":
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]
mapx = np.zeros((height, width), dtype=np.float32)
mapy = np.zeros((height, width), dtype=np.float32)
for i in range(0, width):
for j in range(0, height):
x = float(i)
y = float(j)
x1 = (x - cx) / fx
y1 = (y - cy) / fy
theta = np.sqrt(x1**2 + y1**2)
r = (
1.0
+ params[0] * theta**2
+ params[1] * theta**4
+ params[2] * theta**6
+ params[3] * theta**8
)
x2 = fx * x1 * r + width // 2
y2 = fy * y1 * r + height // 2
mapx[j, i] = x2
mapy[j, i] = y2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batchify this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


x_crop, y_crop = (100, 70) # Hardcoded ROI crop
roi_undist = np.array(
[x_crop, y_crop, int(width - 2 * x_crop), int(height - 2 * y_crop)]
)
K_undist = K.copy()
K_undist[0, 2] -= x_crop
K_undist[1, 2] -= y_crop
self.Ks_dict[camera_id] = K_undist
self.mapx_dict[camera_id] = mapx
self.mapy_dict[camera_id] = mapy
Expand Down
17 changes: 8 additions & 9 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class Config:
global_scale: float = 1.0
# Normalize the world space
normalize_world_space: bool = True
# Camera model
camera_model: str = "pinhole"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor note: it would be nice to use Literal["fisheye", "pinhole"] here and elsewhere!


# Port for the viewer server
port: int = 8080
Expand Down Expand Up @@ -441,6 +443,7 @@ def rasterize_splats(
sparse_grad=self.cfg.sparse_grad,
rasterize_mode=rasterize_mode,
distributed=self.world_size > 1,
camera_model=self.cfg.camera_model,
**kwargs,
)
return render_colors, render_alphas, info
Expand Down Expand Up @@ -836,7 +839,10 @@ def render_traj(self, step: int):
K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
width, height = list(self.parser.imsize_dict.values())[0]

canvas_all = []
# save to video
video_dir = f"{cfg.result_dir}/videos"
os.makedirs(video_dir, exist_ok=True)
writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"):
camtoworlds = camtoworlds_all[i : i + 1]
Ks = K[None]
Expand All @@ -859,13 +865,6 @@ def render_traj(self, step: int):
# write images
canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy()
canvas = (canvas * 255).astype(np.uint8)
canvas_all.append(canvas)

# save to video
video_dir = f"{cfg.result_dir}/videos"
os.makedirs(video_dir, exist_ok=True)
writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
for canvas in canvas_all:
writer.append_data(canvas)
writer.close()
print(f"Video saved to {video_dir}/traj_{step}.mp4")
Expand Down Expand Up @@ -927,7 +926,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config):
runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts])
step = ckpts[0]["step"]
runner.eval(step=step)
runner.render_traj(step=step)
# runner.render_traj(step=step)
if cfg.compression is not None:
runner.run_compression(step=step)
else:
Expand Down
63 changes: 63 additions & 0 deletions examples/test_remap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import numpy as np
import cv2
import imageio


def init_fisheye_remap(K, params, width, height):
fx = K[0, 0]
fy = K[1, 1]
cx = K[0, 2]
cy = K[1, 2]

mapx = np.zeros((height, width), dtype=np.float32)
mapy = np.zeros((height, width), dtype=np.float32)
for i in range(0, width):
for j in range(0, height):
x = float(i)
y = float(j)
x1 = (x - cx) / fx
y1 = (y - cy) / fy
theta = np.sqrt(x1**2 + y1**2)
r = (
1.0
+ params[0] * theta**2
+ params[1] * theta**4
+ params[2] * theta**6
+ params[3] * theta**8
)
x2 = fx * x1 * r + width // 2
y2 = fy * y1 * r + height // 2
mapx[j, i] = x2
mapy[j, i] = y2
return mapx, mapy


def main():
K = np.array(
[[610.93592297, 0.0, 876.0], [0.0, 610.84071973, 584.0], [0.0, 0.0, 1.0]]
)
params = np.array([0.03699945, 0.00660936, 0.00116909, -0.00038226])
width, height = (1752, 1168)

mapx, mapy = init_fisheye_remap(K, params, width, height)

x_min = np.nonzero(mapx < 0)[1].max()
x_max = np.nonzero(mapx > width)[1].min()
y_min = np.nonzero(mapy < 0)[0].max()
y_max = np.nonzero(mapy > height)[0].min()
roi_undist = [x_min, y_min, x_max - x_min, y_max - y_min]
K[0, 2] -= x_min
K[1, 2] -= y_min

image = imageio.imread("./data/zipnerf/fisheye/berlin/images_4/DSC00040.JPG")[
..., :3
]
image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
imageio.imwrite("./results/test_remap.png", image)
x, y, w, h = roi_undist
image = image[y : y + h, x : x + w]
imageio.imwrite("./results/test_remap_crop.png", image)


if __name__ == "__main__":
main()
77 changes: 76 additions & 1 deletion gsplat/cuda/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,71 @@ def _persp_proj(
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _fisheye_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Ks: Tensor, # [C, 3, 3]
width: int,
height: int,
) -> Tuple[Tensor, Tensor]:
"""PyTorch implementation of fisheye projection for 3D Gaussians.

Args:
means: Gaussian means in camera coordinate system. [C, N, 3].
covars: Gaussian covariances in camera coordinate system. [C, N, 3, 3].
Ks: Camera intrinsics. [C, 3, 3].
width: Image width.
height: Image height.

Returns:
A tuple:

- **means2d**: Projected means. [C, N, 2].
- **cov2d**: Projected covariances. [C, N, 2, 2].
"""
C, N, _ = means.shape

x, y, z = torch.unbind(means, dim=-1) # [C, N]

fx = Ks[..., 0, 0, None] # [C, 1]
fy = Ks[..., 1, 1, None] # [C, 1]
cx = Ks[..., 0, 2, None] # [C, 1]
cy = Ks[..., 1, 2, None] # [C, 1]

eps = 0.0000001
xy_len = (x**2 + y**2) ** 0.5 + eps
theta = torch.atan2(xy_len, z + eps)
means2d = torch.stack(
[
x * fx * theta / xy_len + cx,
y * fy * theta / xy_len + cy,
],
dim=-1,
)

x2 = x * x + eps
y2 = y * y
xy = x * y
x2y2 = x2 + y2
x2y2z2_inv = 1.0 / (x2y2 + z * z)
b = torch.atan2(xy_len, z) / xy_len / x2y2
a = z * x2y2z2_inv / (x2y2)
J = torch.stack(
[
fx * (x2 * a + y2 * b),
fx * xy * (a - b),
-fx * x * x2y2z2_inv,
fy * xy * (a - b),
fy * (y2 * a + x2 * b),
-fy * y * x2y2z2_inv,
],
dim=-1,
).reshape(C, N, 2, 3)

cov2d = torch.einsum("...ij,...jk,...kl->...il", J, covars, J.transpose(-1, -2))
return means2d, cov2d # [C, N, 2], [C, N, 2, 2]


def _ortho_proj(
means: Tensor, # [C, N, 3]
covars: Tensor, # [C, N, 3, 3]
Expand Down Expand Up @@ -170,7 +235,9 @@ def _world_to_cam(

def _fully_fused_projection(
means: Tensor, # [N, 3]
covars: Tensor, # [N, 3, 3]
covars: Optional[Tensor], # [N, 6] or None
quats: Optional[Tensor], # [N, 4] or None
scales: Optional[Tensor], # [N, 3] or None
viewmats: Tensor, # [C, 4, 4]
Ks: Tensor, # [C, 3, 3]
width: int,
Expand All @@ -180,6 +247,7 @@ def _fully_fused_projection(
far_plane: float = 1e10,
calc_compensations: bool = False,
ortho: bool = False,
fisheye: bool = False,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Optional[Tensor]]:
"""PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection()`

Expand All @@ -188,10 +256,17 @@ def _fully_fused_projection(
This is a minimal implementation of fully fused version, which has more
arguments. Not all arguments are supported.
"""
if covars is None:
covars = _quat_scale_to_covar_preci(
quats, scales, compute_covar=True, compute_preci=False
)[0]

means_c, covars_c = _world_to_cam(means, covars, viewmats)

if ortho:
means2d, covars2d = _ortho_proj(means_c, covars_c, Ks, width, height)
elif fisheye:
means2d, covars2d = _fisheye_proj(means_c, covars_c, Ks, width, height)
else:
means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height)

Expand Down
Loading
Loading