Skip to content

Commit

Permalink
finish map_gaussians
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoyang-Pan committed Oct 2, 2023
1 parent e1f1cdd commit 69e86a7
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 2 deletions.
37 changes: 36 additions & 1 deletion diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Pure PyTorch implementations of various functions"""

import torch
import torch.nn.functional as F
import struct
from jaxtyping import Float
from torch import Tensor

Expand Down Expand Up @@ -291,3 +291,38 @@ def project_gaussians_forward(
conics = conic

return cov3d, xys, depths, radii, conics, num_tiles_hit, mask

def map_gaussian_to_intersects(
num_points,
xys,
depths,
radii,
cum_tiles_hit,
tile_bounds
):
num_intersects = cum_tiles_hit[-1]
isect_ids = torch.zeros(num_intersects, dtype=torch.int64, device=xys.device)
gaussian_ids = torch.zeros(num_intersects, dtype=torch.int32, device=xys.device)

for idx in range(num_points):
if radii[idx] <= 0: break

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack('f', depths[idx])

# Interpret those bytes as an int32_t
depth_id_n = struct.unpack('i', raw_bytes)[0]

for i in range(tile_min[1], tile_max[1]):
for j in range(tile_min[0], tile_max[0]):
tile_id = i * tile_bounds[0] + j
isect_ids[cur_idx] = (tile_id << 32) | depth_id_n
gaussian_ids[cur_idx] = idx
cur_idx += 1

return isect_ids, gaussian_ids

1 change: 1 addition & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ def call_cuda(*args, **kwargs):
compute_sh_forward = _make_lazy_cuda_func("compute_sh_forward")
compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward")
compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects")
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
2 changes: 1 addition & 1 deletion diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -325,5 +325,5 @@ map_gaussian_to_intersects_tensor(
gaussian_ids_unsorted.contiguous().data_ptr<int32_t>()
);

return std::make_tuple(gaussian_ids_unsorted, isect_ids_unsorted);
return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted);
}
1 change: 1 addition & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_sh_forward", &compute_sh_forward_tensor);
m.def("compute_sh_backward", &compute_sh_backward_tensor);
m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor);
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
}
78 changes: 78 additions & 0 deletions tests/test_map_gaussians.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
import torch


device = torch.device("cuda:0")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device")
def test_map_gaussians():
from diff_rast import _torch_impl
import diff_rast.cuda as _C

torch.manual_seed(42)

num_points = 100
means3d = torch.randn((num_points, 3), device=device, requires_grad=True)
scales = torch.randn((num_points, 3), device=device)
glob_scale = 0.3
quats = torch.randn((num_points, 4), device=device)
quats /= torch.linalg.norm(quats, dim=-1, keepdim=True)
viewmat = torch.eye(4, device=device)
projmat = torch.eye(4, device=device)
fx, fy = 3.0, 3.0
H, W = 512, 512
clip_thresh = 0.01

BLOCK_X, BLOCK_Y = 16, 16
tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1

(
_cov3d,
_xys,
_depths,
_radii,
_conics,
_num_tiles_hit,
_masks,
) = _torch_impl.project_gaussians_forward(
means3d,
scales,
glob_scale,
quats,
viewmat,
projmat,
fx,
fy,
(H, W),
tile_bounds,
clip_thresh,
)

_cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32)
_depths = _depths.contiguous()

isect_ids, gaussian_ids = _C.map_gaussian_to_intersects(
num_points,
_xys,
_depths,
_radii,
_cum_tiles_hit,
tile_bounds
)

_isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects(
num_points,
_xys,
_depths,
_radii,
_cum_tiles_hit,
tile_bounds
)

torch.testing.assert_close(gaussian_ids, _gaussian_ids)
torch.testing.assert_close(isect_ids, _isect_ids)


if __name__ == "__main__":
test_map_gaussians()

0 comments on commit 69e86a7

Please sign in to comment.