Skip to content

Commit

Permalink
add e3nn.scatter_mean
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Oct 9, 2023
1 parent 2e2c33f commit 2df7fbb
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 11 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- `e3nn.flax.BatchNorm`
- `e3nn.scatter_mean`


## [0.20.2] - 2023-09-25
Expand Down
3 changes: 2 additions & 1 deletion e3nn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
)
from e3nn_jax._src.gate import gate
from e3nn_jax._src.radius_graph import radius_graph
from e3nn_jax._src.scatter import index_add, scatter_sum, scatter_max
from e3nn_jax._src.scatter import index_add, scatter_sum, scatter_mean, scatter_max
from e3nn_jax._src.reduced_tensor_product import (
reduced_tensor_product_basis,
reduced_symmetric_tensor_product_basis,
Expand Down Expand Up @@ -201,6 +201,7 @@
"radius_graph",
"index_add",
"scatter_sum",
"scatter_mean",
"scatter_max",
"poly_envelope",
"soft_envelope",
Expand Down
83 changes: 77 additions & 6 deletions e3nn_jax/_src/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,19 @@ def scatter_sum(
) -> Union[jnp.ndarray, e3nn.IrrepsArray]:
r"""Scatter sum of data.
Performs either of the following two operations:
``output[dst[i]] += data[i]`` or ``output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])])``
Performs either of the following two operations::
output[dst[i]] += data[i]
or::
output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])])
Args:
data (`jax.numpy.ndarray` or `IrrepsArray`): array of shape ``(n, ...)``
dst (optional, `jax.numpy.ndarray`): array of shape ``(n,)``. If not specified, ``nel`` must be specified.
data (`jax.numpy.ndarray` or `IrrepsArray`): array of shape ``(n1,..nd, ...)``
dst (optional, `jax.numpy.ndarray`): array of shape ``(n1,..nd)``. If not specified, ``nel`` must be specified.
nel (optional, `jax.numpy.ndarray`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified.
output_size (optional, int): size of output array. If not specified, ``nel`` must be specified
or ``map_back`` must be ``True``.
output_size (optional, int): size of output array.
If not specified, ``nel`` must be specified or ``map_back`` must be ``True``.
map_back (bool): whether to map back to the input position
Returns:
Expand All @@ -62,6 +66,73 @@ def scatter_sum(
)


def scatter_mean(
data: Union[jnp.ndarray, e3nn.IrrepsArray],
*,
dst: Optional[jnp.ndarray] = None,
nel: Optional[jnp.ndarray] = None,
output_size: Optional[int] = None,
map_back: bool = False,
mode: str = "promise_in_bounds",
) -> Union[jnp.ndarray, e3nn.IrrepsArray]:
r"""Scatter mean of data.
Performs either of the following two operations::
n[dst[i]] += 1
output[dst[i]] += data[i] / n[i]
or::
output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])]) / nel[i]
Args:
data (`jax.numpy.ndarray` or `IrrepsArray`): array of shape ``(n1,..nd, ...)``
dst (optional, `jax.numpy.ndarray`): array of shape ``(n1,..nd)``. If not specified, ``nel`` must be specified.
nel (optional, `jax.numpy.ndarray`): array of shape ``(output_size,)``. If not specified, ``dst`` must be specified.
output_size (optional, int): size of output array.
If not specified, ``nel`` must be specified or ``map_back`` must be ``True``.
map_back (bool): whether to map back to the input position
Returns:
`jax.numpy.ndarray` or `IrrepsArray`: output array of shape ``(output_size, ...)``
"""
total = _scatter_op(
"sum",
0.0,
data,
dst=dst,
nel=nel,
output_size=output_size,
map_back=map_back,
mode=mode,
)

if dst is not None or map_back:
if dst is not None:
ones = jnp.ones(data.shape[: dst.ndim], jnp.int32)
if nel is not None:
ones = jnp.ones(data.shape[:1], jnp.int32)

nel = _scatter_op(
"sum",
0.0,
ones,
dst=dst,
nel=nel,
output_size=output_size,
map_back=map_back,
mode=mode,
)

nel = jnp.maximum(1, nel)

for _ in range(total.ndim - nel.ndim):
nel = nel[..., None]

return total / nel.astype(total.dtype)


def scatter_max(
data: Union[jnp.ndarray, e3nn.IrrepsArray],
*,
Expand Down
46 changes: 42 additions & 4 deletions e3nn_jax/_src/scatter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,52 @@ def test_scatter_sum():
jnp.array([-9.0, 5.0, 5.0, -9.0])[:, None],
)

x = jnp.array([1.0, 2.0, 1.0, 0.5, 0.5, 0.7, 0.2, 0.1])
nel = jnp.array([3, 2, 3])
np.testing.assert_allclose( # nel
e3nn.scatter_sum(
jnp.array([1.0, 2.0, 1.0, 0.5, 0.5, 0.7, 0.2, 0.1]),
nel=jnp.array([3, 2, 3]),
),
e3nn.scatter_sum(x, nel=nel),
jnp.array([4.0, 1.0, 1.0]),
)

np.testing.assert_allclose( # nel + map_back
e3nn.scatter_sum(x, nel=nel, map_back=True),
jnp.array([4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
)

i = jnp.array([[0, 2], [2, 0]])
x = jnp.array([[[1.0, 0.0], [2.0, 1.0]], [[3.0, 0.0], [-10.0, -1.0]]])
np.testing.assert_allclose(
e3nn.scatter_sum(x, dst=i, output_size=3),
jnp.array([[-9.0, -1.0], [0.0, 0.0], [5.0, 1.0]]),
)


def test_scatter_mean():
x = jnp.array([[2.0, 3.0], [0.0, 3.0], [-10.0, 42.0]])
dst = jnp.array([[0, 2], [2, 2], [0, 1]])

np.testing.assert_allclose( # dst
e3nn.scatter_mean(x, dst=dst, output_size=3),
jnp.array([-4.0, 42.0, 2.0]),
)

np.testing.assert_allclose( # map_back
e3nn.scatter_mean(x, dst=dst, map_back=True),
jnp.array([[-4.0, 2.0], [2.0, 2.0], [-4.0, 42.0]]),
)

x = jnp.array([10.0, 1.0, 2.0, 3.0])
nel = jnp.array([1, 0, 3])
np.testing.assert_allclose( # nel
e3nn.scatter_mean(x, nel=nel),
jnp.array([10.0, 0.0, 2.0]),
)

np.testing.assert_allclose( # nel + map_back
e3nn.scatter_mean(x, nel=nel, map_back=True),
jnp.array([10.0, 2.0, 2.0, 2.0]),
)


def test_scatter_max():
jax.config.update("jax_debug_infs", False)
Expand Down

0 comments on commit 2df7fbb

Please sign in to comment.