Skip to content

Commit

Permalink
further simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
Jianbo Ye committed May 30, 2024
1 parent f89c03b commit d045494
Showing 1 changed file with 32 additions and 53 deletions.
85 changes: 32 additions & 53 deletions gsplat/_torch_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,8 @@ def _all():
fC0n = x / fN0
fS0n = y / fN0
cot_theta = z / fN0
P10 = result[..., 2]
P11 = fTmpA * fN0
dP10 = P11
dP11 = - P10
d_varphi_and_theta[..., 2, 1] = dP10
dP11 = - result[..., 2]
d_varphi_and_theta[..., 2, 1] = fTmpA * fN0
d_varphi_and_theta[..., 3, 1] = dP11 * fC0n
d_varphi_and_theta[..., 1, 1] = dP11 * fS0n

Expand All @@ -221,19 +218,14 @@ def _all():
d_varphi_and_theta[..., 4, 0] = 2 * result[..., 8]

# w.r.t. theta
fN1 = fN0 * fN0
fC1n = fC1 / fN1
fS1n = fS1 / fN1
P21 = fTmpB * fN0
P22 = fTmpA * fN1
dP20 = math.sqrt(3) * P21
dP21 = 2 * P22 + P21 * cot_theta
dP22 = - P21
dP20 = math.sqrt(3) * fTmpB * fN0
dP21 = 2 * fTmpA * fN0 + fTmpB * cot_theta
dP22 = - fTmpB / fN0
d_varphi_and_theta[..., 6, 1] = dP20
d_varphi_and_theta[..., 7, 1] = dP21 * fC0n
d_varphi_and_theta[..., 5, 1] = dP21 * fS0n
d_varphi_and_theta[..., 8, 1] = dP22 * fC1n
d_varphi_and_theta[..., 4, 1] = dP22 * fS1n
d_varphi_and_theta[..., 7, 1] = dP21 * x
d_varphi_and_theta[..., 5, 1] = dP21 * y
d_varphi_and_theta[..., 8, 1] = dP22 * fC1
d_varphi_and_theta[..., 4, 1] = dP22 * fS1

if basis_dim <= 9:
return _all()
Expand Down Expand Up @@ -262,23 +254,17 @@ def _all():
d_varphi_and_theta[..., 9, 0] = 3 * result[..., 15]

# w.r.t. theta
fN2 = fN1 * fN0
fC2n = fC2 / fN2
fS2n = fS2 / fN2
P31 = fTmpC * fN0
P32 = fTmpB * fN1
P33 = fTmpA * fN2
dP30 = math.sqrt(6) * P31
dP31 = math.sqrt(10) * P32 + P31 * cot_theta
dP32 = math.sqrt(6) * P33 + 2 * P32 * cot_theta
dP33 = - math.sqrt(3 / 2) * P32
dP30 = math.sqrt(6) * fTmpC * fN0
dP31 = math.sqrt(10) * fTmpB * fN0 + fTmpC * cot_theta
dP32 = math.sqrt(6) * fTmpA * fN0 + 2 * fTmpB * cot_theta
dP33 = - math.sqrt(3 / 2) * fTmpB / fN0
d_varphi_and_theta[..., 12, 1] = dP30
d_varphi_and_theta[..., 13, 1] = dP31 * fC0n
d_varphi_and_theta[..., 11, 1] = dP31 * fS0n
d_varphi_and_theta[..., 14, 1] = dP32 * fC1n
d_varphi_and_theta[..., 10, 1] = dP32 * fS1n
d_varphi_and_theta[..., 15, 1] = dP33 * fC2n
d_varphi_and_theta[..., 9, 1] = dP33 * fS2n
d_varphi_and_theta[..., 13, 1] = dP31 * x
d_varphi_and_theta[..., 11, 1] = dP31 * y
d_varphi_and_theta[..., 14, 1] = dP32 * fC1
d_varphi_and_theta[..., 10, 1] = dP32 * fS1
d_varphi_and_theta[..., 15, 1] = dP33 * fC2
d_varphi_and_theta[..., 9, 1] = dP33 * fS2


if basis_dim <= 16:
Expand Down Expand Up @@ -315,27 +301,20 @@ def _all():
d_varphi_and_theta[..., 16, 0] = 4 * result[..., 24]

# w.r.t. theta
fN3 = fN2 * fN0
fC3n = fC3 / fN3
fS3n = fS3 / fN3
P41 = fTmpD * fN0
P42 = fTmpC * fN1
P43 = fTmpB * fN2
P44 = fTmpA * fN3
dP40 = math.sqrt(10) * P41
dP41 = math.sqrt(18) * P42 + P41 * cot_theta
dP42 = math.sqrt(14) * P43 + 2 * P42 * cot_theta
dP43 = math.sqrt(8) * P44 + 3 * P43 * cot_theta
dP44 = - math.sqrt(2) * P43
dP40 = math.sqrt(10) * fTmpD * fN0
dP41 = math.sqrt(18) * fTmpC * fN0 + fTmpD * cot_theta
dP42 = math.sqrt(14) * fTmpB * fN0 + 2 * fTmpC * cot_theta
dP43 = math.sqrt(8) * fTmpA * fN0 + 3 * fTmpB * cot_theta
dP44 = - math.sqrt(2) * fTmpB / fN0
d_varphi_and_theta[..., 20, 1] = dP40
d_varphi_and_theta[..., 21, 1] = dP41 * fC0n
d_varphi_and_theta[..., 19, 1] = dP41 * fS0n
d_varphi_and_theta[..., 22, 1] = dP42 * fC1n
d_varphi_and_theta[..., 18, 1] = dP42 * fS1n
d_varphi_and_theta[..., 23, 1] = dP43 * fC2n
d_varphi_and_theta[..., 17, 1] = dP43 * fS2n
d_varphi_and_theta[..., 24, 1] = dP44 * fC3n
d_varphi_and_theta[..., 16, 1] = dP44 * fS3n
d_varphi_and_theta[..., 21, 1] = dP41 * x
d_varphi_and_theta[..., 19, 1] = dP41 * y
d_varphi_and_theta[..., 22, 1] = dP42 * fC1
d_varphi_and_theta[..., 18, 1] = dP42 * fS1
d_varphi_and_theta[..., 23, 1] = dP43 * fC2
d_varphi_and_theta[..., 17, 1] = dP43 * fS2
d_varphi_and_theta[..., 24, 1] = dP44 * fC3
d_varphi_and_theta[..., 16, 1] = dP44 * fS3

return _all()

Expand Down

0 comments on commit d045494

Please sign in to comment.