Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561202548
  • Loading branch information
The swirl_dynamics Authors committed Aug 30, 2023
1 parent a536997 commit 8aeccd9
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 3 deletions.
4 changes: 3 additions & 1 deletion swirl_dynamics/projects/ergodic/configs/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_config():
config.stride = 1
config.normalize = False
config.add_noise = False
config.use_sobolev = True
# Model params
config.integrator = 'OneStepDirect'
config.model = 'PeriodicConvNetModel'
Expand All @@ -47,7 +48,8 @@ def get_config():
config.time_steps_increase_per_cycle = 0
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
config.measure_dist_type = 'MMD_DIST' # Sweepable
config.measure_dist_type = 'MMD' # Sweepable
# config.measure_dist_type = 'MMD_DIST' # Sweepable
config.regularize_measure_dist = False # Sweepable
config.regularize_measure_dist_k = True # Sweepable
config.measure_dist_lambda = 1.0 # Sweepable
Expand Down
8 changes: 7 additions & 1 deletion swirl_dynamics/projects/ergodic/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class KS1dModelConfig:
measure_dist: choices.MeasureDistance
use_pushfwd: bool = False
add_noise: bool = False
use_sobolev: bool = False
noise_level: float = 1e-3


Expand Down Expand Up @@ -121,7 +122,11 @@ def loss_fn(
measure_dist_k = jnp.mean(
jax.vmap(self.measure_dist, in_axes=(1, 1))(pred, true)
)
l2 = jnp.mean(jnp.square(pred - true))

if self.conf.use_sobolev:
l2 = utils.sobolev_norm(pred - true, dim=1, s=1, length=32.)
else:
l2 = jnp.mean(jnp.square(pred - true))

# This is a scalar.
loss = l2
Expand Down Expand Up @@ -247,6 +252,7 @@ def pipeline(conf: ml_collections.ConfigDict) -> PipelinePayload:
measure_dist=choices.MeasureDistance(conf.measure_dist_type),
use_pushfwd=conf.use_pushfwd,
add_noise=conf.add_noise,
use_sobolev=conf.use_sobolev,
)
trainer_config = KS1dTrainerConfig(
time_stride=conf.time_stride,
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/ergodic/measure_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def mmd_distributed(x: Array, y: Array) -> Array:

# we merge the batch per device and device dimensions
x = jnp.reshape(x, (x_shape[0]*x_shape[1],) + x_shape[2:])
y = jnp.reshape(y, (y_shape[0]*y_shape[1],) + x_shape[2:])
y = jnp.reshape(y, (y_shape[0]*y_shape[1],) + y_shape[2:])

return mmd(x, y)

Expand Down
64 changes: 64 additions & 0 deletions swirl_dynamics/projects/ergodic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,67 @@ def load_data_from_hdf5(
batch_fn=tfgrain.TfBatch(batch_size=batch_size, drop_remainder=False),
)
return loader


# TODO(lzepedanunez): find a better place for this function and refactor with
# vmap.
def sobolev_norm(
u: Array, s: int = 1, dim: int = 2, length: float = 1.
):
r"""Sobolev Norm computed via the Plancherel equality.
Arguments:
u: Input to compute the norm (n_batch, n_frames, n_x, n_y, d).
s: Order of the Sobolev norm.
dim: Dimension of the input (either 1 or 2).
length: The length of the domain (we assume that we have a square.)
Returns:
The average H^s squared along trajectories and batch size.
We compute the Sobolov norm using the Fourier Transform following:
\| x \|_{H^s}^2 = \int (\sum_{i=0}^s \|k\|^2i)) | \hat{u}(k) |^2 dk.
In particular we assemble the multipliers, and we approximate the quadrature
using a trapezoidal rule.
"""
n_x = u.shape[-2]
k_x = jnp.fft.fftfreq(n_x, length / (2 * jnp.pi * n_x))

# Reusing the same expression for both one and two dimensional.
axes = (-2,) if dim == 1 else (-3, -2)
u_fft = jnp.fft.fftn(u, axes=axes)

# Computing the base multiplier: \| k \|^2.
if dim == 1:
multiplier = jnp.square(k_x)
multiplier = multiplier.reshape(
len(u.shape[:-2]) * (1,) + (n_x, u.shape[-1])
)
elif dim == 2:
k_x, k_y = jnp.meshgrid(k_x, k_x)
multiplier = jnp.square(k_x) + jnp.square(k_y)
multiplier = multiplier.reshape(
len(u.shape[:-3]) * (1,) + (n_x, n_x, u.shape[-1])
)
else:
raise ValueError(f"Unsupported dim: {dim}")

# Computing the different in Fourier space | \hat{u}(k) |^2.
u_fft_squared = jnp.square(jnp.abs(u_fft))

# Building the set of multipliers following:
# \left ( \sum_{i=0}^s \| k \|^{2i} \right).
mult = jnp.sum(
jnp.power(
multiplier[..., None], # add an extra dimension for broadcasting.
jnp.arange(s + 1).reshape(len(multiplier.shape) * (1,) + (-1,)),
),
axis=-1,
)

# Performing the integration using trapezoidal rule.
norm_squared = jnp.sum(mult * u_fft_squared, axis=axes) / (n_x)**dim

# Returns the mean.
return jnp.mean(norm_squared)

0 comments on commit 8aeccd9

Please sign in to comment.