Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhuoyang-Pan committed Oct 2, 2023
1 parent 69e86a7 commit 22c1e21
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 39 deletions.
26 changes: 11 additions & 15 deletions diff_rast/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,37 +292,33 @@ def project_gaussians_forward(

return cov3d, xys, depths, radii, conics, num_tiles_hit, mask


def map_gaussian_to_intersects(
num_points,
xys,
depths,
radii,
cum_tiles_hit,
tile_bounds
num_points, xys, depths, radii, cum_tiles_hit, tile_bounds
):
num_intersects = cum_tiles_hit[-1]
isect_ids = torch.zeros(num_intersects, dtype=torch.int64, device=xys.device)
gaussian_ids = torch.zeros(num_intersects, dtype=torch.int32, device=xys.device)

for idx in range(num_points):
if radii[idx] <= 0: break

if radii[idx] <= 0:
break

tile_min, tile_max = get_tile_bbox(xys[idx], radii[idx], tile_bounds)

cur_idx = 0 if idx == 0 else cum_tiles_hit[idx - 1]

# Get raw byte representation of the float value at the given index
raw_bytes = struct.pack('f', depths[idx])
raw_bytes = struct.pack("f", depths[idx])

# Interpret those bytes as an int32_t
depth_id_n = struct.unpack('i', raw_bytes)[0]
depth_id_n = struct.unpack("i", raw_bytes)[0]

for i in range(tile_min[1], tile_max[1]):
for j in range(tile_min[0], tile_max[0]):
tile_id = i * tile_bounds[0] + j
isect_ids[cur_idx] = (tile_id << 32) | depth_id_n
gaussian_ids[cur_idx] = idx
cur_idx += 1

return isect_ids, gaussian_ids

8 changes: 7 additions & 1 deletion diff_rast/project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,13 @@ def backward(ctx, v_xys, v_depths, v_radii, v_conics, v_num_tiles_hit, v_cov3d):
conics,
) = ctx.saved_tensors

(v_cov2d, v_cov3d, v_mean3d, v_scale, v_quat,) = _C.project_gaussians_backward(
(
v_cov2d,
v_cov3d,
v_mean3d,
v_scale,
v_quat,
) = _C.project_gaussians_backward(
ctx.num_points,
means3d,
scales,
Expand Down
12 changes: 8 additions & 4 deletions tests/test_cumulative_intersects.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ def test_cumulative_intersects():

num_points = 10

num_tiles_hit = torch.randint(0, 100, (num_points,), device=device, dtype=torch.int32)

num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects(num_points, num_tiles_hit)

num_tiles_hit = torch.randint(
0, 100, (num_points,), device=device, dtype=torch.int32
)

num_intersects, cum_tiles_hit = _C.compute_cumulative_intersects(
num_points, num_tiles_hit
)

_cum_tiles_hit = torch.cumsum(num_tiles_hit, dim=0, dtype=torch.int32)
_num_intersects = _cum_tiles_hit[-1]

Expand Down
22 changes: 6 additions & 16 deletions tests/test_map_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,18 @@ def test_map_gaussians():
tile_bounds,
clip_thresh,
)

_cum_tiles_hit = torch.cumsum(_num_tiles_hit, dim=0, dtype=torch.int32)
_depths = _depths.contiguous()

isect_ids, gaussian_ids = _C.map_gaussian_to_intersects(
num_points,
_xys,
_depths,
_radii,
_cum_tiles_hit,
tile_bounds
num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds
)

_isect_ids, _gaussian_ids = _torch_impl.map_gaussian_to_intersects(
num_points,
_xys,
_depths,
_radii,
_cum_tiles_hit,
tile_bounds
num_points, _xys, _depths, _radii, _cum_tiles_hit, tile_bounds
)

torch.testing.assert_close(gaussian_ids, _gaussian_ids)
torch.testing.assert_close(isect_ids, _isect_ids)

Expand Down
11 changes: 8 additions & 3 deletions tests/test_project_gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_project_gaussians_forward():
projmat,
fx,
fy,
H,
H,
W,
tile_bounds,
clip_thresh,
Expand Down Expand Up @@ -77,12 +77,17 @@ def test_project_gaussians_forward():
atol=1e-5,
rtol=1e-5,
)
torch.testing.assert_close(xys[_masks], _xys[_masks], atol=1e-4, rtol=1e-4,)
torch.testing.assert_close(
xys[_masks],
_xys[_masks],
atol=1e-4,
rtol=1e-4,
)
torch.testing.assert_close(depths[_masks], _depths[_masks])
torch.testing.assert_close(radii[_masks], _radii[_masks])
torch.testing.assert_close(conics[_masks], _conics[_masks])
torch.testing.assert_close(num_tiles_hit[_masks], _num_tiles_hit[_masks])


if __name__ == "__main__":
test_project_gaussians_forward()
test_project_gaussians_forward()

0 comments on commit 22c1e21

Please sign in to comment.