Skip to content

Commit

Permalink
Formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
ameya98 committed Sep 28, 2024
1 parent 68e5d84 commit aa15f91
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
6 changes: 3 additions & 3 deletions e3nn_jax/_src/linear_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ class Linear(flax.linen.Module):
gradient_normalization: Optional[Union[float, str]] = None
path_normalization: Optional[Union[float, str]] = None
biases: bool = False
parameter_initializer: Optional[Callable[[], jax.nn.initializers.Initializer]] = (
None
)
parameter_initializer: Optional[
Callable[[], jax.nn.initializers.Initializer]
] = None
instructions: Optional[List[Tuple[int, int]]] = None
num_indexed_weights: Optional[int] = None
weights_per_channel: bool = False
Expand Down
11 changes: 8 additions & 3 deletions tests/_src/s2grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,11 @@ def test_integrate_spherical_harmonics(key: int, degree: int):
expected_integral = 0.0

assert jnp.isclose(integral, expected_integral, atol=1e-5, rtol=1e-5), (
integral, expected_integral
integral,
expected_integral,
)


@pytest.mark.parametrize("degree", range(10))
@pytest.mark.parametrize("key", range(3))
def test_integrate_polynomials(key: int, degree: int):
Expand Down Expand Up @@ -389,11 +391,14 @@ def f(coords):
else:
alphas = jnp.asarray([x_degree, y_degree, z_degree])
alphas = (alphas + 1) / 2
log_dirichlet = jnp.sum(jax.scipy.special.gammaln(alphas)) - jax.scipy.special.gammaln(jnp.sum(alphas))
log_dirichlet = jnp.sum(
jax.scipy.special.gammaln(alphas)
) - jax.scipy.special.gammaln(jnp.sum(alphas))
expected_integral = 2 * jnp.exp(log_dirichlet)

assert jnp.isclose(integral, expected_integral, atol=1e-5, rtol=1e-5), (
integral, expected_integral
integral,
expected_integral,
)


Expand Down

0 comments on commit aa15f91

Please sign in to comment.