Skip to content

Commit

Permalink
Update the pybinding of rasterize_forward_kernel (#28)
Browse files Browse the repository at this point in the history
* add get_tile_bin_edges binding

* finish pybinding for get_tile_bin_edges and fix the bug of tile_idx

* use _masks to not get negative depths

* finish bin_and_sort_gaussians and fix some bugs of _torch_impl

* update docs

* python wrappers for pybindings

* update tests to use wrappers and fix num_points consistency

* format

* finish rasterize_forward_kernel & fix possible bugs

* finish rasterize_forward_kernel

* fix the naming issue

* black formatting

* update docs with new binding

---------

Co-authored-by: maturk <[email protected]>
  • Loading branch information
Zhuoyang-Pan and maturk authored Oct 6, 2023
1 parent 38462ca commit 3f3b9bd
Show file tree
Hide file tree
Showing 12 changed files with 378 additions and 3 deletions.
2 changes: 2 additions & 0 deletions diff_rast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .get_tile_bin_edges import GetTileBinEdges
from .map_gaussian_to_intersects import MapGaussiansToIntersects
from .sh import SphericalHarmonics
from .rasterize_forward_kernel import RasterizeForwardKernel
from .version import __version__

__all__ = [
Expand All @@ -18,4 +19,5 @@
"GetTileBinEdges",
"MapGaussiansToIntersects",
"SphericalHarmonics",
"RasterizeForwardKernel",
]
70 changes: 70 additions & 0 deletions diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,73 @@ def bin_and_sort_gaussians(
tile_bins = get_tile_bin_edges(num_intersects, isect_ids_sorted)

return isect_ids, gaussian_ids, isect_ids_sorted, gaussian_ids_sorted, tile_bins


def rasterize_forward_kernel(
tile_bounds,
block,
img_size,
gaussian_ids_sorted,
tile_bins,
xys,
conics,
colors,
opacities,
background,
):
channels = colors.shape[1]
out_img = torch.zeros(
(img_size[1], img_size[0], channels), dtype=torch.float32, device=xys.device
)
final_Ts = torch.zeros(
(img_size[1], img_size[0]), dtype=torch.float32, device=xys.device
)
final_idx = torch.zeros(
(img_size[1], img_size[0]), dtype=torch.int32, device=xys.device
)
for i in range(img_size[1]):
for j in range(img_size[0]):
tile_id = (i // block[0]) * tile_bounds[0] + (j // block[1])
tile_bin_start = tile_bins[tile_id, 0]
tile_bin_end = tile_bins[tile_id, 1]
T = 1.0

for idx in range(tile_bin_start, tile_bin_end):
gaussian_id = gaussian_ids_sorted[idx]
conic = conics[gaussian_id]
center = xys[gaussian_id]
delta = center - torch.tensor(
[j, i], dtype=torch.float32, device=xys.device
)

sigma = (
0.5
* (conic[0] * delta[0] * delta[0] + conic[2] * delta[1] * delta[1])
+ conic[1] * delta[0] * delta[1]
)

if sigma < 0:
continue

opac = opacities[gaussian_id]
alpha = min(0.999, opac * torch.exp(-sigma))

if alpha < 1 / 255:
continue

next_T = T * (1 - alpha)

if next_T <= 1e-4:
idx -= 1
break

vis = alpha * T

out_img[i, j] += vis * colors[gaussian_id]
T = next_T

final_Ts[i, j] = T
final_idx[i, j] = idx
out_img[i, j] += T * background

return out_img, final_Ts, final_idx
1 change: 1 addition & 0 deletions diff_rast/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ def call_cuda(*args, **kwargs):
map_gaussian_to_intersects = _make_lazy_cuda_func("map_gaussian_to_intersects")
get_tile_bin_edges = _make_lazy_cuda_func("get_tile_bin_edges")
bin_and_sort_gaussians = _make_lazy_cuda_func("bin_and_sort_gaussians")
rasterize_forward_kernel = _make_lazy_cuda_func("rasterize_forward_kernel")
72 changes: 72 additions & 0 deletions diff_rast/cuda/csrc/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,76 @@ bin_and_sort_gaussians_tensor(
tile_bins
);

}

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor
> rasterize_forward_kernel_tensor(
const std::tuple<int, int, int> tile_bounds,
const std::tuple<int, int, int> block,
const std::tuple<int, int, int> img_size,
const torch::Tensor &gaussian_ids_sorted,
const torch::Tensor &tile_bins,
const torch::Tensor &xys,
const torch::Tensor &conics,
const torch::Tensor &colors,
const torch::Tensor &opacities,
const torch::Tensor &background
){
CHECK_INPUT(gaussian_ids_sorted);
CHECK_INPUT(tile_bins);
CHECK_INPUT(xys);
CHECK_INPUT(conics);
CHECK_INPUT(colors);
CHECK_INPUT(opacities);
CHECK_INPUT(background);

dim3 tile_bounds_dim3;
tile_bounds_dim3.x = std::get<0>(tile_bounds);
tile_bounds_dim3.y = std::get<1>(tile_bounds);
tile_bounds_dim3.z = std::get<2>(tile_bounds);

dim3 block_dim3;
block_dim3.x = std::get<0>(block);
block_dim3.y = std::get<1>(block);
block_dim3.z = std::get<2>(block);

dim3 img_size_dim3;
img_size_dim3.x = std::get<0>(img_size);
img_size_dim3.y = std::get<1>(img_size);
img_size_dim3.z = std::get<2>(img_size);

const int channels = colors.size(1);
const int img_width = img_size_dim3.x;
const int img_height = img_size_dim3.y;

torch::Tensor out_img = torch::zeros(
{img_height, img_width, channels}, xys.options().dtype(torch::kFloat32)
);
torch::Tensor final_Ts = torch::zeros(
{img_height, img_width}, xys.options().dtype(torch::kFloat32)
);
torch::Tensor final_idx = torch::zeros(
{img_height, img_width}, xys.options().dtype(torch::kInt32)
);


rasterize_forward_kernel<3><<<tile_bounds_dim3, block_dim3>>>(
tile_bounds_dim3,
img_size_dim3,
gaussian_ids_sorted.contiguous().data_ptr<int32_t>(),
(int2 *)tile_bins.contiguous().data_ptr<int>(),
(float2 *)xys.contiguous().data_ptr<float>(),
(float3 *)conics.contiguous().data_ptr<float>(),
colors.contiguous().data_ptr<float>(),
opacities.contiguous().data_ptr<float>(),
final_Ts.contiguous().data_ptr<float>(),
final_idx.contiguous().data_ptr<int>(),
out_img.contiguous().data_ptr<float>(),
background.contiguous().data_ptr<float>()
);

return std::make_tuple(out_img, final_Ts, final_idx);
}
17 changes: 17 additions & 0 deletions diff_rast/cuda/csrc/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,21 @@ bin_and_sort_gaussians_tensor(
const torch::Tensor &radii,
const torch::Tensor &cum_tiles_hit,
const std::tuple<int, int, int> tile_bounds
);

std::tuple<
torch::Tensor,
torch::Tensor,
torch::Tensor
> rasterize_forward_kernel_tensor(
const std::tuple<int, int, int> tile_bounds,
const std::tuple<int, int, int> block,
const std::tuple<int, int, int> img_size,
const torch::Tensor &gaussian_ids_sorted,
const torch::Tensor &tile_bins,
const torch::Tensor &xys,
const torch::Tensor &conics,
const torch::Tensor &colors,
const torch::Tensor &opacities,
const torch::Tensor &background
);
1 change: 1 addition & 0 deletions diff_rast/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("map_gaussian_to_intersects", &map_gaussian_to_intersects_tensor);
m.def("get_tile_bin_edges", &get_tile_bin_edges_tensor);
m.def("bin_and_sort_gaussians", &bin_and_sort_gaussians_tensor);
m.def("rasterize_forward_kernel", &rasterize_forward_kernel_tensor);
}
3 changes: 2 additions & 1 deletion diff_rast/cuda/csrc/forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ __global__ void rasterize_forward_kernel(
T = next_T;
}
final_Ts[pix_id] = T; // transmittance at last gaussian in this pixel
final_index[pix_id] = idx; // index of in bin of last gaussian in this pixel
final_index[pix_id] = (idx == range.y) ? idx - 1 : idx; // index of in bin of last gaussian in this pixel

for (int c = 0; c < CHANNELS; ++c) {
out_img[CHANNELS * pix_id + c] += T * background[c];
}
Expand Down
16 changes: 16 additions & 0 deletions diff_rast/cuda/csrc/forward.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,20 @@ __global__ void map_gaussian_to_intersects(

__global__ void get_tile_bin_edges(
const int num_intersects, const int64_t *isect_ids_sorted, int2 *tile_bins
);

template <int CHANNELS>
__global__ void rasterize_forward_kernel(
const dim3 tile_bounds,
const dim3 img_size,
const int32_t *gaussian_ids_sorted,
const int2 *tile_bins,
const float2 *xys,
const float3 *conics,
const float *colors,
const float *opacities,
float *final_Ts,
int *final_index,
float *out_img,
const float *background
);
81 changes: 81 additions & 0 deletions diff_rast/rasterize_forward_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Python bindings for forward rasterization"""

from typing import Tuple, Any, Optional

import torch
from jaxtyping import Float, Int
from torch import Tensor
from torch.autograd import Function

import diff_rast.cuda as _C


class RasterizeForwardKernel(Function):
"""Kernel function for rasterizing and alpha-composing each tile.
Args:
tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1).
block (Tuple): block dimensions as a len 3 tuple (block.x , block.y, 1).
img_size (Tuple): image dimensions as a len 3 tuple (img.x , img.y, 1).
gaussian_ids_sorted (Tensor): tensor that maps isect_ids back to cum_tiles_hit, sorted in ascending order.
tile_bins (Tensor): range of gaussians IDs hit per tile.
xys (Tensor): x,y locations of 2D gaussian projections.
conics (Tensor): conics (inverse of covariance) of 2D gaussians in upper triangular format.
colors (Tensor): colors associated with the gaussians.
opacities (Tensor): opacity associated with the gaussians.
background (Tensor): background color
Returns:
A tuple of {Tensor, Tensor, Tensor}:
- **out_img** (Tensor): the rendered output image.
- **final_Ts** (Tensor): the final transmittance values.
- **final_idx** (Tensor): the final gaussian IDs.
"""

@staticmethod
def forward(
ctx,
tile_bounds: Tuple[int, int, int],
block: Tuple[int, int, int],
img_size: Tuple[int, int, int],
gaussian_ids_sorted: Int[Tensor, "num_intersects 1"],
tile_bins: Int[Tensor, "num_intersects 2"],
xys: Float[Tensor, "batch 2"],
conics: Float[Tensor, "*batch 3"],
colors: Float[Tensor, "*batch channels"],
opacities: Float[Tensor, "*batch 1"],
background: Optional[Float[Tensor, "channels"]] = None,
):
if colors.dtype == torch.uint8:
# make sure colors are float [0,1]
colors = colors.float() / 255

if background is not None:
assert (
background.shape[0] == colors.shape[-1]
), f"incorrect shape of background color tensor, expected shape {colors.shape[-1]}"
else:
background = torch.ones(3, dtype=torch.float32)

(out_img, final_Ts, final_idx,) = _C.rasterize_forward_kernel(
tile_bounds,
block,
img_size,
gaussian_ids_sorted,
tile_bins,
xys,
conics,
colors,
opacities,
background,
)
return (
out_img,
final_Ts,
final_idx,
)

@staticmethod
def backward(ctx: Any, *grad_outputs: Any) -> Any:
raise NotImplementedError
4 changes: 3 additions & 1 deletion docs/source/apis/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ In addition to the main projection and rasterization functions, a few CUDA kerne

.. autoclass:: MapGaussiansToIntersects

.. autoclass:: ComputeCumulativeIntersects
.. autoclass:: ComputeCumulativeIntersects

.. autoclass:: RasterizeForwardKernel
2 changes: 1 addition & 1 deletion docs/source/tests/tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ The tests include:
./test_get_tile_bin_edges.py
./test_map_gaussians.py
./test_project_gaussians
./test_rasterize_forward_kernel.py
./test_sh.py
Loading

0 comments on commit 3f3b9bd

Please sign in to comment.