-
Notifications
You must be signed in to change notification settings - Fork 296
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update the pybinding of rasterize_forward_kernel (#28)
* add get_tile_bin_edges binding * finish pybinding for get_tile_bin_edges and fix the bug of tile_idx * use _masks to not get negative depths * finish bin_and_sort_gaussians and fix some bugs of _torch_impl * update docs * python wrappers for pybindings * update tests to use wrappers and fix num_points consistency * format * finish rasterize_forward_kernel & fix possible bugs * finish rasterize_forward_kernel * fix the naming issue * black formatting * update docs with new binding --------- Co-authored-by: maturk <[email protected]>
- Loading branch information
1 parent
38462ca
commit 3f3b9bd
Showing
12 changed files
with
378 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
"""Python bindings for forward rasterization""" | ||
|
||
from typing import Tuple, Any, Optional | ||
|
||
import torch | ||
from jaxtyping import Float, Int | ||
from torch import Tensor | ||
from torch.autograd import Function | ||
|
||
import diff_rast.cuda as _C | ||
|
||
|
||
class RasterizeForwardKernel(Function): | ||
"""Kernel function for rasterizing and alpha-composing each tile. | ||
Args: | ||
tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1). | ||
block (Tuple): block dimensions as a len 3 tuple (block.x , block.y, 1). | ||
img_size (Tuple): image dimensions as a len 3 tuple (img.x , img.y, 1). | ||
gaussian_ids_sorted (Tensor): tensor that maps isect_ids back to cum_tiles_hit, sorted in ascending order. | ||
tile_bins (Tensor): range of gaussians IDs hit per tile. | ||
xys (Tensor): x,y locations of 2D gaussian projections. | ||
conics (Tensor): conics (inverse of covariance) of 2D gaussians in upper triangular format. | ||
colors (Tensor): colors associated with the gaussians. | ||
opacities (Tensor): opacity associated with the gaussians. | ||
background (Tensor): background color | ||
Returns: | ||
A tuple of {Tensor, Tensor, Tensor}: | ||
- **out_img** (Tensor): the rendered output image. | ||
- **final_Ts** (Tensor): the final transmittance values. | ||
- **final_idx** (Tensor): the final gaussian IDs. | ||
""" | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
tile_bounds: Tuple[int, int, int], | ||
block: Tuple[int, int, int], | ||
img_size: Tuple[int, int, int], | ||
gaussian_ids_sorted: Int[Tensor, "num_intersects 1"], | ||
tile_bins: Int[Tensor, "num_intersects 2"], | ||
xys: Float[Tensor, "batch 2"], | ||
conics: Float[Tensor, "*batch 3"], | ||
colors: Float[Tensor, "*batch channels"], | ||
opacities: Float[Tensor, "*batch 1"], | ||
background: Optional[Float[Tensor, "channels"]] = None, | ||
): | ||
if colors.dtype == torch.uint8: | ||
# make sure colors are float [0,1] | ||
colors = colors.float() / 255 | ||
|
||
if background is not None: | ||
assert ( | ||
background.shape[0] == colors.shape[-1] | ||
), f"incorrect shape of background color tensor, expected shape {colors.shape[-1]}" | ||
else: | ||
background = torch.ones(3, dtype=torch.float32) | ||
|
||
(out_img, final_Ts, final_idx,) = _C.rasterize_forward_kernel( | ||
tile_bounds, | ||
block, | ||
img_size, | ||
gaussian_ids_sorted, | ||
tile_bins, | ||
xys, | ||
conics, | ||
colors, | ||
opacities, | ||
background, | ||
) | ||
return ( | ||
out_img, | ||
final_Ts, | ||
final_idx, | ||
) | ||
|
||
@staticmethod | ||
def backward(ctx: Any, *grad_outputs: Any) -> Any: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.