Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561654476
  • Loading branch information
The swirl_dynamics Authors committed Aug 31, 2023
1 parent 8aeccd9 commit ca86207
Show file tree
Hide file tree
Showing 16 changed files with 1,673 additions and 798 deletions.
80 changes: 80 additions & 0 deletions swirl_dynamics/lib/solvers/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax.experimental import checkify
from jax.experimental import ode
import jax.numpy as jnp
import numpy as np

Array = jax.Array
PyTree = Any
Expand Down Expand Up @@ -136,3 +137,82 @@ def __call__(
mxstep=self.max_steps,
hmax=self.max_dt,
)


class MultiStepScanOdeSolver:
"""ODE solver based on `jax.lax.scan` that uses one than one time step.
Rather than x_{n+1} = f(x_n, t_n), we have
x_{n+1} = f(x_{n-k}, x_{n-k+1}, ..., x_{n-1}, x_n, t_{n-k}, ..., t_{n-1}, t_n)
for some 'num_lookback_steps' window k.
"""

@staticmethod
def stack_timesteps_along_channel_dim(x: Array) -> Array:
"""Helper method to package batches for multi-step solvers.
Args:
x: Array of shape: (num_lookback_steps, state_dims, channels), where
state_dims can have ndim >= 1
Returns:
Array of shape (state_dims, num_lookback_steps*channels)
"""
orig_shape = x.shape
num_lookback_steps = orig_shape[0]
stacked_shape = list(orig_shape)
# All the previous timesteps are collapsed into one
stacked_shape[0] = 1
# Concatenate steps along channel dim
stacked_shape[-1] *= num_lookback_steps
# For state_dims with ndim > 1, e.g. 2D grids, flatten state dims:
flattened_state_size = np.prod(list(x.shape[1:-1]))
x_flattened_state_shape = (
(x.shape[0],) + (flattened_state_size,) + (x.shape[-1],)
)
x = x.reshape(x_flattened_state_shape)
return x.swapaxes(0, 1).reshape(stacked_shape).squeeze(axis=0)

def step(
self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree
) -> Array:
"""Advances the current state one step forward in time."""
raise NotImplementedError

def __call__(
self,
func: OdeDynamics,
x0: Array,
tspan: Array,
params: PyTree,
) -> Array:
"""Solves an ODE at given time stamps by using k previous steps."""

def scan_fun(
state: tuple[Array, Array], t_next: Array
) -> tuple[tuple[Array, Array], Array]:
# x0 assumed to have shape: (lookback, state_dims, channels)
x0, t0 = state
x0_stack = self.stack_timesteps_along_channel_dim(x0)[None, ...]
# input to func has shape: (state_dims, channels*lookback)
dt = t_next - t0
# return item (x_next) has shape: state_dims x channels
x_next = self.step(func, x0_stack, t0, dt, params)
# carry item has same shape as x0, where we first shift over the original
# input and append the new predicted state along the time dimension
x_carry = jnp.concatenate([x0[1:, ...], x_next], axis=0)
return (x_carry, t_next), x_next.squeeze(axis=0)

_, out = jax.lax.scan(scan_fun, (x0, tspan[0]), tspan[1:])
# output of scan has shape: len(tspan) - 1 x state_dims x channels
return jnp.concatenate([x0, out], axis=0)


class MultiStepDirect(MultiStepScanOdeSolver):
"""Solver that directly returns function output as next time step."""

def step(
self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree
) -> Array:
"""Performs a single prediction step."""
return func(x0, t0, params)
36 changes: 36 additions & 0 deletions swirl_dynamics/lib/solvers/ode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from absl.testing import absltest
from absl.testing import parameterized
import jax
import jax.numpy as jnp
import numpy as np
from swirl_dynamics.lib.solvers import ode
Expand Down Expand Up @@ -52,5 +53,40 @@ def test_dopri45_backward_error(self, tspan):
ode.DoPri45()(dummy_ode_dynamics, jnp.array(0), tspan, {})


class MultiStepOdeSolversTest(parameterized.TestCase):

@parameterized.product(
state_dim=((512,), (64, 64), (32, 32, 32)),
channels=(1, 2, 3),
num_lookback_steps=(2, 4, 8),
)
def test_stacked_output_shape_and_value(
self,
state_dim,
channels,
num_lookback_steps,
):
rng = jax.random.PRNGKey(0)
input_shape = (num_lookback_steps,) + state_dim + (channels,)
input_state = jax.random.normal(rng, input_shape)
input_state_stacked = (
ode.MultiStepScanOdeSolver.stack_timesteps_along_channel_dim(
input_state
)
)
# Check that expected shapes match
self.assertEqual(
input_state_stacked.shape, state_dim + (channels * num_lookback_steps,)
)
# Check that timesteps were correctly concatenated along channel dim
for w in range(num_lookback_steps):
c_start = channels * w
c_end = channels * (w + 1)
np.testing.assert_array_equal(
input_state[w, ...],
input_state_stacked[..., c_start:c_end],
)


if __name__ == "__main__":
absltest.main()
81 changes: 66 additions & 15 deletions swirl_dynamics/projects/ergodic/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,71 +12,104 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training choices."""
"""Training choices.
enum.Enum classes that define the various experiment choices, such as system,
learned model integrator, neural architecture, etc.
"""

from collections.abc import Callable
import enum
from typing import Union
import functools

from flax import linen as nn
import jax
import ml_collections
from swirl_dynamics.lib.networks import convnets
from swirl_dynamics.lib.networks import fno
from swirl_dynamics.lib.networks import nonlinear_fourier
from swirl_dynamics.lib.solvers import ode
from swirl_dynamics.projects.ergodic import measure_distances


Array = jax.Array
MeasureDistFn = Callable[[Array, Array], Union[Array, float]]
MeasureDistFn = Callable[[Array, Array], float | Array]


class Experiment(enum.Enum):
"""Experiment choices."""
"""Experiment choices.
1. Lorenz 63 system (lorenz63)
2. Kuramoto-Sivashinsky system, on a 1D grid (ks_1d)
3. Navier-Stokes with Kolmogorov forcing, on a 2D grid (ns_2d)
"""

L63 = "lorenz63"
KS_1D = "ks_1d"
NS_2D = "ns_2d"


class Integrator(enum.Enum):
"""Integrator choices."""

EULER = "ExplicitEuler"
RK4 = "RungeKutta4"
ONE_STEP_DIRECT = "OneStepDirect"
MULTI_STEP_DIRECT = "MultiStepDirect"

def dispatch(self):
def dispatch(
self,
) -> type[ode.ScanOdeSolver] | type[ode.MultiStepScanOdeSolver]:
"""Dispatch integator.
Returns:
ScanOdeSolver
ScanOdeSolver | MultiStepScanOdeSolver
"""
return {
"ExplicitEuler": ode.ExplicitEuler,
"RungeKutta4": ode.RungeKutta4,
"OneStepDirect": ode.OneStepDirect,
"MultiStepDirect": ode.MultiStepDirect,
}[self.value]


class MeasureDistance(enum.Enum):
"""Measure distance choices."""

MMD = "MMD"
MMD_DIST = "MMD_DIST"
SD = "SD"

def dispatch(self) -> MeasureDistFn:
def dispatch(
self,
downsample_factor: int = 1,
) -> measure_distances.MeasureDistFn:
"""Dispatch measure distance.
Args:
downsample_factor: downsample factor for empirical distribution samples.
Returns:
Measure distance function.
"""
return {
dist_fn = {
"MMD": measure_distances.mmd,
"MMD_DIST": measure_distances.mmd_distributed,
"SD": measure_distances.sinkhorn_div
"SD": measure_distances.sinkhorn_div,
}[self.value]
if downsample_factor > 1:
return functools.partial(
measure_distances.spatial_downsampled_dist,
dist_fn,
spatial_downsample=downsample_factor,
)
return dist_fn


class Model(enum.Enum):
"""Model choices."""

FNO = "FNO"
MLP = "MLP"
PERIODIC_CONV_NET_MODEL = "PeriodicConvNetModel"

Expand All @@ -90,11 +123,29 @@ def dispatch(self, conf: ml_collections.ConfigDict) -> nn.Module:
nn.Module
"""
if self.value == Model.MLP.value:
return nonlinear_fourier.MLP(
features=conf.mlp_sizes,
act_fn=nn.swish
)
return nonlinear_fourier.MLP(features=conf.mlp_sizes, act_fn=nn.swish)
if self.value == Model.PERIODIC_CONV_NET_MODEL.value:
# TODO(yairschiff): Add convent args to config.
return convnets.PeriodicConvNetModel()
return convnets.PeriodicConvNetModel(
latent_dim=conf.latent_dim,
num_levels=conf.num_levels,
num_processors=conf.num_processors,
encoder_kernel_size=conf.encoder_kernel_size,
decoder_kernel_size=conf.decoder_kernel_size,
processor_kernel_size=conf.processor_kernel_size,
padding=conf.padding,
is_input_residual=conf.is_input_residual,
)
if self.value == Model.FNO.value:
return fno.Fno(
out_channels=conf.out_channels,
hidden_channels=conf.hidden_channels,
num_modes=conf.num_modes,
lifting_channels=conf.lifting_channels,
projection_channels=conf.projection_channels,
num_blocks=conf.num_blocks,
layers_per_block=conf.layers_per_block,
block_skip_type=conf.block_skip_type,
fft_norm=conf.fft_norm,
separable=conf.separable,
)
raise ValueError()
Loading

0 comments on commit ca86207

Please sign in to comment.