-
Notifications
You must be signed in to change notification settings - Fork 290
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add get_tile_bin_edges binding * finish pybinding for get_tile_bin_edges and fix the bug of tile_idx * use _masks to not get negative depths * finish bin_and_sort_gaussians and fix some bugs of _torch_impl * update docs * python wrappers for pybindings * update tests to use wrappers and fix num_points consistency * format --------- Co-authored-by: maturk <[email protected]>
- Loading branch information
1 parent
2ec6b5c
commit a223e5a
Showing
19 changed files
with
566 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
"""Python bindings for binning and sorting gaussians""" | ||
|
||
from typing import Tuple, Any | ||
|
||
from jaxtyping import Float | ||
from torch import Tensor | ||
from torch.autograd import Function | ||
|
||
import diff_rast.cuda as _C | ||
|
||
|
||
class BinAndSortGaussians(Function): | ||
"""Function for mapping gaussians to sorted unique intersection IDs and tile bins used for fast rasterization. | ||
We return both sorted and unsorted versions of intersect IDs and gaussian IDs for testing purposes. | ||
Args: | ||
num_points (int): number of gaussians. | ||
num_intersects (int): cumulative number of total gaussian intersections | ||
xys (Tensor): x,y locations of 2D gaussian projections. | ||
depths (Tensor): z depth of gaussians. | ||
radii (Tensor): radii of 2D gaussian projections. | ||
cum_tiles_hit (Tensor): list of cumulative tiles hit. | ||
tile_bounds (Tuple): tile dimensions as a len 3 tuple (tiles.x , tiles.y, 1). | ||
Returns: | ||
isect_ids_unsorted (Tensor): unique IDs for each gaussian in the form (tile | depth id). | ||
gaussian_ids_unsorted (Tensor): Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians. | ||
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id). | ||
gaussian_ids_sorted (Tensor): sorted Tensor that maps isect_ids back to cum_tiles_hit. Useful for identifying gaussians. | ||
tile_bins (Tensor): range of gaussians hit per tile. | ||
""" | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, | ||
num_points: int, | ||
num_intersects: int, | ||
xys: Float[Tensor, "batch 2"], | ||
depths: Float[Tensor, "batch 1"], | ||
radii: Float[Tensor, "batch 1"], | ||
cum_tiles_hit: Float[Tensor, "batch 1"], | ||
tile_bounds: Tuple[int, int, int], | ||
) -> Tuple[ | ||
Float[Tensor, "num_intersects 1"], | ||
Float[Tensor, "num_intersects 1"], | ||
Float[Tensor, "num_intersects 1"], | ||
Float[Tensor, "num_intersects 1"], | ||
Float[Tensor, "num_intersects 2"], | ||
]: | ||
|
||
( | ||
isect_ids_unsorted, | ||
gaussian_ids_unsorted, | ||
isect_ids_sorted, | ||
gaussian_ids_sorted, | ||
tile_bins, | ||
) = _C.bin_and_sort_gaussians( | ||
num_points, num_intersects, xys, depths, radii, cum_tiles_hit, tile_bounds | ||
) | ||
|
||
return ( | ||
isect_ids_unsorted, | ||
gaussian_ids_unsorted, | ||
isect_ids_sorted, | ||
gaussian_ids_sorted, | ||
tile_bins, | ||
) | ||
|
||
@staticmethod | ||
def backward(ctx: Any, *grad_outputs: Any) -> Any: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Python bindings for computing cumulative intersects""" | ||
|
||
from typing import Tuple, Any | ||
|
||
from jaxtyping import Float | ||
from torch import Tensor | ||
from torch.autograd import Function | ||
|
||
import diff_rast.cuda as _C | ||
|
||
|
||
class ComputeCumulativeIntersects(Function): | ||
"""Computes cumulative intersections of gaussians. This is useful for creating unique gaussian IDs and for sorting. | ||
Args: | ||
num_points (int): number of gaussians. | ||
num_tiles_hit (Tensor): number of intersected tiles per gaussian. | ||
Returns: | ||
num_intersects (int): total number of tile intersections. | ||
cum_tiles_hit (Tensor): a tensor of cumulated intersections (used for sorting). | ||
""" | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, num_points: int, num_tiles_hit: Float[Tensor, "batch 1"] | ||
) -> Tuple[int, Float[Tensor, "batch 1"]]: | ||
|
||
num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects( | ||
num_points, num_tiles_hit | ||
) | ||
|
||
return (num_intersects, cum_tiles_hit) | ||
|
||
@staticmethod | ||
def backward(ctx: Any, *grad_outputs: Any) -> Any: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
"""Python bindings for computing tile bins""" | ||
|
||
from typing import Any | ||
|
||
from jaxtyping import Int | ||
from torch import Tensor | ||
from torch.autograd import Function | ||
|
||
import diff_rast.cuda as _C | ||
|
||
|
||
class GetTileBinEdges(Function): | ||
"""Function to map sorted intersection IDs to tile bins which give the range of unique gaussian IDs belonging to each tile. | ||
Expects that intersection IDs are sorted by increasing tile ID. | ||
Indexing into tile_bins[tile_idx] returns the range (lower,upper) of gaussian IDs that hit tile_idx. | ||
Args: | ||
num_intersects (int): total number of gaussian intersects. | ||
isect_ids_sorted (Tensor): sorted unique IDs for each gaussian in the form (tile | depth id). | ||
Returns: | ||
tile_bins (Tensor): range of gaussians IDs hit per tile. | ||
""" | ||
|
||
@staticmethod | ||
def forward( | ||
ctx, num_intersects: int, isect_ids_sorted: Int[Tensor, "num_intersects 1"] | ||
) -> Int[Tensor, "num_intersects 2"]: | ||
|
||
tile_bins = _C.get_tile_bin_edges(num_intersects, isect_ids_sorted) | ||
return tile_bins | ||
|
||
@staticmethod | ||
def backward(ctx: Any, *grad_outputs: Any) -> Any: | ||
raise NotImplementedError |
Oops, something went wrong.