Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for pybindings #24

Merged
merged 10 commits into from
Oct 5, 2023
46 changes: 45 additions & 1 deletion diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def map_gaussian_to_intersects(

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]
cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1].item()

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack("f", depths[idx])
Expand All @@ -322,3 +322,47 @@ def map_gaussian_to_intersects(
cur_idx += 1

return isect_ids, gaussian_ids


def get_tile_bin_edges(num_intersects, isect_ids_sorted):
tile_bins = torch.zeros(
(num_intersects, 2), dtype=torch.int32, device=isect_ids_sorted.device
)

for idx in range(num_intersects):

cur_tile_idx = isect_ids_sorted[idx] >> 32

if idx == 0:
tile_bins[cur_tile_idx, 0] = 0
continue

if idx == num_intersects - 1:
tile_bins[cur_tile_idx, 1] = num_intersects
break

prev_tile_idx = isect_ids_sorted[idx - 1] >> 32

if cur_tile_idx != prev_tile_idx:
tile_bins[prev_tile_idx, 1] = idx
tile_bins[cur_tile_idx, 0] = idx

return tile_bins


def bin_and_sort_gaussians(
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds
):
isect_ids, gaussian_ids = map_gaussian_to_intersects(
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
)

# Sorting isect_ids_unsorted
sorted_values, sorted_indices = torch.sort(isect_ids)

isect_ids_sorted = sorted_values
gaussian_ids_sorted = torch.gather(gaussian_ids, 0, sorted_indices)

tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted)

return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins
72 changes: 72 additions & 0 deletions diff_rast/bin_and_sort_gaussians.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Python bindings for binning and sorting gaussians"""

from typing import Tuple, Any

from jaxtyping import Float
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class BinAndSortGaussians(Function):
"""Function for mapping gaussians to sorted unique intersection IDs and tile bins used for fast rasterization.

We return both sorted and unsorted versions of intersect IDs and gaussian IDs for testing purposes.

Args:
num_points (int): number of gaussians.
num_intersects (int): cumulative number of total gaussian intersections
xys (Tensor): x,y locations of 2D gaussian projections.
depths (Tensor): z depth of gaussians.
radii (Tensor): radii of 2D gaussian projections.
cum_tiles_hit (Tensor): list of cumulative tiles hit.
tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1).

Returns:
isect_ids_unsorted (Tensor): unique IDs for each gaussian in the form (tile | depth id).
gaussian_ids_unsorted (Tensor): Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians.
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id).
gaussian_ids_sorted (Tensor): sorted Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians.
tile_bins (Tensor): range of gaussians hit per tile.
"""

@staticmethod
def forward(
ctx,
num_points: int,
num_intersects: int,
xys: Float[Tensor, "batch 2"],
depths: Float[Tensor, "batch 1"],
radii: Float[Tensor, "batch 1"],
cum_tiles_hit: Float[Tensor, "batch 1"],
tile_bounds: Tuple[int, int, int],
) -> Tuple[
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 1"],
Float[Tensor, "num_intersects 2"],
]:

(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
) = _C.bin_and_sort_gaussians(
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds
)

return (
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins,
)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
37 changes: 37 additions & 0 deletions diff_rast/compute_cumulative_intersects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Python bindings for computing cumulative intersects"""

from typing import Tuple, Any

from jaxtyping import Float
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class ComputeCumulativeIntersects(Function):
"""Computes cumulative intersections of gaussians. This is useful for creating unique gaussian IDs and for sorting.

Args:
num_points (int): number of gaussians.
num_tiles_hit (Tensor): number of intersected tiles per gaussian.

Returns:
num_intersects (int): total number of tile intersections.
cum_tiles_hit (Tensor): a tensor of cumulated intersections (used for sorting).
"""

@staticmethod
def forward(
ctx, num_points: int, num_tiles_hit: Float[Tensor, "batch 1"]
) -> Tuple[int, Float[Tensor, "batch 1"]]:

num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects(
num_points, num_tiles_hit
)

return (num_intersects, cum_tiles_hit)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
5 changes: 3 additions & 2 deletions diff_rast/cov2d_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import diff_rast.cuda as _C


class compute_cov2d_bounds(Function):
class ComputeCov2dBounds(Function):
"""Computes bounds of 2D covariance matrix

Args:
cov2d (Tensor): input cov2d of size (batch, 3) of upper triangular 2D covariance values

Returns:
conics (batch, 3) and radii (batch, 1)
conic (Tensor): conic parameters for 2D gaussian.
radii (Tensor): radii of 2D gaussian projections.
"""

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ def call_cuda(*args, **kwargs):
compute_sh_backward = _make_lazy_cuda_func("compute_sh_backward")
compute_cumulative_intersects = _make_lazy_cuda_func("compute_cumulative_intersects")
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges")
bin_and_sort_gaussians = _make_lazy_cuda_func("bin_and_sort_gaussians")
89 changes: 84 additions & 5 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ project_gaussians_backward_tensor(
}

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
const int num_points, const torch::Tensor &num_tiles_hit
) {
// ref:
// https://nvlabs.github.io/cub/structcub_1_1_device_scan.html#a9416ac1ea26f9fde669d83ddc883795a
Expand Down Expand Up @@ -287,10 +287,10 @@ std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
) {
CHECK_INPUT(xys);
Expand Down Expand Up @@ -326,3 +326,82 @@ std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(

return std::make_tuple(isect_ids_unsorted, gaussian_ids_unsorted);
}

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
const torch::Tensor &isect_ids_sorted
) {
CHECK_INPUT(isect_ids_sorted);
torch::Tensor tile_bins =
torch::zeros({num_intersects, 2}, isect_ids_sorted.options().dtype(torch::kInt32));
get_tile_bin_edges<<<
(num_intersects + N_THREADS - 1) / N_THREADS,
N_THREADS>>>(
num_intersects,
isect_ids_sorted.contiguous().data_ptr<int64_t>(),
(int2 *)tile_bins.contiguous().data_ptr<int>()
);
return tile_bins;
}

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
bin_and_sort_gaussians_tensor(
const int num_points,
const int num_intersects,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
){
CHECK_INPUT(xys);
CHECK_INPUT(depths);
CHECK_INPUT(radii);
CHECK_INPUT(cum_tiles_hit);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

torch::Tensor gaussian_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor gaussian_ids_sorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt32));
torch::Tensor isect_ids_unsorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));
torch::Tensor isect_ids_sorted =
torch::zeros({num_intersects}, xys.options().dtype(torch::kInt64));
torch::Tensor tile_bins =
torch::zeros({num_intersects, 2}, xys.options().dtype(torch::kInt32));

bin_and_sort_gaussians(
num_points,
num_intersects,
(float2 *)xys.contiguous().data_ptr<float>(),
depths.contiguous().data_ptr<float>(),
radii.contiguous().data_ptr<int32_t>(),
cum_tiles_hit.contiguous().data_ptr<int32_t>(),
tile_bounds_dim3,
// Outputs.
isect_ids_unsorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_unsorted.contiguous().data_ptr<int32_t>(),
isect_ids_sorted.contiguous().data_ptr<int64_t>(),
gaussian_ids_sorted.contiguous().data_ptr<int32_t>(),
(int2 *)tile_bins.contiguous().data_ptr<int>()
);

return std::make_tuple(
isect_ids_unsorted,
gaussian_ids_unsorted,
isect_ids_sorted,
gaussian_ids_sorted,
tile_bins
);

}
31 changes: 26 additions & 5 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,35 @@ project_gaussians_backward_tensor(
);

std::tuple<torch::Tensor, torch::Tensor> compute_cumulative_intersects_tensor(
const int num_points, torch::Tensor &num_tiles_hit
const int num_points, const torch::Tensor &num_tiles_hit
);

std::tuple<torch::Tensor, torch::Tensor> map_gaussian_to_intersects_tensor(
const int num_points,
torch::Tensor &xys,
torch::Tensor &depths,
torch::Tensor &radii,
torch::Tensor &cum_tiles_hit,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);

torch::Tensor get_tile_bin_edges_tensor(
int num_intersects,
const torch::Tensor &isect_ids_sorted
);

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor,
torch::Tensor>
bin_and_sort_gaussians_tensor(
const int num_points,
const int num_intersects,
const torch::Tensor &xys,
const torch::Tensor &depths,
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);
2 changes: 2 additions & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compute_sh_backward", &compute_sh_backward_tensor);
m.def("compute_cumulative_intersects", &compute_cumulative_intersects_tensor);
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor);
m.def("bin_and_sort_gaussians", &bin_and_sort_gaussians_tensor);
}
4 changes: 4 additions & 0 deletions diff_rast/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,8 @@ __global__ void map_gaussian_to_intersects(
const dim3 tile_bounds,
int64_t *isect_ids,
int32_t *gaussian_ids
);

__global__ void get_tile_bin_edges(
const int num_intersects, const int64_t *isect_ids_sorted, int2 *tile_bins
);
37 changes: 37 additions & 0 deletions diff_rast/get_tile_bin_edges.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Python bindings for computing tile bins"""

from typing import Any

from jaxtyping import Int
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class GetTileBinEdges(Function):
"""Function to map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile.

Expects that intersection IDs are sorted by increasing tile ID.

Indexing into tile_bins[tile_idx] returns the range (lower,upper) of gaussian IDs that hit tile_idx.

Args:
num_intersects (int): total number of gaussian intersects.
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id).

Returns:
tile_bins (Tensor): range of gaussians IDs hit per tile.
"""

@staticmethod
def forward(
ctx, num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"]
) -> Int[Tensor, "num_intersects 2"]:

tile_bins = _C.get_tile_bin_edges(num_intersects, isect_ids_sorted)
return tile_bins

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
Loading