diff --git a/swirl_dynamics/lib/diffusion/diffusion_test.py b/swirl_dynamics/lib/diffusion/diffusion_test.py index 8a1f6bf..74779f5 100644 --- a/swirl_dynamics/lib/diffusion/diffusion_test.py +++ b/swirl_dynamics/lib/diffusion/diffusion_test.py @@ -26,7 +26,7 @@ class DiffusionTest(parameterized.TestCase): def test_tangent_noise_schedule(self, clip_max, start, end): sigma = diffusion.tangent_noise_schedule(clip_max, start, end) self.assertAlmostEqual(sigma(1.0), clip_max, places=3) - self.assertEqual(sigma(0.0), 0) + self.assertAlmostEqual(sigma(0.0), 0, places=8) test_points = np.random.default_rng(1234).uniform(0.05, 1.0, size=(10,)) np.testing.assert_allclose( diff --git a/swirl_dynamics/lib/networks/rational_networks_test.py b/swirl_dynamics/lib/networks/rational_networks_test.py index 0f92cd4..9f6c9ab 100644 --- a/swirl_dynamics/lib/networks/rational_networks_test.py +++ b/swirl_dynamics/lib/networks/rational_networks_test.py @@ -42,22 +42,6 @@ def test_rational_default_init(self): params['q_coeffs'], expected_q_params, places=5 ) - test_apply = rat_net.apply({'params': params}, x) - - expected_apply = jnp.array([ - -0.003531166352, - 0.246322900057, - -0.021814700216, - 0.018779350445, - 0.006357953884, - -0.020405692980, - 0.021597648039, - 0.020564671606, - 0.750805377960, - 0.586633801460, - ]) - self.assertSequenceAlmostEqual(test_apply, expected_apply, places=5) - def test_unshared_rational_default_init(self): """Function to test that the rational networks are properly initialized.""" rat_net = rational_networks.UnsharedRationalLayer() @@ -85,28 +69,11 @@ def test_unshared_rational_default_init(self): places=5, ) - test_apply = rat_net.apply({'params': params}, x) - - expected_apply = jnp.array([ - -0.003531166352, - 0.246322900057, - -0.021814700216, - 0.018779350445, - 0.006357953884, - -0.020405692980, - 0.021597648039, - 0.020564671606, - 0.750805377960, - 0.586633801460, - ]) - self.assertSequenceAlmostEqual(test_apply, expected_apply, places=5) - class RationalMLPTest(absltest.TestCase): def test_number_params(self): - # Testing that the network has the correct number of parameters. - + """Tests that the network has the correct number of parameters.""" features = (2, 2) periodic_mlp_small = rational_networks.RationalMLP(features=features) diff --git a/swirl_dynamics/lib/solvers/sde.py b/swirl_dynamics/lib/solvers/sde.py index 30a9f84..4d8f8a3 100644 --- a/swirl_dynamics/lib/solvers/sde.py +++ b/swirl_dynamics/lib/solvers/sde.py @@ -56,7 +56,7 @@ def __call__( dynamics: SdeDynamics, x0: Array, tspan: Array, - rng: jax.random.KeyArray, + rng: Array, params: PyTree, ) -> Array: """Solves a SDE at given time stamps. @@ -95,7 +95,7 @@ def step( x0: Array, t0: Array, dt: Array, - rng: jax.random.KeyArray, + rng: Array, params: SdeParams, ) -> Array: """Advances the current state one step forward in time.""" @@ -106,14 +106,14 @@ def __call__( dynamics: SdeDynamics, x0: Array, tspan: Array, - rng: jax.random.KeyArray, + rng: Array, params: SdeParams, ) -> Array: """Solves a SDE by integrating the step function with `jax.lax.scan`.""" def scan_fun( state: tuple[Array, Array], - ext: tuple[Array, jax.random.KeyArray], + ext: tuple[Array, Array], ) -> tuple[tuple[Array, Array], Array]: x0, t0 = state t_next, step_rng = ext @@ -138,7 +138,7 @@ def step( x0: Array, t0: Array, dt: Array, - rng: jax.random.KeyArray, + rng: Array, params: SdeParams, ) -> Array: """Makes one Euler-Maruyama integration step in time.""" diff --git a/swirl_dynamics/templates/train.py b/swirl_dynamics/templates/train.py index 8bf17eb..b057644 100644 --- a/swirl_dynamics/templates/train.py +++ b/swirl_dynamics/templates/train.py @@ -38,7 +38,8 @@ def run( eval_dataloader: Iterable[Any] | None = None, eval_every_steps: int = 100, num_batches_per_eval: int = 10, - # callbacks + # other configs + metric_writer: metric_writers.MultiWriter | None = None, callbacks: Sequence[cb.Callback] = tuple(), ) -> None: """Runs trainer for a training task. @@ -64,6 +65,8 @@ def run( runs. Must be an integer multiple of `metric_aggregation_steps`. num_batches_per_eval: The number of batches to step through every time evaluation is run (resulting metrics are aggregated). + metric_writer: A metric writer that writes scalar metrics to disc. It is + also accessible to callbacks for custom writing in other formats. callbacks: Self-contained programs executing non-essential logic (e.g. checkpoint saving, logging, timing, profiling etc.). """ @@ -82,12 +85,13 @@ def run( ) eval_iter = iter(eval_dataloader) - writer = metric_writers.create_default_writer( - workdir, just_logging=jax.process_index() > 0 - ) + if metric_writer is None: + metric_writer = metric_writers.create_default_writer( + workdir, just_logging=jax.process_index() > 0 + ) for callback in callbacks: - callback.metric_writer = writer + callback.metric_writer = metric_writer callback.on_train_begin(trainer) cur_step = trainer.train_state.int_step @@ -98,7 +102,7 @@ def run( num_steps = min(total_train_steps - cur_step, metric_aggregation_steps) train_metrics = trainer.train(train_iter, num_steps).compute() cur_step += num_steps - writer.write_scalars(cur_step, train_metrics) + metric_writer.write_scalars(cur_step, train_metrics) # At train/eval batch end, callbacks are called in reverse order so that # they are last-in-first-out, loosely resembling nested python contexts. @@ -113,11 +117,9 @@ def run( assert eval_iter is not None eval_metrics = trainer.eval(eval_iter, num_batches_per_eval).compute() eval_metrics_to_log = { - k: v - for k, v in eval_metrics.items() - if utils.is_scalar(v) + k: v for k, v in eval_metrics.items() if utils.is_scalar(v) } - writer.write_scalars(cur_step, eval_metrics_to_log) + metric_writer.write_scalars(cur_step, eval_metrics_to_log) for callback in reversed(callbacks): callback.on_eval_batches_end(trainer, eval_metrics) @@ -125,4 +127,4 @@ def run( for callback in reversed(callbacks): callback.on_train_end(trainer) - writer.flush() + metric_writer.flush() diff --git a/swirl_dynamics/templates/train_test.py b/swirl_dynamics/templates/train_test.py index 37db32f..6dcf863 100644 --- a/swirl_dynamics/templates/train_test.py +++ b/swirl_dynamics/templates/train_test.py @@ -17,7 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized from clu import metrics as clu_metrics -import grain.tensorflow as grain +import grain.python as pygrain import jax.numpy as jnp import numpy as np from swirl_dynamics.templates import callbacks @@ -25,7 +25,6 @@ from swirl_dynamics.templates import train_states from swirl_dynamics.templates import trainers from swirl_dynamics.templates import utils -import tensorflow as tf mock = absltest.mock @@ -92,13 +91,8 @@ class TrainTest(parameterized.TestCase): def setUp(self): super().setUp() - source = grain.TfInMemoryDataSource.from_dataset( - tf.data.Dataset.from_tensor_slices(np.ones(10)) - ) - sampler = grain.TfDefaultIndexSampler( - num_records=10, seed=12, shard_options=grain.NoSharding() - ) - self.dummy_dataloader = grain.TfDataLoader(source=source, sampler=sampler) + source = pygrain.RangeDataSource(start=1, stop=10, step=1) + self.dummy_dataloader = pygrain.load(source, seed=12, batch_size=1) # mock trainer with constant metrics returned self.test_trainer = mock.Mock(spec=trainers.BasicTrainer)