Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567737414
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Sep 22, 2023
1 parent 9bbddf9 commit 2a051d2
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 60 deletions.
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/diffusion/diffusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DiffusionTest(parameterized.TestCase):
def test_tangent_noise_schedule(self, clip_max, start, end):
sigma = diffusion.tangent_noise_schedule(clip_max, start, end)
self.assertAlmostEqual(sigma(1.0), clip_max, places=3)
self.assertEqual(sigma(0.0), 0)
self.assertAlmostEqual(sigma(0.0), 0, places=8)

test_points = np.random.default_rng(1234).uniform(0.05, 1.0, size=(10,))
np.testing.assert_allclose(
Expand Down
35 changes: 1 addition & 34 deletions swirl_dynamics/lib/networks/rational_networks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,6 @@ def test_rational_default_init(self):
params['q_coeffs'], expected_q_params, places=5
)

test_apply = rat_net.apply({'params': params}, x)

expected_apply = jnp.array([
-0.003531166352,
0.246322900057,
-0.021814700216,
0.018779350445,
0.006357953884,
-0.020405692980,
0.021597648039,
0.020564671606,
0.750805377960,
0.586633801460,
])
self.assertSequenceAlmostEqual(test_apply, expected_apply, places=5)

def test_unshared_rational_default_init(self):
"""Function to test that the rational networks are properly initialized."""
rat_net = rational_networks.UnsharedRationalLayer()
Expand Down Expand Up @@ -85,28 +69,11 @@ def test_unshared_rational_default_init(self):
places=5,
)

test_apply = rat_net.apply({'params': params}, x)

expected_apply = jnp.array([
-0.003531166352,
0.246322900057,
-0.021814700216,
0.018779350445,
0.006357953884,
-0.020405692980,
0.021597648039,
0.020564671606,
0.750805377960,
0.586633801460,
])
self.assertSequenceAlmostEqual(test_apply, expected_apply, places=5)


class RationalMLPTest(absltest.TestCase):

def test_number_params(self):
# Testing that the network has the correct number of parameters.

"""Tests that the network has the correct number of parameters."""
features = (2, 2)
periodic_mlp_small = rational_networks.RationalMLP(features=features)

Expand Down
10 changes: 5 additions & 5 deletions swirl_dynamics/lib/solvers/sde.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __call__(
dynamics: SdeDynamics,
x0: Array,
tspan: Array,
rng: jax.random.KeyArray,
rng: Array,
params: PyTree,
) -> Array:
"""Solves a SDE at given time stamps.
Expand Down Expand Up @@ -95,7 +95,7 @@ def step(
x0: Array,
t0: Array,
dt: Array,
rng: jax.random.KeyArray,
rng: Array,
params: SdeParams,
) -> Array:
"""Advances the current state one step forward in time."""
Expand All @@ -106,14 +106,14 @@ def __call__(
dynamics: SdeDynamics,
x0: Array,
tspan: Array,
rng: jax.random.KeyArray,
rng: Array,
params: SdeParams,
) -> Array:
"""Solves a SDE by integrating the step function with `jax.lax.scan`."""

def scan_fun(
state: tuple[Array, Array],
ext: tuple[Array, jax.random.KeyArray],
ext: tuple[Array, Array],
) -> tuple[tuple[Array, Array], Array]:
x0, t0 = state
t_next, step_rng = ext
Expand All @@ -138,7 +138,7 @@ def step(
x0: Array,
t0: Array,
dt: Array,
rng: jax.random.KeyArray,
rng: Array,
params: SdeParams,
) -> Array:
"""Makes one Euler-Maruyama integration step in time."""
Expand Down
24 changes: 13 additions & 11 deletions swirl_dynamics/templates/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def run(
eval_dataloader: Iterable[Any] | None = None,
eval_every_steps: int = 100,
num_batches_per_eval: int = 10,
# callbacks
# other configs
metric_writer: metric_writers.MultiWriter | None = None,
callbacks: Sequence[cb.Callback] = tuple(),
) -> None:
"""Runs trainer for a training task.
Expand All @@ -64,6 +65,8 @@ def run(
runs. Must be an integer multiple of `metric_aggregation_steps`.
num_batches_per_eval: The number of batches to step through every time
evaluation is run (resulting metrics are aggregated).
metric_writer: A metric writer that writes scalar metrics to disc. It is
also accessible to callbacks for custom writing in other formats.
callbacks: Self-contained programs executing non-essential logic (e.g.
checkpoint saving, logging, timing, profiling etc.).
"""
Expand All @@ -82,12 +85,13 @@ def run(
)
eval_iter = iter(eval_dataloader)

writer = metric_writers.create_default_writer(
workdir, just_logging=jax.process_index() > 0
)
if metric_writer is None:
metric_writer = metric_writers.create_default_writer(
workdir, just_logging=jax.process_index() > 0
)

for callback in callbacks:
callback.metric_writer = writer
callback.metric_writer = metric_writer
callback.on_train_begin(trainer)

cur_step = trainer.train_state.int_step
Expand All @@ -98,7 +102,7 @@ def run(
num_steps = min(total_train_steps - cur_step, metric_aggregation_steps)
train_metrics = trainer.train(train_iter, num_steps).compute()
cur_step += num_steps
writer.write_scalars(cur_step, train_metrics)
metric_writer.write_scalars(cur_step, train_metrics)

# At train/eval batch end, callbacks are called in reverse order so that
# they are last-in-first-out, loosely resembling nested python contexts.
Expand All @@ -113,16 +117,14 @@ def run(
assert eval_iter is not None
eval_metrics = trainer.eval(eval_iter, num_batches_per_eval).compute()
eval_metrics_to_log = {
k: v
for k, v in eval_metrics.items()
if utils.is_scalar(v)
k: v for k, v in eval_metrics.items() if utils.is_scalar(v)
}
writer.write_scalars(cur_step, eval_metrics_to_log)
metric_writer.write_scalars(cur_step, eval_metrics_to_log)

for callback in reversed(callbacks):
callback.on_eval_batches_end(trainer, eval_metrics)

for callback in reversed(callbacks):
callback.on_train_end(trainer)

writer.flush()
metric_writer.flush()
12 changes: 3 additions & 9 deletions swirl_dynamics/templates/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
from absl.testing import absltest
from absl.testing import parameterized
from clu import metrics as clu_metrics
import grain.tensorflow as grain
import grain.python as pygrain
import jax.numpy as jnp
import numpy as np
from swirl_dynamics.templates import callbacks
from swirl_dynamics.templates import train
from swirl_dynamics.templates import train_states
from swirl_dynamics.templates import trainers
from swirl_dynamics.templates import utils
import tensorflow as tf

mock = absltest.mock

Expand Down Expand Up @@ -92,13 +91,8 @@ class TrainTest(parameterized.TestCase):

def setUp(self):
super().setUp()
source = grain.TfInMemoryDataSource.from_dataset(
tf.data.Dataset.from_tensor_slices(np.ones(10))
)
sampler = grain.TfDefaultIndexSampler(
num_records=10, seed=12, shard_options=grain.NoSharding()
)
self.dummy_dataloader = grain.TfDataLoader(source=source, sampler=sampler)
source = pygrain.RangeDataSource(start=1, stop=10, step=1)
self.dummy_dataloader = pygrain.load(source, seed=12, batch_size=1)

# mock trainer with constant metrics returned
self.test_trainer = mock.Mock(spec=trainers.BasicTrainer)
Expand Down

0 comments on commit 2a051d2

Please sign in to comment.