Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561182058
  • Loading branch information
The swirl_dynamics Authors committed Aug 30, 2023
1 parent e3ed675 commit a536997
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
17 changes: 13 additions & 4 deletions swirl_dynamics/projects/ergodic/configs/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Default Hyperparameter configuration."""
r"""Default Hyperparameter configuration.
Usage:
gxm third_party/py/swirl_dynamics/projects/ergodic/google/xm_launch.py \
--exp=test_sequential_sobolev_norm \
--config=third_party/py/swirl_dynamics/projects/ergodic/configs/ks_1d.py \
--xm_resource_alloc=group:research-training/sim-research-xm \
--platform=v100=1 --cell=oi --priority=200
"""

import ml_collections

Expand All @@ -39,10 +47,10 @@ def get_config():
config.time_steps_increase_per_cycle = 0
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
config.measure_dist_type = 'MMD' # Sweepable
config.measure_dist_type = 'MMD_DIST' # Sweepable
config.regularize_measure_dist = False # Sweepable
config.regularize_measure_dist_k = False # Sweepable
config.measure_dist_lambda = 0.0 # Sweepable
config.regularize_measure_dist_k = True # Sweepable
config.measure_dist_lambda = 1.0 # Sweepable
config.measure_dist_step_start = 0
# Train params
config.experiment = 'ks_1d'
Expand Down Expand Up @@ -79,6 +87,7 @@ def skip(


# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
# use option --sweep=False in the command line to avoid sweeping
def sweep(add):
"""Define param sweep."""
for seed in [1, 11, 21, 42, 84]:
Expand Down
53 changes: 42 additions & 11 deletions swirl_dynamics/projects/ergodic/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,30 +77,42 @@ def loss_fn(
mutables: PyTree,
):
"""Computes training loss and metrics."""
# Shape: (batch_size, num_rollouts, num_states, 1)
true = batch["true"][:, 1:, :]
# Shape: (batch_size, 1, num_states, 1)
x0 = batch["x0"]
# we remove the extra dimension induced by tiling and the pmap reshape.
# Shape (1, num_rollouts)
tspan = batch["tspan"][0]
# The shape is (1, 1), but we require a scalar, so we extract the scalar
# from the vector this is due to the tiling and pmap reshape.
measure_dist_lambda = batch["measure_dist_lambda"][0][0]
measure_dist_lambda_k = batch["measure_dist_lambda_k"][0][0]

if self.conf.add_noise:
noise = self.conf.noise_level + jax.random.normal(rng, x0.shape)
x0 += noise
if self.conf.use_pushfwd:
# Rollout for t-1 steps with stop gradient
pred_pushfwd = jax.lax.stop_gradient(
jax.vmap(self.pred_integrator, in_axes=(0, None, None))(
x0, batch["tspan"][:-1], dict(params=params, **mutables)
x0, tspan[:-1], dict(params=params, **mutables)
)
)[:, -1, :]
# Pushforward for final step
pred = jax.vmap(self.pred_integrator, in_axes=(0, None, None))(
pred_pushfwd, batch["tspan"][-2:], dict(params=params, **mutables)
pred_pushfwd, tspan[-2:], dict(params=params, **mutables)
)[:, -1, :]
# Compare to true trajectory last step
true = true[:, -1, :]

# Compute distance between measures.
measure_dist = self.measure_dist(pred, batch["x0"])
measure_dist_k = self.measure_dist(pred, true)

else: # Regular unrolling without stop-gradient
pred = jax.vmap(self.pred_integrator, in_axes=(0, None, None))(
x0, batch["tspan"], dict(params=params, **mutables)
x0, tspan, dict(params=params, **mutables)
)[:, 1:, :]
measure_dist = jnp.mean(
jax.vmap(
Expand All @@ -110,15 +122,19 @@ def loss_fn(
jax.vmap(self.measure_dist, in_axes=(1, 1))(pred, true)
)
l2 = jnp.mean(jnp.square(pred - true))

# This is a scalar.
loss = l2
loss += batch["measure_dist_lambda"] * measure_dist
loss += batch["measure_dist_lambda_k"] * measure_dist_k

loss += measure_dist_lambda * measure_dist
loss += measure_dist_lambda_k * measure_dist_k

metric = dict(
loss=loss,
l2=l2,
measure_dist=measure_dist,
measure_dist_k=measure_dist_k,
rollout=batch["tspan"].shape[0]-1
rollout=tspan.shape[0]-1
)
return loss, (metric, mutables)

Expand All @@ -139,7 +155,8 @@ class KS1dTrainerConfig:
time_steps_increase_per_cycle: int = 1


class KS1dTrainer(trainers.BasicTrainer):
# class KS1dTrainer(trainers.BasicTrainer):
class KS1dTrainer(trainers.BasicDistributedTrainer):
"""Kuramoto Sivashinsky 1D trainer."""

@flax.struct.dataclass
Expand Down Expand Up @@ -183,13 +200,27 @@ def preprocess_train_batch(self, batch_data, step, rng):
measure_dist *= (step > self.conf.measure_dist_step_start)
measure_dist_k = self.conf.regularize_measure_dist_k
measure_dist_k *= step > self.conf.measure_dist_step_start
return dict(
batch_dict = dict(
x0=batch_data["u"][:, 0, :, :],
true=batch_data["u"][:, :num_time_steps_gt:time_stride, :, :],
tspan=tspan,
measure_dist_lambda=self.conf.measure_dist_lambda * measure_dist,
measure_dist_lambda_k=self.conf.measure_dist_lambda * measure_dist_k,
tspan=np.tile(tspan, (jax.device_count(), 1)),
measure_dist_lambda=np.tile(
self.conf.measure_dist_lambda * measure_dist,
(jax.device_count(), 1)),
measure_dist_lambda_k=np.tile
(self.conf.measure_dist_lambda * measure_dist_k,
(jax.device_count(), 1)),
)
return jax.jit(trainers.reshape_for_pmap)(batch_dict)


# TODO(yairschiff): add this extra
# class DistributedKS1dTrainer(
# KS1dTrainer,
# trainers.BasicDistributedTrainer[KS1dModel, TrainState],
# ):
# # MRO: KS1dTrainer > BasicDistributedTrainer > BasicTrainer
# ...


def pipeline(conf: ml_collections.ConfigDict) -> PipelinePayload:
Expand Down

0 comments on commit a536997

Please sign in to comment.