Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed Aug 14, 2024
1 parent 0c7fe1c commit 129706a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion e3nn_jax/_src/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ def normal(
raise ValueError("Normalization needs to be 'norm' or 'component'")


def where(mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
def where(
mask: jax.Array, x: e3nn.IrrepsArray, y: e3nn.IrrepsArray
) -> e3nn.IrrepsArray:
"""
Selects elements from `x` or `y`, depending on `mask`.
Expand Down

0 comments on commit 129706a

Please sign in to comment.