Skip to content

Commit

Permalink
avoid divide zero
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianbo Ye committed May 30, 2024
1 parent d045494 commit 71eb305
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ def _all():

# w.r.t. theta
fN0 = torch.sqrt(x * x + y * y) # sin(theta)
fC0n = x / fN0
fS0n = y / fN0
cot_theta = z / fN0
fC0n = x / torch.clamp(fN0, min=1e-6)
fS0n = y / torch.clamp(fN0, min=1e-6)
dP11 = - result[..., 2]
d_varphi_and_theta[..., 2, 1] = fTmpA * fN0
d_varphi_and_theta[..., 3, 1] = dP11 * fC0n
Expand Down Expand Up @@ -218,14 +217,17 @@ def _all():
d_varphi_and_theta[..., 4, 0] = 2 * result[..., 8]

# w.r.t. theta
fN1 = fN0 * fN0
fC1n = fC1 / torch.clamp(fN0, min=1e-6)
fS1n = fS1 / torch.clamp(fN0, min=1e-6)
dP20 = math.sqrt(3) * fTmpB * fN0
dP21 = 2 * fTmpA * fN0 + fTmpB * cot_theta
dP22 = - fTmpB / fN0
dP21 = 2 * fTmpA * fN1 + fTmpB * z
dP22 = - fTmpB
d_varphi_and_theta[..., 6, 1] = dP20
d_varphi_and_theta[..., 7, 1] = dP21 * x
d_varphi_and_theta[..., 5, 1] = dP21 * y
d_varphi_and_theta[..., 8, 1] = dP22 * fC1
d_varphi_and_theta[..., 4, 1] = dP22 * fS1
d_varphi_and_theta[..., 7, 1] = dP21 * fC0n
d_varphi_and_theta[..., 5, 1] = dP21 * fS0n
d_varphi_and_theta[..., 8, 1] = dP22 * fC1n
d_varphi_and_theta[..., 4, 1] = dP22 * fS1n

if basis_dim <= 9:
return _all()
Expand Down Expand Up @@ -254,17 +256,19 @@ def _all():
d_varphi_and_theta[..., 9, 0] = 3 * result[..., 15]

# w.r.t. theta
fC2n = fC2 / torch.clamp(fN0, min=1e-6)
fS2n = fS2 / torch.clamp(fN0, min=1e-6)
dP30 = math.sqrt(6) * fTmpC * fN0
dP31 = math.sqrt(10) * fTmpB * fN0 + fTmpC * cot_theta
dP32 = math.sqrt(6) * fTmpA * fN0 + 2 * fTmpB * cot_theta
dP33 = - math.sqrt(3 / 2) * fTmpB / fN0
dP31 = math.sqrt(10) * fTmpB * fN1 + fTmpC * z
dP32 = math.sqrt(6) * fTmpA * fN1 + 2 * fTmpB * z
dP33 = - math.sqrt(3 / 2) * fTmpB
d_varphi_and_theta[..., 12, 1] = dP30
d_varphi_and_theta[..., 13, 1] = dP31 * x
d_varphi_and_theta[..., 11, 1] = dP31 * y
d_varphi_and_theta[..., 14, 1] = dP32 * fC1
d_varphi_and_theta[..., 10, 1] = dP32 * fS1
d_varphi_and_theta[..., 15, 1] = dP33 * fC2
d_varphi_and_theta[..., 9, 1] = dP33 * fS2
d_varphi_and_theta[..., 13, 1] = dP31 * fC0n
d_varphi_and_theta[..., 11, 1] = dP31 * fS0n
d_varphi_and_theta[..., 14, 1] = dP32 * fC1n
d_varphi_and_theta[..., 10, 1] = dP32 * fS1n
d_varphi_and_theta[..., 15, 1] = dP33 * fC2n
d_varphi_and_theta[..., 9, 1] = dP33 * fS2n


if basis_dim <= 16:
Expand Down Expand Up @@ -301,20 +305,22 @@ def _all():
d_varphi_and_theta[..., 16, 0] = 4 * result[..., 24]

# w.r.t. theta
fC3n = fC3 / torch.clamp(fN0, min=1e-6)
fS3n = fS3 / torch.clamp(fN0, min=1e-6)
dP40 = math.sqrt(10) * fTmpD * fN0
dP41 = math.sqrt(18) * fTmpC * fN0 + fTmpD * cot_theta
dP42 = math.sqrt(14) * fTmpB * fN0 + 2 * fTmpC * cot_theta
dP43 = math.sqrt(8) * fTmpA * fN0 + 3 * fTmpB * cot_theta
dP44 = - math.sqrt(2) * fTmpB / fN0
dP41 = math.sqrt(18) * fTmpC * fN1 + fTmpD * z
dP42 = math.sqrt(14) * fTmpB * fN1 + 2 * fTmpC * z
dP43 = math.sqrt(8) * fTmpA * fN1 + 3 * fTmpB * z
dP44 = - math.sqrt(2) * fTmpB
d_varphi_and_theta[..., 20, 1] = dP40
d_varphi_and_theta[..., 21, 1] = dP41 * x
d_varphi_and_theta[..., 19, 1] = dP41 * y
d_varphi_and_theta[..., 22, 1] = dP42 * fC1
d_varphi_and_theta[..., 18, 1] = dP42 * fS1
d_varphi_and_theta[..., 23, 1] = dP43 * fC2
d_varphi_and_theta[..., 17, 1] = dP43 * fS2
d_varphi_and_theta[..., 24, 1] = dP44 * fC3
d_varphi_and_theta[..., 16, 1] = dP44 * fS3
d_varphi_and_theta[..., 21, 1] = dP41 * fC0n
d_varphi_and_theta[..., 19, 1] = dP41 * fS0n
d_varphi_and_theta[..., 22, 1] = dP42 * fC1n
d_varphi_and_theta[..., 18, 1] = dP42 * fS1n
d_varphi_and_theta[..., 23, 1] = dP43 * fC2n
d_varphi_and_theta[..., 17, 1] = dP43 * fS2n
d_varphi_and_theta[..., 24, 1] = dP44 * fC3n
d_varphi_and_theta[..., 16, 1] = dP44 * fS3n

return _all()

Expand Down

0 comments on commit 71eb305

Please sign in to comment.