diff --git a/swirl_dynamics/projects/ergodic/utils.py b/swirl_dynamics/projects/ergodic/utils.py index 0efa953..eff80f3 100644 --- a/swirl_dynamics/projects/ergodic/utils.py +++ b/swirl_dynamics/projects/ergodic/utils.py @@ -260,3 +260,47 @@ def plot_error_metrics( ax[i].set_title(metric_disply[metric]) ax[i].set_yscale("log") return {"err": fig} + + +def sample_uniform_spherical_shell( + n_points: int, + radii: tuple[float, float], + shape: tuple[int, ...], + key: jax.random.KeyArray = jax.random.PRNGKey(0), +): + """Uniform sampling (in angle and radius) from an spherical shell. + + Arguments: + n_points: Number of points to sample. + radii: Interior and exterior radii of the spherical shell. + shape: Shape of the points to sample. + key: Random key for generating the random numbers. + + Returns: + A vector of size (n_points,) + shape, within the spherical shell. The + vector is chosen uniformly in both angle and radius. + """ + + inner_radius, outer_radius = radii + + # Shape to help broadcasting. + broadcasting_shape = (n_points,) + len(shape) * (1,) + # Obtain the correct axis for the sum, depending on the shape. + # Here we suppose that shape comes in the form (nx, ny, d) or (nx, d). + assert len(shape) < 4 and len(shape) >= 2, ("The shape should represent ", + "one- or two-dimensional points.", + f" Instead we have shape {shape}") + + axis_sum = (1,) if len(shape) == 2 else (1, 2,) + + key_radius, key_vec = jax.random.split(key) + + sampling_radius = jax.random.uniform(key_radius, (n_points,), + minval=inner_radius, + maxval=outer_radius) + vec = jax.random.normal(key_vec, shape=((n_points,) + shape)) + + vec_norm = jnp.linalg.norm(vec, axis=axis_sum).reshape(broadcasting_shape) + vec /= vec_norm + + return vec * sampling_radius.reshape(broadcasting_shape)