Skip to content

Commit

Permalink
Updates for pybindings (#24)
Browse files Browse the repository at this point in the history
* add get_tile_bin_edges binding

* finish pybinding for get_tile_bin_edges and fix the bug of tile_idx

* use _masks to not get negative depths

* finish bin_and_sort_gaussians and fix some bugs of _torch_impl

* update docs

* python wrappers for pybindings

* update tests to use wrappers and fix num_points consistency

* format

---------

Co-authored-by: maturk <[email protected]>
  • Loading branch information
Zhuoyang-Pan and maturk authored Oct 5, 2023
1 parent 2ec6b5c commit a223e5a
Show file tree
Hide file tree
Showing 19 changed files with 566 additions and 24 deletions.
46 changes: 45 additions & 1 deletion diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def map_gaussian_to_intersects(

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]
cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1].item()

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack("f", depths[idx])
Expand All @@ -322,3 +322,47 @@ def map_gaussian_to_intersects(
cur_idx += 1

return isect_ids, gaussian_ids


def get_tile_bin_edges(num_intersects, isect_ids_sorted):
tile_bins = torch.zeros(
(num_intersects, 2), dtype=torch.int32, device=isect_ids_sorted.device
)

for idx in range(num_intersects):

cur_tile_idx = isect_ids_sorted[idx] >> 32

if idx == 0:
tile_bins[cur_tile_idx, 0] = 0
continue

if idx == num_intersects - 1:
tile_bins[cur_tile_idx, 1] = num_intersects
break

prev_tile_idx = isect_ids_sorted[idx - 1] >> 32

if cur_tile_idx != prev_tile_idx:
tile_bins[prev_tile_idx, 1] = idx
tile_bins[cur_tile_idx, 0] = idx

return tile_bins


def bin_and_sort_gaussians(
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds
):
isect_ids, gaussian_ids = map_gaussian_to_intersects(
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
)

# Sorting isect_ids_unsorted
sorted_values, sorted_indices = torch.sort(isect_ids)

isect_ids_sorted = sorted_values
gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices)

tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted)

return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins
72 changes: 72 additions & 0 deletions diff_rast/bin_and_sort_gaussians.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Python bindings for binning and sorting gaussians"""

from typing import Tuple, Any

from jaxtyping import Float
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class BinAndSortGaussians(Function):
"""Function for mapping gaussians to sorted unique intersection IDs and tile bins used for fast rasterization.
We return both sorted and unsorted versions of intersect IDs and gaussian IDs for testing purposes.
Args:
num_points (int): number of gaussians.
num_intersects (int): cumulative number of total gaussian intersections
xys (Tensor): x,y locations of 2D gaussian projections.
depths (Tensor): z depth of gaussians.
radii (Tensor): radii of 2D gaussian projections.
cum_tiles_hit (Tensor): list of cumulative tiles hit.
tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1).
Returns:
isect_ids_unsorted (Tensor): unique IDs for each gaussian in the form (tile | depth id).
gaussian_ids_unsorted (Tensor): Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians.
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id).
gaussian_ids_sorted (Tensor): sorted Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians.
tile_bins (Tensor): range of gaussians hit per tile.
"""

@staticmethod
def forward(
ctx,
num_points: int,
num_intersects: int,
xys: Float[Tensor, "batch 2"],
depths: Float[Tensor, "batch 1"],
radii: Float[Tensor, "batch 1"],
cum_tiles_hit: Float[Tensor, "batch 1"],
tile_bounds: Tuple[int, int, int],
) -> Tuple[
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 2"],
]:

(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
) = _C.bin_and_sort_gaussians(
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds
)

return (
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
37 changes: 37 additions & 0 deletions diff_rast/compute_cumulative_intersects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Python bindings for computing cumulative intersects"""

from typing import Tuple, Any

from jaxtyping import Float
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class ComputeCumulativeIntersects(Function):
"""Computes cumulative intersections of gaussians. This is useful for creating unique gaussian IDs and for sorting.
Args:
num_points (int): number of gaussians.
num_tiles_hit (Tensor): number of intersected tiles per gaussian.
Returns:
num_intersects (int): total number of tile intersections.
cum_tiles_hit (Tensor): a tensor of cumulated intersections (used for sorting).
"""

@staticmethod
def forward(
ctx, num_points: int, num_tiles_hit: Float[Tensor, "batch 1"]
) -> Tuple[int, Float[Tensor, "batch 1"]]:

num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects(
num_points, num_tiles_hit
)

return (num_intersects, cum_tiles_hit)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
5 changes: 3 additions & 2 deletions diff_rast/cov2d_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import diff_rast.cuda as _C


class compute_cov2d_bounds(Function):
class ComputeCov2dBounds(Function):
"""Computes bounds of 2D covariance matrix
Args:
cov2d (Tensor): input cov2d of size (batch, 3) of upper triangular 2D covariance values
Returns:
conics (batch, 3) and radii (batch, 1)
conic (Tensor): conic parameters for 2D gaussian.
radii (Tensor): radii of 2D gaussian projections.
"""

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ def call_cuda(*args, **kwargs):
compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward")
compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects")
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges")
bin_and_sort_gaussians = _make_lazy_cuda_func("bin_and_sort_gaussians")
89 changes: 84 additions & 5 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ project_gaussians_backward_tensor(
}

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
const int num_points, const torch::Tensor &num_tiles_hit
) {
// ref:
// https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
Expand Down Expand Up @@ -287,10 +287,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
) {
CHECK_INPUT(xys);
Expand Down Expand Up @@ -326,3 +326,82 @@ std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(

return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted);
}

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
const torch::Tensor &isect_ids_sorted
) {
CHECK_INPUT(isect_ids_sorted);
torch::Tensor tile_bins =
torch::zeros({num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32));
get_tile_bin_edges<<<
(num_intersects + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_intersects,
isect_ids_sorted.contiguous().data_ptr<int64_t>(),
(int2 *)tile_bins.contiguous().data_ptr<int>()
);
return tile_bins;
}

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
bin_and_sort_gaussians_tensor(
const int num_points,
const int num_intersects,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
){
CHECK_INPUT(xys);
CHECK_INPUT(depths);
CHECK_INPUT(radii);
CHECK_INPUT(cum_tiles_hit);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

torch::Tensor gaussian_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor gaussian_ids_sorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor isect_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));
torch::Tensor isect_ids_sorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));
torch::Tensor tile_bins =
torch::zeros({num_intersects, 2}, xys.options().dtype(torch::kInt32));

bin_and_sort_gaussians(
num_points,
num_intersects,
(float2 *)xys.contiguous().data_ptr<float>(),
depths.contiguous().data_ptr<float>(),
radii.contiguous().data_ptr<int32_t>(),
cum_tiles_hit.contiguous().data_ptr<int32_t>(),
tile_bounds_dim3,
// Outputs.
isect_ids_unsorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_unsorted.contiguous().data_ptr<int32_t>(),
isect_ids_sorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_sorted.contiguous().data_ptr<int32_t>(),
(int2 *)tile_bins.contiguous().data_ptr<int>()
);

return std::make_tuple(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins
);

}
31 changes: 26 additions & 5 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,35 @@ project_gaussians_backward_tensor(
);

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
const int num_points, const torch::Tensor &num_tiles_hit
);

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
const torch::Tensor &isect_ids_sorted
);

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
bin_and_sort_gaussians_tensor(
const int num_points,
const int num_intersects,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);
2 changes: 2 additions & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_sh_backward", &compute_sh_backward_tensor);
m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor);
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor);
m.def("bin_and_sort_gaussians", &bin_and_sort_gaussians_tensor);
}
4 changes: 4 additions & 0 deletions diff_rast/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,8 @@ __global__ void map_gaussian_to_intersects(
const dim3 tile_bounds,
int64_t *isect_ids,
int32_t *gaussian_ids
);

__global__ void get_tile_bin_edges(
const int num_intersects, const int64_t *isect_ids_sorted, int2 *tile_bins
);
37 changes: 37 additions & 0 deletions diff_rast/get_tile_bin_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Python bindings for computing tile bins"""

from typing import Any

from jaxtyping import Int
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class GetTileBinEdges(Function):
"""Function to map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile.
Expects that intersection IDs are sorted by increasing tile ID.
Indexing into tile_bins[tile_idx] returns the range (lower,upper) of gaussian IDs that hit tile_idx.
Args:
num_intersects (int): total number of gaussian intersects.
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id).
Returns:
tile_bins (Tensor): range of gaussians IDs hit per tile.
"""

@staticmethod
def forward(
ctx, num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"]
) -> Int[Tensor, "num_intersects 2"]:

tile_bins = _C.get_tile_bin_edges(num_intersects, isect_ids_sorted)
return tile_bins

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
Loading

0 comments on commit a223e5a

Please sign in to comment.