Skip to content

Commit

Permalink
format C files
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruilong Li committed Jul 16, 2024
1 parent e623311 commit 7562a39
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 86 deletions.
8 changes: 5 additions & 3 deletions gsplat/cuda/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("rasterize_to_indices_in_range", &rasterize_to_indices_in_range_tensor);

// packed version
m.def("fully_fused_projection_packed_fwd", &fully_fused_projection_packed_fwd_tensor);
m.def("fully_fused_projection_packed_bwd", &fully_fused_projection_packed_bwd_tensor);

m.def("fully_fused_projection_packed_fwd",
&fully_fused_projection_packed_fwd_tensor);
m.def("fully_fused_projection_packed_bwd",
&fully_fused_projection_packed_bwd_tensor);

m.def("compute_relocation", &compute_relocation_tensor);
}
36 changes: 18 additions & 18 deletions gsplat/cuda/csrc/fully_fused_projection_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

namespace cg = cooperative_groups;


/****************************************************************************
* Projection of Gaussians (Single Batch) Forward Pass
****************************************************************************/
Expand All @@ -25,10 +24,10 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N,
const T *__restrict__ viewmats, // [C, 4, 4]
const T *__restrict__ Ks, // [C, 3, 3]
const int32_t image_width, const int32_t image_height,
const T eps2d, const T near_plane,
const T far_plane, const T radius_clip,
const T eps2d, const T near_plane, const T far_plane,
const T radius_clip,
// outputs
int32_t *__restrict__ radii, // [C, N]
int32_t *__restrict__ radii, // [C, N]
T *__restrict__ means2d, // [C, N, 2]
T *__restrict__ depths, // [C, N]
T *__restrict__ conics, // [C, N, 3]
Expand Down Expand Up @@ -74,16 +73,17 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N,
// compute from quaternions and scales
quats += gid * 4;
scales += gid * 3;
quat_scale_to_covar_preci<T>(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr);
quat_scale_to_covar_preci<T>(glm::make_vec4(quats), glm::make_vec3(scales),
&covar, nullptr);
}
mat3<T> covar_c;
covar_world_to_cam(R, covar, covar_c);

// perspective projection
mat2<T> covar2d;
vec2<T> mean2d;
persp_proj<T>(mean_c, covar_c, Ks[0], Ks[4], Ks[2], Ks[5], image_width, image_height,
covar2d, mean2d);
persp_proj<T>(mean_c, covar_c, Ks[0], Ks[4], Ks[2], Ks[5], image_width,
image_height, covar2d, mean2d);

T compensation;
T det = add_blur(eps2d, covar2d, compensation);
Expand Down Expand Up @@ -128,7 +128,6 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N,
}
}


std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
fully_fused_projection_fwd_tensor(
const torch::Tensor &means, // [N, 3]
Expand Down Expand Up @@ -166,16 +165,17 @@ fully_fused_projection_fwd_tensor(
compensations = torch::zeros({C, N}, means.options());
}
if (C && N) {
fully_fused_projection_fwd_kernel<float><<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>(
C, N, means.data_ptr<float>(),
covars.has_value() ? covars.value().data_ptr<float>() : nullptr,
quats.has_value() ? quats.value().data_ptr<float>() : nullptr,
scales.has_value() ? scales.value().data_ptr<float>() : nullptr,
viewmats.data_ptr<float>(), Ks.data_ptr<float>(), image_width, image_height,
eps2d, near_plane, far_plane, radius_clip, radii.data_ptr<int32_t>(),
means2d.data_ptr<float>(), depths.data_ptr<float>(),
conics.data_ptr<float>(),
calc_compensations ? compensations.data_ptr<float>() : nullptr);
fully_fused_projection_fwd_kernel<float>
<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>(
C, N, means.data_ptr<float>(),
covars.has_value() ? covars.value().data_ptr<float>() : nullptr,
quats.has_value() ? quats.value().data_ptr<float>() : nullptr,
scales.has_value() ? scales.value().data_ptr<float>() : nullptr,
viewmats.data_ptr<float>(), Ks.data_ptr<float>(), image_width,
image_height, eps2d, near_plane, far_plane, radius_clip,
radii.data_ptr<int32_t>(), means2d.data_ptr<float>(),
depths.data_ptr<float>(), conics.data_ptr<float>(),
calc_compensations ? compensations.data_ptr<float>() : nullptr);
}
return std::make_tuple(radii, means2d, depths, conics, compensations);
}
2 changes: 1 addition & 1 deletion gsplat/cuda/csrc/helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

#include <ATen/cuda/Atomic.cuh>
#include <ATen/Dispatch.h>
#include <ATen/cuda/Atomic.cuh>

#define PRAGMA_UNROLL _Pragma("unroll")

Expand Down
21 changes: 11 additions & 10 deletions gsplat/cuda/csrc/persp_proj_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

namespace cg = cooperative_groups;


/****************************************************************************
* Perspective Projection Forward Pass
****************************************************************************/
Expand All @@ -23,10 +22,10 @@ __global__ void persp_proj_fwd_kernel(const uint32_t C, const uint32_t N,
const uint32_t width, const uint32_t height,
T *__restrict__ means2d, // [C, N, 2]
T *__restrict__ covars2d // [C, N, 2, 2]
) {
) {
// For now we'll upcast float16 and bfloat16 to float32
using OpT = typename OpType<T>::type;

// parallelize over C * N.
uint32_t idx = cg::this_grid().thread_rank();
if (idx >= C * N) {
Expand Down Expand Up @@ -63,7 +62,6 @@ __global__ void persp_proj_fwd_kernel(const uint32_t C, const uint32_t N,
}
}


std::tuple<torch::Tensor, torch::Tensor>
persp_proj_fwd_tensor(const torch::Tensor &means, // [C, N, 3]
const torch::Tensor &covars, // [C, N, 3, 3]
Expand All @@ -82,12 +80,15 @@ persp_proj_fwd_tensor(const torch::Tensor &means, // [C, N, 3]

if (C && N) {
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means.scalar_type(), "persp_proj_fwd", [&]() {
persp_proj_fwd_kernel<scalar_t><<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>(
C, N, means.data_ptr<scalar_t>(), covars.data_ptr<scalar_t>(),
Ks.data_ptr<scalar_t>(), width, height, means2d.data_ptr<scalar_t>(),
covars2d.data_ptr<scalar_t>());
});
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, means.scalar_type(),
"persp_proj_fwd", [&]() {
persp_proj_fwd_kernel<scalar_t>
<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>(
C, N, means.data_ptr<scalar_t>(), covars.data_ptr<scalar_t>(),
Ks.data_ptr<scalar_t>(), width, height,
means2d.data_ptr<scalar_t>(), covars2d.data_ptr<scalar_t>());
});
}
return std::make_tuple(means2d, covars2d);
}
19 changes: 10 additions & 9 deletions gsplat/cuda/csrc/quat_scale_to_covar_preci_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ quat_scale_to_covar_preci_fwd_kernel(const uint32_t N,
mat3<OpT> covar, preci;
const vec4<OpT> quat = glm::make_vec4(quats);
const vec3<OpT> scale = glm::make_vec3(scales);
quat_scale_to_covar_preci(quat, scale,
covars ? &covar : nullptr,
quat_scale_to_covar_preci(quat, scale, covars ? &covar : nullptr,
precis ? &preci : nullptr);

// write to outputs: glm is column-major but we want row-major
Expand Down Expand Up @@ -118,13 +117,15 @@ quat_scale_to_covar_preci_fwd_tensor(const torch::Tensor &quats, // [N, 4]

if (N) {
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, quats.scalar_type(), "quat_scale_to_covar_preci_fwd", [&]() {
quat_scale_to_covar_preci_fwd_kernel<<<(N + N_THREADS - 1) / N_THREADS,
N_THREADS, 0, stream>>>(
N, quats.data_ptr<float>(), scales.data_ptr<float>(), triu,
compute_covar ? covars.data_ptr<float>() : nullptr,
compute_preci ? precis.data_ptr<float>() : nullptr);
});
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, quats.scalar_type(),
"quat_scale_to_covar_preci_fwd", [&]() {
quat_scale_to_covar_preci_fwd_kernel<<<(N + N_THREADS - 1) / N_THREADS,
N_THREADS, 0, stream>>>(
N, quats.data_ptr<float>(), scales.data_ptr<float>(), triu,
compute_covar ? covars.data_ptr<float>() : nullptr,
compute_preci ? precis.data_ptr<float>() : nullptr);
});
}
return std::make_tuple(covars, precis);
}
5 changes: 2 additions & 3 deletions gsplat/cuda/csrc/types.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
#define GSPLAT_CUDA_TYPES_H

#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <glm/glm.hpp>
#include <glm/gtc/type_ptr.hpp>
Expand All @@ -24,12 +24,11 @@ template <typename T> using mat4 = glm::mat<4, 4, T>;

template <typename T> using mat3x2 = glm::mat<3, 2, T>;


template <typename T> struct OpType {
typedef T type;
};

template<> struct OpType<__nv_bfloat16> {
template <> struct OpType<__nv_bfloat16> {
typedef float type;
};

Expand Down
36 changes: 17 additions & 19 deletions gsplat/cuda/csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
#include <cuda.h>
#include <cuda_runtime.h>


template <typename T>
inline __device__ mat3<T> quat_to_rotmat(const vec4<T> quat) {
template <typename T> inline __device__ mat3<T> quat_to_rotmat(const vec4<T> quat) {
T w = quat[0], x = quat[1], y = quat[2], z = quat[3];
// normalize
T inv_norm = rsqrt(x * x + y * y + z * z + w * w);
Expand All @@ -29,7 +27,8 @@ inline __device__ mat3<T> quat_to_rotmat(const vec4<T> quat) {
}

template <typename T>
inline __device__ void quat_to_rotmat_vjp(const vec4<T> quat, const mat3<T> v_R, vec4<T> &v_quat) {
inline __device__ void quat_to_rotmat_vjp(const vec4<T> quat, const mat3<T> v_R,
vec4<T> &v_quat) {
T w = quat[0], x = quat[1], y = quat[2], z = quat[3];
// normalize
T inv_norm = rsqrt(x * x + y * y + z * z + w * w);
Expand Down Expand Up @@ -65,8 +64,8 @@ inline __device__ void quat_scale_to_covar_preci(const vec4<T> quat,
}
if (preci != nullptr) {
// P = R * S^-1 * S^-1 * Rt
mat3<T> S = mat3<T>(1.0f / scale[0], 0.f, 0.f, 0.f, 1.0f / scale[1], 0.f,
0.f, 0.f, 1.0f / scale[2]);
mat3<T> S = mat3<T>(1.0f / scale[0], 0.f, 0.f, 0.f, 1.0f / scale[1], 0.f, 0.f,
0.f, 1.0f / scale[2]);
mat3<T> M = R * S;
*preci = M * glm::transpose(M);
}
Expand Down Expand Up @@ -146,8 +145,8 @@ inline __device__ void quat_scale_to_preci_vjp(
template <typename T>
inline __device__ void persp_proj(
// inputs
const vec3<T> mean3d, const mat3<T> cov3d, const T fx, const T fy,
const T cx, const T cy, const uint32_t width, const uint32_t height,
const vec3<T> mean3d, const mat3<T> cov3d, const T fx, const T fy, const T cx,
const T cy, const uint32_t width, const uint32_t height,
// outputs
mat2<T> &cov2d, vec2<T> &mean2d) {
T x = mean3d[0], y = mean3d[1], z = mean3d[2];
Expand All @@ -174,8 +173,8 @@ inline __device__ void persp_proj(
template <typename T>
inline __device__ void persp_proj_vjp(
// fwd inputs
const vec3<T> mean3d, const mat3<T> cov3d, const T fx, const T fy,
const T cx, const T cy, const uint32_t width, const uint32_t height,
const vec3<T> mean3d, const mat3<T> cov3d, const T fx, const T fy, const T cx,
const T cy, const uint32_t width, const uint32_t height,
// grad outputs
const mat2<T> v_cov2d, const vec2<T> v_mean2d,
// grad inputs
Expand All @@ -194,8 +193,8 @@ inline __device__ void persp_proj_vjp(

// mat3x2 is 3 columns x 2 rows.
mat3x2<T> J = mat3x2<T>(fx * rz, 0.f, // 1st column
0.f, fy * rz, // 2nd column
-fx * tx * rz2, -fy * ty * rz2 // 3rd column
0.f, fy * rz, // 2nd column
-fx * tx * rz2, -fy * ty * rz2 // 3rd column
);

// cov = J * V * Jt; G = df/dcov = v_cov
Expand All @@ -214,7 +213,8 @@ inline __device__ void persp_proj_vjp(
// df/dz = -fx * rz2 * df/dJ_00 - fy * rz2 * df/dJ_11
// + 2 * fx * tx * rz3 * df/dJ_02 + 2 * fy * ty * rz3
T rz3 = rz2 * rz;
mat3x2<T> v_J = v_cov2d * J * glm::transpose(cov3d) + glm::transpose(v_cov2d) * J * cov3d;
mat3x2<T> v_J =
v_cov2d * J * glm::transpose(cov3d) + glm::transpose(v_cov2d) * J * cov3d;

// fov clipping
if (x * rz <= lim_x && x * rz >= -lim_x) {
Expand Down Expand Up @@ -279,8 +279,7 @@ inline __device__ void covar_world_to_cam_vjp(
v_covar += glm::transpose(R) * v_covar_c * R;
}

template <typename T>
inline __device__ T inverse(const mat2<T> M, mat2<T> &Minv) {
template <typename T> inline __device__ T inverse(const mat2<T> M, mat2<T> &Minv) {
T det = M[0][0] * M[1][1] - M[0][1] * M[1][0];
if (det <= 0.f) {
return det;
Expand All @@ -301,8 +300,7 @@ inline __device__ void inverse_vjp(const T Minv, const T v_Minv, T &v_M) {
}

template <typename T>
inline __device__ T add_blur(const T eps2d, mat2<T> &covar,
T &compensation) {
inline __device__ T add_blur(const T eps2d, mat2<T> &covar, T &compensation) {
T det_orig = covar[0][0] * covar[1][1] - covar[0][1] * covar[1][0];
covar[0][0] += eps2d;
covar[1][1] += eps2d;
Expand All @@ -313,8 +311,8 @@ inline __device__ T add_blur(const T eps2d, mat2<T> &covar,

template <typename T>
inline __device__ void add_blur_vjp(const T eps2d, const mat2<T> conic_blur,
const T compensation,
const T v_compensation, mat2<T> &v_covar) {
const T compensation, const T v_compensation,
mat2<T> &v_covar) {
// comp = sqrt(det(covar) / det(covar_blur))

// d [det(M)] / d M = adj(M)
Expand Down
28 changes: 16 additions & 12 deletions gsplat/cuda/csrc/world_to_cam_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

namespace cg = cooperative_groups;


/****************************************************************************
* World to Camera Transformation Backward Pass
****************************************************************************/
Expand Down Expand Up @@ -113,7 +112,6 @@ world_to_cam_bwd_kernel(const uint32_t C, const uint32_t N,
}
}


std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
world_to_cam_bwd_tensor(const torch::Tensor &means, // [N, 3]
const torch::Tensor &covars, // [N, 3, 3]
Expand Down Expand Up @@ -148,16 +146,22 @@ world_to_cam_bwd_tensor(const torch::Tensor &means, // [N, 3]

if (C && N) {
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, means.scalar_type(), "world_to_cam_bwd", [&]() {
world_to_cam_bwd_kernel<scalar_t><<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>(
C, N, means.data_ptr<scalar_t>(), covars.data_ptr<scalar_t>(),
viewmats.data_ptr<scalar_t>(),
v_means_c.has_value() ? v_means_c.value().data_ptr<scalar_t>() : nullptr,
v_covars_c.has_value() ? v_covars_c.value().data_ptr<scalar_t>() : nullptr,
means_requires_grad ? v_means.data_ptr<scalar_t>() : nullptr,
covars_requires_grad ? v_covars.data_ptr<scalar_t>() : nullptr,
viewmats_requires_grad ? v_viewmats.data_ptr<scalar_t>() : nullptr);
});
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, means.scalar_type(),
"world_to_cam_bwd", [&]() {
world_to_cam_bwd_kernel<scalar_t>
<<<(C * N + N_THREADS - 1) / N_THREADS, N_THREADS, 0, stream>>>(
C, N, means.data_ptr<scalar_t>(), covars.data_ptr<scalar_t>(),
viewmats.data_ptr<scalar_t>(),
v_means_c.has_value() ? v_means_c.value().data_ptr<scalar_t>()
: nullptr,
v_covars_c.has_value() ? v_covars_c.value().data_ptr<scalar_t>()
: nullptr,
means_requires_grad ? v_means.data_ptr<scalar_t>() : nullptr,
covars_requires_grad ? v_covars.data_ptr<scalar_t>() : nullptr,
viewmats_requires_grad ? v_viewmats.data_ptr<scalar_t>()
: nullptr);
});
}
return std::make_tuple(v_means, v_covars, v_viewmats);
}
Loading

0 comments on commit 7562a39

Please sign in to comment.