diff --git a/e3nn_jax/_src/s2grid.py b/e3nn_jax/_src/s2grid.py index b8e492a..4d45f58 100644 --- a/e3nn_jax/_src/s2grid.py +++ b/e3nn_jax/_src/s2grid.py @@ -10,7 +10,7 @@ import e3nn_jax as e3nn from .activation import parity_function -from .spherical_harmonics import _sh_alpha, _sh_beta +from .spherical_harmonics.legendre import _sh_alpha, _sh_beta class SphericalSignal: diff --git a/e3nn_jax/_src/spherical_harmonics.py b/e3nn_jax/_src/spherical_harmonics/__init__.py similarity index 57% rename from e3nn_jax/_src/spherical_harmonics.py rename to e3nn_jax/_src/spherical_harmonics/__init__.py index 085dce2..7ba3cb6 100644 --- a/e3nn_jax/_src/spherical_harmonics.py +++ b/e3nn_jax/_src/spherical_harmonics/__init__.py @@ -1,13 +1,13 @@ -import math from functools import partial -from typing import Dict, List, Sequence, Tuple, Union +from typing import List, Sequence, Tuple, Union import jax import jax.numpy as jnp -import sympy import e3nn_jax as e3nn -from e3nn_jax._src.utils.sympy import sqrtQarray_to_sympy + +from .legendre import legendre_spherical_harmonics +from .recursive import recursive_spherical_harmonics def sh( @@ -189,12 +189,12 @@ def _spherical_harmonics( ls: Tuple[int, ...], x: jnp.ndarray, normalization: str, algorithm: Tuple[str] ) -> List[jnp.ndarray]: if "legendre" in algorithm: - out = _legendre_spherical_harmonics(max(ls), x, False, normalization) + out = legendre_spherical_harmonics(max(ls), x, False, normalization) return [out[..., l**2 : (l + 1) ** 2] for l in ls] if "recursive" in algorithm: context = dict() for l in ls: - _recursive_spherical_harmonics(l, context, x, normalization, algorithm) + recursive_spherical_harmonics(l, context, x, normalization, algorithm) return [context[l] for l in ls] raise ValueError("Unknown algorithm: must be 'legendre' or 'recursive'") @@ -250,215 +250,3 @@ def h(l: int, r: jnp.ndarray) -> jnp.ndarray: tangent = [h(l, r) if l > 0 else jnp.zeros_like(r) for l, r in zip(ls, res)] return primal, tangent - - -def _recursive_spherical_harmonics( - l: int, - context: Dict[int, jnp.ndarray], - input: jnp.ndarray, - normalization: str, - algorithm: Tuple[str], -) -> sympy.Array: - context.update(dict(jnp=jnp, clebsch_gordan=e3nn.clebsch_gordan)) - - if l == 0: - if 0 not in context: - if normalization == "integral": - context[0] = math.sqrt(1 / (4 * math.pi)) * jnp.ones_like( - input[..., :1] - ) - elif normalization == "component": - context[0] = jnp.ones_like(input[..., :1]) - else: - context[0] = jnp.ones_like(input[..., :1]) - - return sympy.Array([1]) - - if l == 1: - if 1 not in context: - if normalization == "integral": - context[1] = math.sqrt(3 / (4 * math.pi)) * input - elif normalization == "component": - context[1] = math.sqrt(3) * input - else: - context[1] = input - - return sympy.Array([1, 0, 0]) - - def sh_var(l): - return [sympy.symbols(f"sh{l}_{m}") for m in range(2 * l + 1)] - - l2 = biggest_power_of_two(l - 1) - l1 = l - l2 - - w = sqrtQarray_to_sympy(e3nn.clebsch_gordan(l1, l2, l)) - yx = sympy.Array( - [ - sum( - sh_var(l1)[i] * sh_var(l2)[j] * w[i, j, k] - for i in range(2 * l1 + 1) - for j in range(2 * l2 + 1) - ) - for k in range(2 * l + 1) - ] - ) - - sph_1_l1 = _recursive_spherical_harmonics( - l1, context, input, normalization, algorithm - ) - sph_1_l2 = _recursive_spherical_harmonics( - l2, context, input, normalization, algorithm - ) - - y1 = yx.subs(zip(sh_var(l1), sph_1_l1)).subs(zip(sh_var(l2), sph_1_l2)) - norm = sympy.sqrt(sum(y1.applyfunc(lambda x: x**2))) - y1 = y1 / norm - - if l not in context: - if normalization == "integral": - x = math.sqrt((2 * l + 1) / (4 * math.pi)) / ( - math.sqrt((2 * l1 + 1) / (4 * math.pi)) - * math.sqrt((2 * l2 + 1) / (4 * math.pi)) - ) - elif normalization == "component": - x = math.sqrt((2 * l + 1) / ((2 * l1 + 1) * (2 * l2 + 1))) - else: - x = 1 - - w = (x / float(norm)) * e3nn.clebsch_gordan(l1, l2, l) - w = w.astype(input.dtype) - - if "dense" in algorithm: - context[l] = jnp.einsum("...i,...j,ijk->...k", context[l1], context[l2], w) - elif "sparse" in algorithm: - context[l] = jnp.stack( - [ - sum( - [ - w[i, j, k] * context[l1][..., i] * context[l2][..., j] - for i in range(2 * l1 + 1) - for j in range(2 * l2 + 1) - if w[i, j, k] != 0 - ] - ) - for k in range(2 * l + 1) - ], - axis=-1, - ) - else: - raise ValueError("Unknown algorithm: must be 'dense' or 'sparse'") - - return y1 - - -def biggest_power_of_two(n): - return 2 ** (n.bit_length() - 1) - - -def legendre( - lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool = False -) -> jnp.ndarray: - r"""Associated Legendre polynomials. - - en.wikipedia.org/wiki/Associated_Legendre_polynomials - - Args: - lmax (int): maximum l value - x (jnp.ndarray): input array of shape ``(...)`` - phase (float): -1 or 1, multiplies by :math:`(-1)^m` - is_normalized (bool): True if the associated Legendre functions are normalized. - - Returns: - jnp.ndarray: Associated Legendre polynomials ``P(l,m)`` - In an array of shape ``(lmax + 1, lmax + 1, ...)`` - """ - x = jnp.asarray(x) - return _legendre(lmax, x, phase, is_normalized) - - -@partial(jax.jit, static_argnums=(0, 3)) -def _legendre( - lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool -) -> jnp.ndarray: - p = jax.scipy.special.lpmn_values( - lmax, lmax, x.flatten(), is_normalized - ) # [m, l, x] - p = (-phase) ** jnp.arange(lmax + 1)[:, None, None] * p - p = jnp.transpose(p, (1, 0, 2)) # [l, m, x] - p = jnp.reshape(p, (lmax + 1, lmax + 1) + x.shape) - return p - - -def _sh_alpha(l: int, alpha: jnp.ndarray) -> jnp.ndarray: - r"""Alpha dependence of spherical harmonics. - - Args: - l: l value - alpha: input array of shape ``(...)`` - - Returns: - Array of shape ``(..., 2 * l + 1)`` - """ - alpha = alpha[..., None] # [..., 1] - m = jnp.arange(1, l + 1) # [1, 2, 3, ..., l] - cos = jnp.cos(m * alpha) # [..., m] - - m = jnp.arange(l, 0, -1) # [l, l-1, l-2, ..., 1] - sin = jnp.sin(m * alpha) # [..., m] - - return jnp.concatenate( - [ - jnp.sqrt(2) * sin, - jnp.ones_like(alpha), - jnp.sqrt(2) * cos, - ], - axis=-1, - ) - - -def _sh_beta(lmax: int, cos_betas: jnp.ndarray) -> jnp.ndarray: - r"""Beta dependence of spherical harmonics. - - Args: - lmax: l value - cos_betas: input array of shape ``(...)`` - - Returns: - Array of shape ``(..., l, m)`` - """ - sh_y = legendre(lmax, cos_betas, phase=1.0, is_normalized=True) # [l, m, ...] - sh_y = jnp.moveaxis(sh_y, 0, -1) # [m, ..., l] - sh_y = jnp.moveaxis(sh_y, 0, -1) # [..., l, m] - return sh_y - - -def _legendre_spherical_harmonics( - lmax: int, x: jnp.ndarray, normalize: bool, normalization: str -) -> jnp.ndarray: - alpha = jnp.arctan2(x[..., 0], x[..., 2]) - sh_alpha = _sh_alpha(lmax, alpha) # [..., 2 * l + 1] - - n = jnp.linalg.norm(x, axis=-1, keepdims=True) - x = x / jnp.where(n > 0, n, 1.0) - - sh_y = _sh_beta(lmax, x[..., 1]) # [..., l, m] - - sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,), x.dtype) - - def f(l, sh): - def g(m, sh): - y = sh_y[..., l, jnp.abs(m)] - if not normalize: - y = y * n[..., 0] ** l - if normalization == "norm": - y = y * (jnp.sqrt(4 * jnp.pi) / jnp.sqrt(2 * l + 1)) - elif normalization == "component": - y = y * jnp.sqrt(4 * jnp.pi) - - a = sh_alpha[..., lmax + m] - return sh.at[..., l**2 + l + m].set(y * a) - - return jax.lax.fori_loop(-l, l + 1, g, sh) - - sh = jax.lax.fori_loop(0, lmax + 1, f, sh) - return sh diff --git a/e3nn_jax/_src/spherical_harmonics/legendre.py b/e3nn_jax/_src/spherical_harmonics/legendre.py new file mode 100644 index 0000000..3a4ca8f --- /dev/null +++ b/e3nn_jax/_src/spherical_harmonics/legendre.py @@ -0,0 +1,113 @@ +from functools import partial + +import jax +import jax.numpy as jnp + + +def legendre( + lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool = False +) -> jnp.ndarray: + r"""Associated Legendre polynomials. + + en.wikipedia.org/wiki/Associated_Legendre_polynomials + + Args: + lmax (int): maximum l value + x (jnp.ndarray): input array of shape ``(...)`` + phase (float): -1 or 1, multiplies by :math:`(-1)^m` + is_normalized (bool): True if the associated Legendre functions are normalized. + + Returns: + jnp.ndarray: Associated Legendre polynomials ``P(l,m)`` + In an array of shape ``(lmax + 1, lmax + 1, ...)`` + """ + x = jnp.asarray(x) + return _legendre(lmax, x, phase, is_normalized) + + +@partial(jax.jit, static_argnums=(0, 3)) +def _legendre( + lmax: int, x: jnp.ndarray, phase: float, is_normalized: bool +) -> jnp.ndarray: + p = jax.scipy.special.lpmn_values( + lmax, lmax, x.flatten(), is_normalized + ) # [m, l, x] + p = (-phase) ** jnp.arange(lmax + 1)[:, None, None] * p + p = jnp.transpose(p, (1, 0, 2)) # [l, m, x] + p = jnp.reshape(p, (lmax + 1, lmax + 1) + x.shape) + return p + + +def _sh_alpha(l: int, alpha: jnp.ndarray) -> jnp.ndarray: + r"""Alpha dependence of spherical harmonics. + + Args: + l: l value + alpha: input array of shape ``(...)`` + + Returns: + Array of shape ``(..., 2 * l + 1)`` + """ + alpha = alpha[..., None] # [..., 1] + m = jnp.arange(1, l + 1) # [1, 2, 3, ..., l] + cos = jnp.cos(m * alpha) # [..., m] + + m = jnp.arange(l, 0, -1) # [l, l-1, l-2, ..., 1] + sin = jnp.sin(m * alpha) # [..., m] + + return jnp.concatenate( + [ + jnp.sqrt(2) * sin, + jnp.ones_like(alpha), + jnp.sqrt(2) * cos, + ], + axis=-1, + ) + + +def _sh_beta(lmax: int, cos_betas: jnp.ndarray) -> jnp.ndarray: + r"""Beta dependence of spherical harmonics. + + Args: + lmax: l value + cos_betas: input array of shape ``(...)`` + + Returns: + Array of shape ``(..., l, m)`` + """ + sh_y = legendre(lmax, cos_betas, phase=1.0, is_normalized=True) # [l, m, ...] + sh_y = jnp.moveaxis(sh_y, 0, -1) # [m, ..., l] + sh_y = jnp.moveaxis(sh_y, 0, -1) # [..., l, m] + return sh_y + + +def legendre_spherical_harmonics( + lmax: int, x: jnp.ndarray, normalize: bool, normalization: str +) -> jnp.ndarray: + alpha = jnp.arctan2(x[..., 0], x[..., 2]) + sh_alpha = _sh_alpha(lmax, alpha) # [..., 2 * l + 1] + + n = jnp.linalg.norm(x, axis=-1, keepdims=True) + x = x / jnp.where(n > 0, n, 1.0) + + sh_y = _sh_beta(lmax, x[..., 1]) # [..., l, m] + + sh = jnp.zeros(x.shape[:-1] + ((lmax + 1) ** 2,), x.dtype) + + def f(l, sh): + def g(m, sh): + y = sh_y[..., l, jnp.abs(m)] + if not normalize: + y = y * n[..., 0] ** l + if normalization == "norm": + y = y * (jnp.sqrt(4 * jnp.pi) / jnp.sqrt(2 * l + 1)) + elif normalization == "component": + y = y * jnp.sqrt(4 * jnp.pi) + + a = sh_alpha[..., lmax + m] + return sh.at[..., l**2 + l + m].set(y * a) + + return jax.lax.fori_loop(-l, l + 1, g, sh) + + sh = jax.lax.fori_loop(0, lmax + 1, f, sh) + return sh diff --git a/e3nn_jax/_src/spherical_harmonics/recursive.py b/e3nn_jax/_src/spherical_harmonics/recursive.py new file mode 100644 index 0000000..30c860d --- /dev/null +++ b/e3nn_jax/_src/spherical_harmonics/recursive.py @@ -0,0 +1,111 @@ +import math +from typing import Dict, Tuple + +import jax.numpy as jnp +import sympy + +import e3nn_jax as e3nn +from e3nn_jax._src.utils.sympy import sqrtQarray_to_sympy + + +def recursive_spherical_harmonics( + l: int, + context: Dict[int, jnp.ndarray], + input: jnp.ndarray, + normalization: str, + algorithm: Tuple[str], +) -> sympy.Array: + context.update(dict(jnp=jnp, clebsch_gordan=e3nn.clebsch_gordan)) + + if l == 0: + if 0 not in context: + if normalization == "integral": + context[0] = math.sqrt(1 / (4 * math.pi)) * jnp.ones_like( + input[..., :1] + ) + elif normalization == "component": + context[0] = jnp.ones_like(input[..., :1]) + else: + context[0] = jnp.ones_like(input[..., :1]) + + return sympy.Array([1]) + + if l == 1: + if 1 not in context: + if normalization == "integral": + context[1] = math.sqrt(3 / (4 * math.pi)) * input + elif normalization == "component": + context[1] = math.sqrt(3) * input + else: + context[1] = input + + return sympy.Array([1, 0, 0]) + + def sh_var(l): + return [sympy.symbols(f"sh{l}_{m}") for m in range(2 * l + 1)] + + l2 = biggest_power_of_two(l - 1) + l1 = l - l2 + + w = sqrtQarray_to_sympy(e3nn.clebsch_gordan(l1, l2, l)) + yx = sympy.Array( + [ + sum( + sh_var(l1)[i] * sh_var(l2)[j] * w[i, j, k] + for i in range(2 * l1 + 1) + for j in range(2 * l2 + 1) + ) + for k in range(2 * l + 1) + ] + ) + + sph_1_l1 = recursive_spherical_harmonics( + l1, context, input, normalization, algorithm + ) + sph_1_l2 = recursive_spherical_harmonics( + l2, context, input, normalization, algorithm + ) + + y1 = yx.subs(zip(sh_var(l1), sph_1_l1)).subs(zip(sh_var(l2), sph_1_l2)) + norm = sympy.sqrt(sum(y1.applyfunc(lambda x: x**2))) + y1 = y1 / norm + + if l not in context: + if normalization == "integral": + x = math.sqrt((2 * l + 1) / (4 * math.pi)) / ( + math.sqrt((2 * l1 + 1) / (4 * math.pi)) + * math.sqrt((2 * l2 + 1) / (4 * math.pi)) + ) + elif normalization == "component": + x = math.sqrt((2 * l + 1) / ((2 * l1 + 1) * (2 * l2 + 1))) + else: + x = 1 + + w = (x / float(norm)) * e3nn.clebsch_gordan(l1, l2, l) + w = w.astype(input.dtype) + + if "dense" in algorithm: + context[l] = jnp.einsum("...i,...j,ijk->...k", context[l1], context[l2], w) + elif "sparse" in algorithm: + context[l] = jnp.stack( + [ + sum( + [ + w[i, j, k] * context[l1][..., i] * context[l2][..., j] + for i in range(2 * l1 + 1) + for j in range(2 * l2 + 1) + if w[i, j, k] != 0 + ] + ) + for k in range(2 * l + 1) + ], + axis=-1, + ) + else: + raise ValueError("Unknown algorithm: must be 'dense' or 'sparse'") + + return y1 + + +def biggest_power_of_two(n): + return 2 ** (n.bit_length() - 1) diff --git a/e3nn_jax/_src/spherical_harmonics_test.py b/e3nn_jax/_src/spherical_harmonics/sh_test.py similarity index 100% rename from e3nn_jax/_src/spherical_harmonics_test.py rename to e3nn_jax/_src/spherical_harmonics/sh_test.py