Skip to content

Commit

Permalink
move sh code in a folder
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Nov 5, 2023
1 parent 349254c commit 22ed24d
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 219 deletions.
2 changes: 1 addition & 1 deletion e3nn_jax/_src/s2grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import e3nn_jax as e3nn

from .activation import parity_function
from .spherical_harmonics import _sh_alpha, _sh_beta
from .spherical_harmonics.legendre import _sh_alpha, _sh_beta


class SphericalSignal:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import math
from functools import partial
from typing import Dict, List, Sequence, Tuple, Union
from typing import List, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
import sympy

import e3nn_jax as e3nn
from e3nn_jax._src.utils.sympy import sqrtQarray_to_sympy

from .legendre import legendre_spherical_harmonics
from .recursive import recursive_spherical_harmonics


def sh(
Expand Down Expand Up @@ -189,12 +189,12 @@ def _spherical_harmonics(
ls: Tuple[int, ...], x: jnp.ndarray, normalization: str, algorithm: Tuple[str]
) -> List[jnp.ndarray]:
if "legendre" in algorithm:
out = _legendre_spherical_harmonics(max(ls), x, False, normalization)
out = legendre_spherical_harmonics(max(ls), x, False, normalization)
return [out[..., l**2 : (l + 1) ** 2] for l in ls]
if "recursive" in algorithm:
context = dict()
for l in ls:
_recursive_spherical_harmonics(l, context, x, normalization, algorithm)
recursive_spherical_harmonics(l, context, x, normalization, algorithm)
return [context[l] for l in ls]
raise ValueError("Unknown algorithm: must be 'legendre' or 'recursive'")

Expand Down Expand Up @@ -250,215 +250,3 @@ def h(l: int, r: jnp.ndarray) -> jnp.ndarray:

tangent = [h(l, r) if l > 0 else jnp.zeros_like(r) for l, r in zip(ls, res)]
return primal, tangent


def _recursive_spherical_harmonics(
l: int,
context: Dict[int, jnp.ndarray],
input: jnp.ndarray,
normalization: str,
algorithm: Tuple[str],
) -> sympy.Array:
context.update(dict(jnp=jnp, clebsch_gordan=e3nn.clebsch_gordan))

if l == 0:
if 0 not in context:
if normalization == "integral":
context[0] = math.sqrt(1 / (4 * math.pi)) * jnp.ones_like(
input[..., :1]
)
elif normalization == "component":
context[0] = jnp.ones_like(input[..., :1])
else:
context[0] = jnp.ones_like(input[..., :1])

return sympy.Array([1])

if l == 1:
if 1 not in context:
if normalization == "integral":
context[1] = math.sqrt(3 / (4 * math.pi)) * input
elif normalization == "component":
context[1] = math.sqrt(3) * input
else:
context[1] = input

return sympy.Array([1, 0, 0])

def sh_var(l):
return [sympy.symbols(f"sh{l}_{m}") for m in range(2 * l + 1)]

l2 = biggest_power_of_two(l - 1)
l1 = l - l2

w = sqrtQarray_to_sympy(e3nn.clebsch_gordan(l1, l2, l))
yx = sympy.Array(
[
sum(
sh_var(l1)[i] * sh_var(l2)[j] * w[i, j, k]
for i in range(2 * l1 + 1)
for j in range(2 * l2 + 1)
)
for k in range(2 * l + 1)
]
)

sph_1_l1 = _recursive_spherical_harmonics(
l1, context, input, normalization, algorithm
)
sph_1_l2 = _recursive_spherical_harmonics(
l2, context, input, normalization, algorithm
)

y1 = yx.subs(zip(sh_var(l1), sph_1_l1)).subs(zip(sh_var(l2), sph_1_l2))
norm = sympy.sqrt(sum(y1.applyfunc(lambda x: x**2)))
y1 = y1 / norm

if l not in context:
if normalization == "integral":
x = math.sqrt((2 * l + 1) / (4 * math.pi)) / (
math.sqrt((2 * l1 + 1) / (4 * math.pi))
* math.sqrt((2 * l2 + 1) / (4 * math.pi))
)
elif normalization == "component":
x = math.sqrt((2 * l + 1) / ((2 * l1 + 1) * (2 * l2 + 1)))
else:
x = 1

w = (x / float(norm)) * e3nn.clebsch_gordan(l1, l2, l)
w = w.astype(input.dtype)

if "dense" in algorithm:
context[l] = jnp.einsum("...i,...j,ijk->...k", context[l1], context[l2], w)
elif "sparse" in algorithm:
context[l] = jnp.stack(
[
sum(
[
w[i, j, k] * context[l1][..., i] * context[l2][..., j]
for i in range(2 * l1 + 1)
for j in range(2 * l2 + 1)
if w[i, j, k] != 0
]
)
for k in range(2 * l + 1)
],
axis=-1,
)
else:
raise ValueError("Unknown algorithm: must be 'dense' or 'sparse'")

return y1


def biggest_power_of_two(n):
return 2 ** (n.bit_length() - 1)


def legendre(
lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool = False
) -> jnp.ndarray:
r"""Associated Legendre polynomials.
en.wikipedia.org/wiki/Associated_Legendre_polynomials
Args:
lmax (int): maximum l value
x (jnp.ndarray): input array of shape ``(...)``
phase (float): -1 or 1, multiplies by :math:`(-1)^m`
is_normalized (bool): True if the associated Legendre functions are normalized.
Returns:
jnp.ndarray: Associated Legendre polynomials ``P(l,m)``
In an array of shape ``(lmax + 1, lmax + 1, ...)``
"""
x = jnp.asarray(x)
return _legendre(lmax, x, phase, is_normalized)


@partial(jax.jit, static_argnums=(0, 3))
def _legendre(
lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool
) -> jnp.ndarray:
p = jax.scipy.special.lpmn_values(
lmax, lmax, x.flatten(), is_normalized
) # [m, l, x]
p = (-phase) ** jnp.arange(lmax + 1)[:, None, None] * p
p = jnp.transpose(p, (1, 0, 2)) # [l, m, x]
p = jnp.reshape(p, (lmax + 1, lmax + 1) + x.shape)
return p


def _sh_alpha(l: int, alpha: jnp.ndarray) -> jnp.ndarray:
r"""Alpha dependence of spherical harmonics.
Args:
l: l value
alpha: input array of shape ``(...)``
Returns:
Array of shape ``(..., 2 * l + 1)``
"""
alpha = alpha[..., None] # [..., 1]
m = jnp.arange(1, l + 1) # [1, 2, 3, ..., l]
cos = jnp.cos(m * alpha) # [..., m]

m = jnp.arange(l, 0, -1) # [l, l-1, l-2, ..., 1]
sin = jnp.sin(m * alpha) # [..., m]

return jnp.concatenate(
[
jnp.sqrt(2) * sin,
jnp.ones_like(alpha),
jnp.sqrt(2) * cos,
],
axis=-1,
)


def _sh_beta(lmax: int, cos_betas: jnp.ndarray) -> jnp.ndarray:
r"""Beta dependence of spherical harmonics.
Args:
lmax: l value
cos_betas: input array of shape ``(...)``
Returns:
Array of shape ``(..., l, m)``
"""
sh_y = legendre(lmax, cos_betas, phase=1.0, is_normalized=True) # [l, m, ...]
sh_y = jnp.moveaxis(sh_y, 0, -1) # [m, ..., l]
sh_y = jnp.moveaxis(sh_y, 0, -1) # [..., l, m]
return sh_y


def _legendre_spherical_harmonics(
lmax: int, x: jnp.ndarray, normalize: bool, normalization: str
) -> jnp.ndarray:
alpha = jnp.arctan2(x[..., 0], x[..., 2])
sh_alpha = _sh_alpha(lmax, alpha) # [..., 2 * l + 1]

n = jnp.linalg.norm(x, axis=-1, keepdims=True)
x = x / jnp.where(n > 0, n, 1.0)

sh_y = _sh_beta(lmax, x[..., 1]) # [..., l, m]

sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,), x.dtype)

def f(l, sh):
def g(m, sh):
y = sh_y[..., l, jnp.abs(m)]
if not normalize:
y = y * n[..., 0] ** l
if normalization == "norm":
y = y * (jnp.sqrt(4 * jnp.pi) / jnp.sqrt(2 * l + 1))
elif normalization == "component":
y = y * jnp.sqrt(4 * jnp.pi)

a = sh_alpha[..., lmax + m]
return sh.at[..., l**2 + l + m].set(y * a)

return jax.lax.fori_loop(-l, l + 1, g, sh)

sh = jax.lax.fori_loop(0, lmax + 1, f, sh)
return sh
113 changes: 113 additions & 0 deletions e3nn_jax/_src/spherical_harmonics/legendre.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from functools import partial

import jax
import jax.numpy as jnp


def legendre(
lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool = False
) -> jnp.ndarray:
r"""Associated Legendre polynomials.
en.wikipedia.org/wiki/Associated_Legendre_polynomials
Args:
lmax (int): maximum l value
x (jnp.ndarray): input array of shape ``(...)``
phase (float): -1 or 1, multiplies by :math:`(-1)^m`
is_normalized (bool): True if the associated Legendre functions are normalized.
Returns:
jnp.ndarray: Associated Legendre polynomials ``P(l,m)``
In an array of shape ``(lmax + 1, lmax + 1, ...)``
"""
x = jnp.asarray(x)
return _legendre(lmax, x, phase, is_normalized)


@partial(jax.jit, static_argnums=(0, 3))
def _legendre(
lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool
) -> jnp.ndarray:
p = jax.scipy.special.lpmn_values(
lmax, lmax, x.flatten(), is_normalized
) # [m, l, x]
p = (-phase) ** jnp.arange(lmax + 1)[:, None, None] * p
p = jnp.transpose(p, (1, 0, 2)) # [l, m, x]
p = jnp.reshape(p, (lmax + 1, lmax + 1) + x.shape)
return p


def _sh_alpha(l: int, alpha: jnp.ndarray) -> jnp.ndarray:
r"""Alpha dependence of spherical harmonics.
Args:
l: l value
alpha: input array of shape ``(...)``
Returns:
Array of shape ``(..., 2 * l + 1)``
"""
alpha = alpha[..., None] # [..., 1]
m = jnp.arange(1, l + 1) # [1, 2, 3, ..., l]
cos = jnp.cos(m * alpha) # [..., m]

m = jnp.arange(l, 0, -1) # [l, l-1, l-2, ..., 1]
sin = jnp.sin(m * alpha) # [..., m]

return jnp.concatenate(
[
jnp.sqrt(2) * sin,
jnp.ones_like(alpha),
jnp.sqrt(2) * cos,
],
axis=-1,
)


def _sh_beta(lmax: int, cos_betas: jnp.ndarray) -> jnp.ndarray:
r"""Beta dependence of spherical harmonics.
Args:
lmax: l value
cos_betas: input array of shape ``(...)``
Returns:
Array of shape ``(..., l, m)``
"""
sh_y = legendre(lmax, cos_betas, phase=1.0, is_normalized=True) # [l, m, ...]
sh_y = jnp.moveaxis(sh_y, 0, -1) # [m, ..., l]
sh_y = jnp.moveaxis(sh_y, 0, -1) # [..., l, m]
return sh_y


def legendre_spherical_harmonics(
lmax: int, x: jnp.ndarray, normalize: bool, normalization: str
) -> jnp.ndarray:
alpha = jnp.arctan2(x[..., 0], x[..., 2])
sh_alpha = _sh_alpha(lmax, alpha) # [..., 2 * l + 1]

n = jnp.linalg.norm(x, axis=-1, keepdims=True)
x = x / jnp.where(n > 0, n, 1.0)

sh_y = _sh_beta(lmax, x[..., 1]) # [..., l, m]

sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,), x.dtype)

def f(l, sh):
def g(m, sh):
y = sh_y[..., l, jnp.abs(m)]
if not normalize:
y = y * n[..., 0] ** l
if normalization == "norm":
y = y * (jnp.sqrt(4 * jnp.pi) / jnp.sqrt(2 * l + 1))
elif normalization == "component":
y = y * jnp.sqrt(4 * jnp.pi)

a = sh_alpha[..., lmax + m]
return sh.at[..., l**2 + l + m].set(y * a)

return jax.lax.fori_loop(-l, l + 1, g, sh)

sh = jax.lax.fori_loop(0, lmax + 1, f, sh)
return sh
Loading

0 comments on commit 22ed24d

Please sign in to comment.