From d9dfdd93febe926c22c34632042b27c1bae8bd76 Mon Sep 17 00:00:00 2001 From: Zhuoyang Date: Tue, 3 Oct 2023 03:10:46 -0700 Subject: [PATCH 1/8] add get_tile_bin_edges binding --- diff_rast/cuda/csrc/bindings.cu | 17 +++++++++++++++++ diff_rast/cuda/csrc/bindings.h | 5 +++++ diff_rast/cuda/csrc/ext.cpp | 1 + diff_rast/cuda/csrc/forward.cuh | 4 ++++ 4 files changed, 27 insertions(+) diff --git a/diff_rast/cuda/csrc/bindings.cu b/diff_rast/cuda/csrc/bindings.cu index 9b5cc8753..bc43fc5bc 100644 --- a/diff_rast/cuda/csrc/bindings.cu +++ b/diff_rast/cuda/csrc/bindings.cu @@ -327,3 +327,20 @@ map_gaussian_to_intersects_tensor( return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted); } + +torch::Tensor get_tile_bin_edges_tensor( + int num_intersects, + torch::Tensor &isect_ids_sorted +) { + CHECK_INPUT(isect_ids_sorted); + torch::Tensor tile_bins = + torch::zeros({num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32)); + get_tile_bin_edges<<< + (num_intersects + N_THREADS - 1) / N_THREADS, + N_THREADS>>>( + num_intersects, + isect_ids_sorted.contiguous().data_ptr(), + (int2 *)tile_bins.contiguous().data_ptr() + ); + return tile_bins; +} \ No newline at end of file diff --git a/diff_rast/cuda/csrc/bindings.h b/diff_rast/cuda/csrc/bindings.h index 852004824..3ec4f1cff 100644 --- a/diff_rast/cuda/csrc/bindings.h +++ b/diff_rast/cuda/csrc/bindings.h @@ -98,4 +98,9 @@ map_gaussian_to_intersects_tensor( torch::Tensor &radii, torch::Tensor &cum_tiles_hit, const std::tuple tile_bounds +); + +torch::Tensor get_tile_bin_edges_tensor( + int num_intersects, + torch::Tensor &isect_ids_sorted ); \ No newline at end of file diff --git a/diff_rast/cuda/csrc/ext.cpp b/diff_rast/cuda/csrc/ext.cpp index 1df3cd8ae..c733bf55a 100644 --- a/diff_rast/cuda/csrc/ext.cpp +++ b/diff_rast/cuda/csrc/ext.cpp @@ -12,4 +12,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_sh_backward", &compute_sh_backward_tensor); m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor); m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor); + m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor); } diff --git a/diff_rast/cuda/csrc/forward.cuh b/diff_rast/cuda/csrc/forward.cuh index 624b8ef50..57b39d7b3 100644 --- a/diff_rast/cuda/csrc/forward.cuh +++ b/diff_rast/cuda/csrc/forward.cuh @@ -88,4 +88,8 @@ __global__ void map_gaussian_to_intersects( const dim3 tile_bounds, int64_t *isect_ids, int32_t *gaussian_ids +); + +__global__ void get_tile_bin_edges( + const int num_intersects, const int64_t *isect_ids_sorted, int2 *tile_bins ); \ No newline at end of file From 6e529826bc65203ba781f20fd6030449a9881a28 Mon Sep 17 00:00:00 2001 From: Zhuoyang Date: Wed, 4 Oct 2023 03:30:31 -0700 Subject: [PATCH 2/8] finish pybinding for get_tile_bin_edges and fix the bug of tile_idx --- diff_rast/_torch_impl.py | 28 +++++++++++- diff_rast/cuda/__init__.py | 1 + diff_rast/cuda/csrc/forward.cu | 2 +- tests/test_get_tile_bin_edges.py | 76 ++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 2 deletions(-) create mode 100644 tests/test_get_tile_bin_edges.py diff --git a/diff_rast/_torch_impl.py b/diff_rast/_torch_impl.py index f100d1624..a2ea58f5b 100644 --- a/diff_rast/_torch_impl.py +++ b/diff_rast/_torch_impl.py @@ -312,7 +312,7 @@ def map_gaussian_to_intersects( raw_bytes = struct.pack("f", depths[idx]) # Interpret those bytes as an int32_t - depth_id_n = struct.unpack("i", raw_bytes)[0] + depth_id_n = struct.unpack("I", raw_bytes)[0] for i in range(tile_min[1], tile_max[1]): for j in range(tile_min[0], tile_max[0]): @@ -322,3 +322,29 @@ def map_gaussian_to_intersects( cur_idx += 1 return isect_ids, gaussian_ids + + +def get_tile_bin_edges(num_intersects, isect_ids_sorted): + tile_bins = torch.zeros( + (num_intersects, 2), dtype=torch.int32, device=isect_ids_sorted.device + ) + + for idx in range(num_intersects): + + cur_tile_idx = isect_ids_sorted[idx] >> 32 + + if idx == 0: + tile_bins[cur_tile_idx, 0] = 0 + continue + + if idx == num_intersects - 1: + tile_bins[cur_tile_idx, 1] = num_intersects + break + + prev_tile_idx = isect_ids_sorted[idx - 1] >> 32 + + if cur_tile_idx != prev_tile_idx: + tile_bins[prev_tile_idx, 1] = idx + tile_bins[cur_tile_idx, 0] = idx + + return tile_bins diff --git a/diff_rast/cuda/__init__.py b/diff_rast/cuda/__init__.py index d418081cd..e2ca4dc32 100644 --- a/diff_rast/cuda/__init__.py +++ b/diff_rast/cuda/__init__.py @@ -20,3 +20,4 @@ def call_cuda(*args, **kwargs): compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward") compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects") map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects") +get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges") diff --git a/diff_rast/cuda/csrc/forward.cu b/diff_rast/cuda/csrc/forward.cu index aa12b771b..4f25f9b03 100644 --- a/diff_rast/cuda/csrc/forward.cu +++ b/diff_rast/cuda/csrc/forward.cu @@ -166,7 +166,7 @@ __global__ void map_gaussian_to_intersects( // update the intersection info for all tiles this gaussian hits int32_t cur_idx = (idx == 0) ? 0 : cum_tiles_hit[idx - 1]; // printf("point %d starting at %d\n", idx, cur_idx); - int64_t depth_id = (int64_t) * (int32_t *)&(depths[idx]); + u_int64_t depth_id = (u_int64_t) * (u_int32_t *)&(depths[idx]); for (int i = tile_min.y; i < tile_max.y; ++i) { for (int j = tile_min.x; j < tile_max.x; ++j) { // isect_id is tile ID and depth as int32 diff --git a/tests/test_get_tile_bin_edges.py b/tests/test_get_tile_bin_edges.py new file mode 100644 index 000000000..b71676f3e --- /dev/null +++ b/tests/test_get_tile_bin_edges.py @@ -0,0 +1,76 @@ +import pytest +import torch + + +device = torch.device("cuda:0") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_get_tile_bin_edges(): + from diff_rast import _torch_impl + import diff_rast.cuda as _C + + torch.manual_seed(42) + + num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) + scales = torch.randn((num_points, 3), device=device) + glob_scale = 0.3 + quats = torch.randn((num_points, 4), device=device) + quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) + viewmat = torch.eye(4, device=device) + projmat = torch.eye(4, device=device) + fx, fy = 3.0, 3.0 + H, W = 512, 512 + clip_thresh = 0.01 + + BLOCK_X, BLOCK_Y = 16, 16 + tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 + + ( + _cov3d, + _xys, + _depths, + _radii, + _conics, + _num_tiles_hit, + _masks, + ) = _torch_impl.project_gaussians_forward( + means3d, + scales, + glob_scale, + quats, + viewmat, + projmat, + fx, + fy, + (H, W), + tile_bounds, + clip_thresh, + ) + + _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) + _num_intersects = _cum_tiles_hit[-1].item() + _depths = _depths.contiguous() + + ( + _isect_ids_unsorted, + _gaussian_ids_unsorted, + ) = _torch_impl.map_gaussian_to_intersects( + num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + # Sorting isect_ids_unsorted + sorted_values, sorted_indices = torch.sort(_isect_ids_unsorted) + + _isect_ids_sorted = sorted_values + _gaussian_ids_sorted = torch.gather(_gaussian_ids_unsorted, 0, sorted_indices) + + _tile_bins = _torch_impl.get_tile_bin_edges(_num_intersects, _isect_ids_sorted) + tile_bins = _C.get_tile_bin_edges(_num_intersects, _isect_ids_sorted) + + torch.testing.assert_close(_tile_bins, tile_bins) + + +if __name__ == "__main__": + test_get_tile_bin_edges() From 3021ffa01ad47210b7c9e3a95641efb7a03fc676 Mon Sep 17 00:00:00 2001 From: maturk Date: Wed, 4 Oct 2023 23:11:55 +0300 Subject: [PATCH 3/8] use _masks to not get negative depths --- tests/test_get_tile_bin_edges.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_get_tile_bin_edges.py b/tests/test_get_tile_bin_edges.py index b71676f3e..fe69c8d45 100644 --- a/tests/test_get_tile_bin_edges.py +++ b/tests/test_get_tile_bin_edges.py @@ -49,6 +49,14 @@ def test_get_tile_bin_edges(): clip_thresh, ) + _xys = _xys[_masks] + _depths = _depths[_masks] + _radii = _radii[_masks] + _conics = _conics[_masks] + _num_tiles_hit = _num_tiles_hit[_masks] + + num_points = num_points - torch.count_nonzero(~_masks).item() + _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) _num_intersects = _cum_tiles_hit[-1].item() _depths = _depths.contiguous() From 329fc9fe6bd5305aa63947ea4400fc3672879313 Mon Sep 17 00:00:00 2001 From: Zhuoyang Date: Wed, 4 Oct 2023 16:58:36 -0700 Subject: [PATCH 4/8] finish bin_and_sort_gaussians and fix some bugs of _torch_impl --- diff_rast/_torch_impl.py | 22 ++++++- diff_rast/cuda/__init__.py | 1 + diff_rast/cuda/csrc/bindings.cu | 74 ++++++++++++++++++++-- diff_rast/cuda/csrc/bindings.h | 28 +++++++-- diff_rast/cuda/csrc/ext.cpp | 1 + diff_rast/cuda/csrc/forward.cu | 2 +- tests/test_bin_and_sort_gaussians.py | 92 ++++++++++++++++++++++++++++ tests/test_map_gaussians.py | 12 +++- 8 files changed, 214 insertions(+), 18 deletions(-) create mode 100644 tests/test_bin_and_sort_gaussians.py diff --git a/diff_rast/_torch_impl.py b/diff_rast/_torch_impl.py index a2ea58f5b..52fd2ba87 100644 --- a/diff_rast/_torch_impl.py +++ b/diff_rast/_torch_impl.py @@ -306,13 +306,13 @@ def map_gaussian_to_intersects( tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds) - cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1] + cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1].item() # Get raw byte representation of the float value at the given index raw_bytes = struct.pack("f", depths[idx]) # Interpret those bytes as an int32_t - depth_id_n = struct.unpack("I", raw_bytes)[0] + depth_id_n = struct.unpack("i", raw_bytes)[0] for i in range(tile_min[1], tile_max[1]): for j in range(tile_min[0], tile_max[0]): @@ -348,3 +348,21 @@ def get_tile_bin_edges(num_intersects, isect_ids_sorted): tile_bins[cur_tile_idx, 0] = idx return tile_bins + + +def bin_and_sort_gaussians( + num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds +): + isect_ids, gaussian_ids = map_gaussian_to_intersects( + num_points, xys, depths, radii, cum_tiles_hit, tile_bounds + ) + + # Sorting isect_ids_unsorted + sorted_values, sorted_indices = torch.sort(isect_ids) + + isect_ids_sorted = sorted_values + gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices) + + tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted) + + return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins diff --git a/diff_rast/cuda/__init__.py b/diff_rast/cuda/__init__.py index e2ca4dc32..df7dcd614 100644 --- a/diff_rast/cuda/__init__.py +++ b/diff_rast/cuda/__init__.py @@ -21,3 +21,4 @@ def call_cuda(*args, **kwargs): compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects") map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects") get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges") +bin_and_sort_gaussians = _make_lazy_cuda_func("bin_and_sort_gaussians") diff --git a/diff_rast/cuda/csrc/bindings.cu b/diff_rast/cuda/csrc/bindings.cu index 1747fff51..63b0752d1 100644 --- a/diff_rast/cuda/csrc/bindings.cu +++ b/diff_rast/cuda/csrc/bindings.cu @@ -258,7 +258,7 @@ project_gaussians_backward_tensor( } std::tuple compute_cumulative_intersects_tensor( - const int num_points, torch::Tensor &num_tiles_hit + const int num_points, const torch::Tensor &num_tiles_hit ) { // ref: // https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a @@ -287,10 +287,10 @@ std::tuple compute_cumulative_intersects_tensor( std::tuple map_gaussian_to_intersects_tensor( const int num_points, - torch::Tensor &xys, - torch::Tensor &depths, - torch::Tensor &radii, - torch::Tensor &cum_tiles_hit, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, const std::tuple tile_bounds ) { CHECK_INPUT(xys); @@ -329,7 +329,7 @@ std::tuple map_gaussian_to_intersects_tensor( torch::Tensor get_tile_bin_edges_tensor( int num_intersects, - torch::Tensor &isect_ids_sorted + const torch::Tensor &isect_ids_sorted ) { CHECK_INPUT(isect_ids_sorted); torch::Tensor tile_bins = @@ -342,4 +342,66 @@ torch::Tensor get_tile_bin_edges_tensor( (int2 *)tile_bins.contiguous().data_ptr() ); return tile_bins; +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +bin_and_sort_gaussians_tensor( + const int num_points, + const int num_intersects, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds +){ + CHECK_INPUT(xys); + CHECK_INPUT(depths); + CHECK_INPUT(radii); + CHECK_INPUT(cum_tiles_hit); + + dim3 tile_bounds_dim3; + tile_bounds_dim3.x = std::get<0>(tile_bounds); + tile_bounds_dim3.y = std::get<1>(tile_bounds); + tile_bounds_dim3.z = std::get<2>(tile_bounds); + + torch::Tensor gaussian_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); + torch::Tensor gaussian_ids_sorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32)); + torch::Tensor isect_ids_unsorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + torch::Tensor isect_ids_sorted = + torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64)); + torch::Tensor tile_bins = + torch::zeros({num_intersects, 2}, xys.options().dtype(torch::kInt32)); + + bin_and_sort_gaussians( + num_points, + num_intersects, + (float2 *)xys.contiguous().data_ptr(), + depths.contiguous().data_ptr(), + radii.contiguous().data_ptr(), + cum_tiles_hit.contiguous().data_ptr(), + tile_bounds_dim3, + // Outputs. + isect_ids_unsorted.contiguous().data_ptr(), + gaussian_ids_unsorted.contiguous().data_ptr(), + isect_ids_sorted.contiguous().data_ptr(), + gaussian_ids_sorted.contiguous().data_ptr(), + (int2 *)tile_bins.contiguous().data_ptr() + ); + + return std::make_tuple( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins + ); + } \ No newline at end of file diff --git a/diff_rast/cuda/csrc/bindings.h b/diff_rast/cuda/csrc/bindings.h index e236d19ab..3b20e62b3 100644 --- a/diff_rast/cuda/csrc/bindings.h +++ b/diff_rast/cuda/csrc/bindings.h @@ -81,19 +81,35 @@ project_gaussians_backward_tensor( ); std::tuple compute_cumulative_intersects_tensor( - const int num_points, torch::Tensor &num_tiles_hit + const int num_points, const torch::Tensor &num_tiles_hit ); std::tuple map_gaussian_to_intersects_tensor( const int num_points, - torch::Tensor &xys, - torch::Tensor &depths, - torch::Tensor &radii, - torch::Tensor &cum_tiles_hit, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, const std::tuple tile_bounds ); torch::Tensor get_tile_bin_edges_tensor( int num_intersects, - torch::Tensor &isect_ids_sorted + const torch::Tensor &isect_ids_sorted +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +bin_and_sort_gaussians_tensor( + const int num_points, + const int num_intersects, + const torch::Tensor &xys, + const torch::Tensor &depths, + const torch::Tensor &radii, + const torch::Tensor &cum_tiles_hit, + const std::tuple tile_bounds ); \ No newline at end of file diff --git a/diff_rast/cuda/csrc/ext.cpp b/diff_rast/cuda/csrc/ext.cpp index c733bf55a..0b8ee4b15 100644 --- a/diff_rast/cuda/csrc/ext.cpp +++ b/diff_rast/cuda/csrc/ext.cpp @@ -13,4 +13,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor); m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor); m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor); + m.def("bin_and_sort_gaussians", &bin_and_sort_gaussians_tensor); } diff --git a/diff_rast/cuda/csrc/forward.cu b/diff_rast/cuda/csrc/forward.cu index 4f25f9b03..aa12b771b 100644 --- a/diff_rast/cuda/csrc/forward.cu +++ b/diff_rast/cuda/csrc/forward.cu @@ -166,7 +166,7 @@ __global__ void map_gaussian_to_intersects( // update the intersection info for all tiles this gaussian hits int32_t cur_idx = (idx == 0) ? 0 : cum_tiles_hit[idx - 1]; // printf("point %d starting at %d\n", idx, cur_idx); - u_int64_t depth_id = (u_int64_t) * (u_int32_t *)&(depths[idx]); + int64_t depth_id = (int64_t) * (int32_t *)&(depths[idx]); for (int i = tile_min.y; i < tile_max.y; ++i) { for (int j = tile_min.x; j < tile_max.x; ++j) { // isect_id is tile ID and depth as int32 diff --git a/tests/test_bin_and_sort_gaussians.py b/tests/test_bin_and_sort_gaussians.py new file mode 100644 index 000000000..24b2bae31 --- /dev/null +++ b/tests/test_bin_and_sort_gaussians.py @@ -0,0 +1,92 @@ +import pytest +import torch + + +device = torch.device("cuda:0") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_bin_and_sort_gaussians(): + from diff_rast import _torch_impl + import diff_rast.cuda as _C + + torch.manual_seed(42) + + num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) + scales = torch.randn((num_points, 3), device=device) + glob_scale = 0.3 + quats = torch.randn((num_points, 4), device=device) + quats /= torch.linalg.norm(quats, dim=-1, keepdim=True) + viewmat = torch.eye(4, device=device) + projmat = torch.eye(4, device=device) + fx, fy = 3.0, 3.0 + H, W = 512, 512 + clip_thresh = 0.01 + + BLOCK_X, BLOCK_Y = 16, 16 + tile_bounds = (W + BLOCK_X - 1) // BLOCK_X, (H + BLOCK_Y - 1) // BLOCK_Y, 1 + + ( + _cov3d, + _xys, + _depths, + _radii, + _conics, + _num_tiles_hit, + _masks, + ) = _torch_impl.project_gaussians_forward( + means3d, + scales, + glob_scale, + quats, + viewmat, + projmat, + fx, + fy, + (H, W), + tile_bounds, + clip_thresh, + ) + + _xys = _xys[_masks] + _depths = _depths[_masks] + _radii = _radii[_masks] + _conics = _conics[_masks] + _num_tiles_hit = _num_tiles_hit[_masks] + + num_points = num_points - torch.count_nonzero(~_masks).item() + + _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) + _num_intersects = _cum_tiles_hit[-1].item() + _depths = _depths.contiguous() + + ( + _isect_ids_unsorted, + _gaussian_ids_unsorted, + _isect_ids_sorted, + _gaussian_ids_sorted, + _tile_bins, + ) = _torch_impl.bin_and_sort_gaussians( + num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + ( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins, + ) = _C.bin_and_sort_gaussians( + num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds + ) + + torch.testing.assert_close(_isect_ids_unsorted, isect_ids_unsorted) + torch.testing.assert_close(_gaussian_ids_unsorted, gaussian_ids_unsorted) + torch.testing.assert_close(_isect_ids_sorted, isect_ids_sorted) + torch.testing.assert_close(_gaussian_ids_sorted, gaussian_ids_sorted) + torch.testing.assert_close(_tile_bins, tile_bins) + + +if __name__ == "__main__": + test_bin_and_sort_gaussians() diff --git a/tests/test_map_gaussians.py b/tests/test_map_gaussians.py index 394fe7199..912d22226 100644 --- a/tests/test_map_gaussians.py +++ b/tests/test_map_gaussians.py @@ -48,15 +48,21 @@ def test_map_gaussians(): tile_bounds, clip_thresh, ) + _xys = _xys[_masks] + _depths = _depths[_masks] + _radii = _radii[_masks] + _conics = _conics[_masks] + _num_tiles_hit = _num_tiles_hit[_masks] + + num_points = num_points - torch.count_nonzero(~_masks).item() _cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32) _depths = _depths.contiguous() - isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( + _isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects( num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds ) - - _isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects( + isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds ) From cd7a63b24f43c422d2d8eed6d6fc322ba79bedb5 Mon Sep 17 00:00:00 2001 From: maturk Date: Thu, 5 Oct 2023 11:07:27 +0300 Subject: [PATCH 5/8] update docs --- diff_rast/cov2d_bounds.py | 5 +++-- diff_rast/project_gaussians.py | 10 +++++++++- diff_rast/rasterize.py | 5 ++++- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/diff_rast/cov2d_bounds.py b/diff_rast/cov2d_bounds.py index 0a946c612..9d6250eea 100644 --- a/diff_rast/cov2d_bounds.py +++ b/diff_rast/cov2d_bounds.py @@ -10,14 +10,15 @@ import diff_rast.cuda as _C -class compute_cov2d_bounds(Function): +class ComputeCov2dBounds(Function): """Computes bounds of 2D covariance matrix Args: cov2d (Tensor): input cov2d of size (batch, 3) of upper triangular 2D covariance values Returns: - conics (batch, 3) and radii (batch, 1) + conic (Tensor): conic parameters for 2D gaussian. + radii (Tensor): radii of 2D gaussian projections. """ @staticmethod diff --git a/diff_rast/project_gaussians.py b/diff_rast/project_gaussians.py index ad9e53a54..e6470a64f 100644 --- a/diff_rast/project_gaussians.py +++ b/diff_rast/project_gaussians.py @@ -10,7 +10,7 @@ class ProjectGaussians(Function): - """Project 3D Gaussians to 2D. + """This function projects 3D gaussians to 2D using the EWA splatting method for gaussian splatting. Args: means3d (Tensor): xyzs of gaussians. @@ -24,6 +24,14 @@ class ProjectGaussians(Function): img_height (int): height of the rendered image. img_width (int): width of the rendered image. tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1). + + Returns: + xys (Tensor): x,y locations of 2D gaussian projections. + depths (Tensor): z depth of gaussians. + radii (Tensor): radii of 2D gaussian projections. + conics (Tensor): conic parameters for 2D gaussian. + num_tiles_hit (int): number of tiles hit. + cov3d (Tensor): 3D covariances. """ @staticmethod diff --git a/diff_rast/rasterize.py b/diff_rast/rasterize.py index e219c4d02..efc3f64ed 100644 --- a/diff_rast/rasterize.py +++ b/diff_rast/rasterize.py @@ -11,7 +11,7 @@ class RasterizeGaussians(Function): - """Rasterize 2D gaussians. + """Rasterizes 2D gaussians by sorting and binning gaussian intersections for each tile and returns an output image using alpha-compositing. Args: xys (Tensor): xy coords of 2D gaussians. @@ -24,6 +24,9 @@ class RasterizeGaussians(Function): img_height (int): height of the rendered image. img_width (int): width of the rendered image. background (Tensor): background color + + Returns: + out_img (Tensor): the rendered output image. """ @staticmethod From 46160ffd2fd5c9ebf5e1d6f84c638d3013e408cd Mon Sep 17 00:00:00 2001 From: maturk Date: Thu, 5 Oct 2023 11:08:17 +0300 Subject: [PATCH 6/8] python wrappers for pybindings --- diff_rast/bin_and_sort_gaussians.py | 72 ++++++++++++++++++++++ diff_rast/compute_cumulative_intersects.py | 37 +++++++++++ diff_rast/get_tile_bin_edges.py | 38 ++++++++++++ diff_rast/map_gaussian_to_intersects.py | 46 ++++++++++++++ 4 files changed, 193 insertions(+) create mode 100644 diff_rast/bin_and_sort_gaussians.py create mode 100644 diff_rast/compute_cumulative_intersects.py create mode 100644 diff_rast/get_tile_bin_edges.py create mode 100644 diff_rast/map_gaussian_to_intersects.py diff --git a/diff_rast/bin_and_sort_gaussians.py b/diff_rast/bin_and_sort_gaussians.py new file mode 100644 index 000000000..8e1f504f5 --- /dev/null +++ b/diff_rast/bin_and_sort_gaussians.py @@ -0,0 +1,72 @@ +"""Python bindings for binning and sorting gaussians""" + +from typing import Tuple, Any + +from jaxtyping import Float +from torch import Tensor +from torch.autograd import Function + +import diff_rast.cuda as _C + + +class BinAndSortGaussians(Function): + """Function for mapping gaussians to sorted unique intersection IDs and tile bins used for fast rasterization. + + We return both sorted and unsorted versions of intersect IDs and gaussian IDs for testing purposes. + + Args: + num_points (int): number of gaussians. + num_intersects (int): cumulative number of total gaussian intersections + xys (Tensor): x,y locations of 2D gaussian projections. + depths (Tensor): z depth of gaussians. + radii (Tensor): radii of 2D gaussian projections. + cum_tiles_hit (Tensor): list of cumulative tiles hit. + tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1). + + Returns: + isect_ids_unsorted (Tensor): unique IDs for each gaussian in the form (tile | depth id). + gaussian_ids_unsorted (Tensor): Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians. + isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id). + gaussian_ids_sorted (Tensor): sorted Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians. + tile_bins (Tensor): range of gaussians hit per tile. + """ + + @staticmethod + def forward( + ctx, + num_points: int, + num_intersects: int, + xys: Float[Tensor, "batch 2"], + depths: Float[Tensor, "batch 1"], + radii: Float[Tensor, "batch 1"], + cum_tiles_hit: Float[Tensor, "batch 1"], + tile_bounds: Tuple[int, int, int], + ) -> Tuple[ + Float[Tensor, "num_intersects 1"], + Float[Tensor, "num_intersects 1"], + Float[Tensor, "num_intersects 1"], + Float[Tensor, "num_intersects 1"], + Float[Tensor, "num_intersects 2"], + ]: + + ( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins, + ) = _C.bin_and_sort_gaussians( + num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds + ) + + return ( + isect_ids_unsorted, + gaussian_ids_unsorted, + isect_ids_sorted, + gaussian_ids_sorted, + tile_bins, + ) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError diff --git a/diff_rast/compute_cumulative_intersects.py b/diff_rast/compute_cumulative_intersects.py new file mode 100644 index 000000000..645e76cb5 --- /dev/null +++ b/diff_rast/compute_cumulative_intersects.py @@ -0,0 +1,37 @@ +"""Python bindings for computing cumulative intersects""" + +from typing import Tuple, Any + +from jaxtyping import Float +from torch import Tensor +from torch.autograd import Function + +import diff_rast.cuda as _C + + +class ComputeCumulativeIntersects(Function): + """Computes cumulative intersections of gaussians. This is useful for creating unique gaussian IDs and for sorting. + + Args: + num_points (int): number of gaussians. + num_tiles_hit (Tensor): number of intersected tiles per gaussian. + + Returns: + num_intersects (int): total number of tile intersections. + cum_tiles_hit (Tensor): a tensor of cumulated intersections (used for sorting). + """ + + @staticmethod + def forward( + ctx, num_points: int, num_tiles_hit: Float[Tensor, "batch 1"] + ) -> Tuple[int, Float[Tensor, "batch 1"]]: + + num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects( + num_points, num_tiles_hit + ) + + return (num_intersects, cum_tiles_hit) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError diff --git a/diff_rast/get_tile_bin_edges.py b/diff_rast/get_tile_bin_edges.py new file mode 100644 index 000000000..ad9573d6c --- /dev/null +++ b/diff_rast/get_tile_bin_edges.py @@ -0,0 +1,38 @@ +"""Python bindings for computing tile bins""" + +from typing import Any + +from jaxtyping import Int +from torch import Tensor +from torch.autograd import Function + +import diff_rast.cuda as _C + + +class GetTileBinEdges(Function): + """Function to map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile. + + Expects that intersection IDs are sorted by increasing tile ID. + + Indexing into tile_bins[tile_idx] returns the range (lower,upper) of gaussian IDs that hit tile_idx. + + Args: + num_intersects (int): total number of gaussian intersects. + isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id). + + Returns: + tile_bins (Tensor): range of gaussians IDs hit per tile. + """ + + @staticmethod + def forward( + ctx, num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"] + ) -> Int[Tensor, "num_intersects 2"]: + + tile_bins = _C.get_tile_bin_edges(num_intersects, isect_ids_sorted) + return tile_bins + + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError diff --git a/diff_rast/map_gaussian_to_intersects.py b/diff_rast/map_gaussian_to_intersects.py new file mode 100644 index 000000000..195ac7bfb --- /dev/null +++ b/diff_rast/map_gaussian_to_intersects.py @@ -0,0 +1,46 @@ +"""Python bindings for mapping gaussians to interset IDs""" + +from typing import Tuple, Any + +from jaxtyping import Float +from torch import Tensor +from torch.autograd import Function + +import diff_rast.cuda as _C + + +class MapGaussiansToIntersects(Function): + """Function to map each gaussian intersection to a unique tile ID and depth value for sorting. + + Args: + num_points (int): number of gaussians. + xys (Tensor): x,y locations of 2D gaussian projections. + depths (Tensor): z depth of gaussians. + radii (Tensor): radii of 2D gaussian projections. + cum_tiles_hit (Tensor): list of cumulative tiles hit. + tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1). + + Returns: + isect_ids (Tensor): unique IDs for each gaussian in the form (tile | depth id). + gaussian_ids (Tensor): Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians. + """ + + @staticmethod + def forward( + ctx, + num_points: int, + xys: Float[Tensor, "batch 2"], + depths: Float[Tensor, "batch 1"], + radii: Float[Tensor, "batch 1"], + cum_tiles_hit: Float[Tensor, "batch 1"], + tile_bounds: Tuple[int, int, int], + ) -> Tuple[Float[Tensor, "cum_tiles_hit 1"], Float[Tensor, "cum_tiles_hit 1"]]: + + isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( + num_points, xys, depths, radii, cum_tiles_hit, tile_bounds + ) + return (isect_ids, gaussian_ids) + + @staticmethod + def backward(ctx: Any, *grad_outputs: Any) -> Any: + raise NotImplementedError From ef736ffc57ee9180b4414a1e4541e4675b71be1d Mon Sep 17 00:00:00 2001 From: maturk Date: Thu, 5 Oct 2023 11:08:49 +0300 Subject: [PATCH 7/8] update tests to use wrappers and fix num_points consistency --- tests/test_bin_and_sort_gaussians.py | 4 ++-- tests/test_cov2d_bounds.py | 6 +++--- tests/test_cumulative_intersects.py | 6 +++--- tests/test_get_tile_bin_edges.py | 5 +++-- tests/test_map_gaussians.py | 6 ++++-- tests/test_project_gaussians.py | 1 + 6 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/test_bin_and_sort_gaussians.py b/tests/test_bin_and_sort_gaussians.py index 24b2bae31..cb044b79c 100644 --- a/tests/test_bin_and_sort_gaussians.py +++ b/tests/test_bin_and_sort_gaussians.py @@ -8,7 +8,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_bin_and_sort_gaussians(): from diff_rast import _torch_impl - import diff_rast.cuda as _C + from diff_rast.bin_and_sort_gaussians import BinAndSortGaussians torch.manual_seed(42) @@ -77,7 +77,7 @@ def test_bin_and_sort_gaussians(): isect_ids_sorted, gaussian_ids_sorted, tile_bins, - ) = _C.bin_and_sort_gaussians( + ) = BinAndSortGaussians.apply( num_points, _num_intersects, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds ) diff --git a/tests/test_cov2d_bounds.py b/tests/test_cov2d_bounds.py index 156f2a268..74c931601 100644 --- a/tests/test_cov2d_bounds.py +++ b/tests/test_cov2d_bounds.py @@ -8,11 +8,11 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_compare_binding_to_pytorch(): from diff_rast._torch_impl import compute_cov2d_bounds as _compute_cov2d_bounds - from diff_rast.cov2d_bounds import compute_cov2d_bounds + from diff_rast.cov2d_bounds import ComputeCov2dBounds torch.manual_seed(42) - num_cov2ds = 2 + num_cov2ds = 100 _covs2d = torch.rand( (num_cov2ds, 2, 2), dtype=torch.float32, device=device, requires_grad=True @@ -26,7 +26,7 @@ def test_compare_binding_to_pytorch(): dim=-1, ) - conic, radii = compute_cov2d_bounds.apply(covs2d) + conic, radii = ComputeCov2dBounds.apply(covs2d) _conic, _radii, _mask = _compute_cov2d_bounds(_covs2d) radii = radii.squeeze(-1) diff --git a/tests/test_cumulative_intersects.py b/tests/test_cumulative_intersects.py index bebef8bc0..8285ae34d 100644 --- a/tests/test_cumulative_intersects.py +++ b/tests/test_cumulative_intersects.py @@ -7,17 +7,17 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_cumulative_intersects(): - import diff_rast.cuda as _C + from diff_rast.compute_cumulative_intersects import ComputeCumulativeIntersects torch.manual_seed(42) - num_points = 10 + num_points = 100 num_tiles_hit = torch.randint( 0, 100, (num_points,), device=device, dtype=torch.int32 ) - num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects( + num_intersects, cum_tiles_hit = ComputeCumulativeIntersects.apply( num_points, num_tiles_hit ) diff --git a/tests/test_get_tile_bin_edges.py b/tests/test_get_tile_bin_edges.py index fe69c8d45..b943b973d 100644 --- a/tests/test_get_tile_bin_edges.py +++ b/tests/test_get_tile_bin_edges.py @@ -8,11 +8,12 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_get_tile_bin_edges(): from diff_rast import _torch_impl - import diff_rast.cuda as _C + from diff_rast.get_tile_bin_edges import GetTileBinEdges torch.manual_seed(42) num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) scales = torch.randn((num_points, 3), device=device) glob_scale = 0.3 @@ -75,7 +76,7 @@ def test_get_tile_bin_edges(): _gaussian_ids_sorted = torch.gather(_gaussian_ids_unsorted, 0, sorted_indices) _tile_bins = _torch_impl.get_tile_bin_edges(_num_intersects, _isect_ids_sorted) - tile_bins = _C.get_tile_bin_edges(_num_intersects, _isect_ids_sorted) + tile_bins = GetTileBinEdges.apply(_num_intersects, _isect_ids_sorted) torch.testing.assert_close(_tile_bins, tile_bins) diff --git a/tests/test_map_gaussians.py b/tests/test_map_gaussians.py index 912d22226..915138574 100644 --- a/tests/test_map_gaussians.py +++ b/tests/test_map_gaussians.py @@ -8,11 +8,12 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") def test_map_gaussians(): from diff_rast import _torch_impl - import diff_rast.cuda as _C + from diff_rast.map_gaussian_to_intersects import MapGaussiansToIntersects torch.manual_seed(42) num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) scales = torch.randn((num_points, 3), device=device) glob_scale = 0.3 @@ -62,7 +63,8 @@ def test_map_gaussians(): _isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects( num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds ) - isect_ids, gaussian_ids = _C.map_gaussian_to_intersects( + + isect_ids, gaussian_ids = MapGaussiansToIntersects.apply( num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds ) diff --git a/tests/test_project_gaussians.py b/tests/test_project_gaussians.py index b0e761e43..242e8b4b7 100644 --- a/tests/test_project_gaussians.py +++ b/tests/test_project_gaussians.py @@ -13,6 +13,7 @@ def test_project_gaussians_forward(): torch.manual_seed(42) num_points = 100 + means3d = torch.randn((num_points, 3), device=device, requires_grad=True) scales = torch.randn((num_points, 3), device=device) glob_scale = 0.3 From 2e35e1da5ca1198153ad7445c92b6d52d47fa8dc Mon Sep 17 00:00:00 2001 From: maturk Date: Thu, 5 Oct 2023 11:13:04 +0300 Subject: [PATCH 8/8] format --- diff_rast/get_tile_bin_edges.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/diff_rast/get_tile_bin_edges.py b/diff_rast/get_tile_bin_edges.py index ad9573d6c..67f80e2c3 100644 --- a/diff_rast/get_tile_bin_edges.py +++ b/diff_rast/get_tile_bin_edges.py @@ -11,7 +11,7 @@ class GetTileBinEdges(Function): """Function to map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile. - + Expects that intersection IDs are sorted by increasing tile ID. Indexing into tile_bins[tile_idx] returns the range (lower,upper) of gaussian IDs that hit tile_idx. @@ -19,7 +19,7 @@ class GetTileBinEdges(Function): Args: num_intersects (int): total number of gaussian intersects. isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id). - + Returns: tile_bins (Tensor): range of gaussians IDs hit per tile. """ @@ -31,7 +31,6 @@ def forward( tile_bins = _C.get_tile_bin_edges(num_intersects, isect_ids_sorted) return tile_bins - @staticmethod def backward(ctx: Any, *grad_outputs: Any) -> Any: