Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561807285
  • Loading branch information
The swirl_dynamics Authors committed Sep 1, 2023
1 parent ca86207 commit e0402b1
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions swirl_dynamics/projects/ergodic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,47 @@ def plot_error_metrics(
ax[i].set_title(metric_disply[metric])
ax[i].set_yscale("log")
return {"err": fig}


def sample_uniform_spherical_shell(
n_points: int,
radii: tuple[float, float],
shape: tuple[int, ...],
key: jax.random.KeyArray = jax.random.PRNGKey(0),
):
"""Uniform sampling (in angle and radius) from an spherical shell.
Arguments:
n_points: Number of points to sample.
radii: Interior and exterior radii of the spherical shell.
shape: Shape of the points to sample.
key: Random key for generating the random numbers.
Returns:
A vector of size (n_points,) + shape, within the spherical shell. The
vector is chosen uniformly in both angle and radius.
"""

inner_radius, outer_radius = radii

# Shape to help broadcasting.
broadcasting_shape = (n_points,) + len(shape) * (1,)
# Obtain the correct axis for the sum, depending on the shape.
# Here we suppose that shape comes in the form (nx, ny, d) or (nx, d).
assert len(shape) < 4 and len(shape) >= 2, ("The shape should represent ",
"one- or two-dimensional points.",
f" Instead we have shape {shape}")

axis_sum = (1,) if len(shape) == 2 else (1, 2,)

key_radius, key_vec = jax.random.split(key)

sampling_radius = jax.random.uniform(key_radius, (n_points,),
minval=inner_radius,
maxval=outer_radius)
vec = jax.random.normal(key_vec, shape=((n_points,) + shape))

vec_norm = jnp.linalg.norm(vec, axis=axis_sum).reshape(broadcasting_shape)
vec /= vec_norm

return vec * sampling_radius.reshape(broadcasting_shape)

0 comments on commit e0402b1

Please sign in to comment.