diff --git a/e3nn_jax/_src/basic.py b/e3nn_jax/_src/basic.py index e187c01..e57585d 100644 --- a/e3nn_jax/_src/basic.py +++ b/e3nn_jax/_src/basic.py @@ -648,7 +648,9 @@ def normal( raise ValueError("Normalization needs to be 'norm' or 'component'") -def where(mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray) -> e3nn.IrrepsArray: +def where( + mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray +) -> e3nn.IrrepsArray: """ Selects elements from `x` or `y`, depending on `mask`.