diff --git a/pyproject.toml b/pyproject.toml index 40cd74c..9300ed6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,10 +23,12 @@ keywords = [ # Installed locally with `pip install -e .` dependencies = [ "absl-py", + "chex", "clu", + "etils", "flax", "h5py", - # "grain@git+https://github.com/google/grain", + "grain@git+https://github.com/google/grain.git@b80f7066ce1f69317519bf64739e5ff9a463059a", "gin-config", "jax", "numpy", diff --git a/swirl_dynamics/data/tfgrain_transforms.py b/swirl_dynamics/data/tfgrain_transforms.py deleted file mode 100644 index 6119f05..0000000 --- a/swirl_dynamics/data/tfgrain_transforms.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2023 The swirl_dynamics Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module for reusable TfGrain transformations.""" - -from collections.abc import MutableMapping -import dataclasses -import grain.tensorflow as tfgrain -import jax -import tensorflow as tf - -Array = jax.Array -FeatureDict = MutableMapping[str, tf.Tensor] -Scalar = float | int - - -def _check_valid_scale_range( - scale_range: tuple[Scalar, Scalar], name: str -) -> None: - """Checks if a given scale range is valid.""" - if scale_range[0] >= scale_range[1]: - raise ValueError( - f"Lower bound of {name} ({scale_range[0]}) must be strictly smaller" - f" than its upper bound ({scale_range[1]})" - ) - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class LinearRescale(tfgrain.MapTransform): - """Apply linear rescaling to a feature.""" - - feature_name: str - input_range: tuple[Scalar, Scalar] - output_range: tuple[Scalar, Scalar] - - def __post_init__(self): - _check_valid_scale_range(self.input_range, name="input_range") - _check_valid_scale_range(self.output_range, name="output_range") - - def map(self, features: FeatureDict) -> FeatureDict: - normalized = (features[self.feature_name] - self.input_range[0]) / ( - self.input_range[1] - self.input_range[0] - ) - rescaled = ( - normalized * (self.output_range[1] - self.output_range[0]) - + self.output_range[0] - ) - features[self.feature_name] = rescaled - return features - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class Normalize(tfgrain.MapTransform): - """Apply normalization to a feature.""" - - feature_name: str - mean: Array - std: Array - - def map(self, features: FeatureDict) -> FeatureDict: - normalized = (features[self.feature_name] - self.mean) / self.std - features[self.feature_name] = normalized - return features - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class RandomSection(tfgrain.RandomMapTransform): - """Samples a random section in a given trajectory. - - Sampling always happens along the leading dimension. First a start index is - randomly selected amongst all permissible locations and the sample is taken to - be the contiguous section immediately after and including the start index, - with the specified number of steps and stride. - - Attributes: - feature_names: names of the features to be sampled. All sampled features - share the same start index (besides number of steps and stride). They must - have the same dimension in the leading axis. - num_steps: the number of steps in the sample. - stride: the stride (i.e. downsample) in the sampled section wrt the original - features. - """ - - feature_names: tuple[str, ...] - num_steps: int - stride: int = 1 - - def random_map(self, features: FeatureDict, seed: tf.Tensor) -> FeatureDict: - total_length = features[self.feature_names[0]].shape.as_list()[0] - sample_length = self.stride * (self.num_steps - 1) + 1 - - for name in self.feature_names[1:]: - feature_length = features[name].shape.as_list()[0] - if feature_length != total_length: - raise ValueError( - "Features must have the same dimension along axis 0:" - f" {self.feature_names[0]} ({total_length}) vs." - f" {name} ({feature_length})" - ) - - if sample_length > total_length: - raise ValueError( - f"Not enough steps [{total_length}] " - f"for desired sample length [{sample_length}] " - f"= stride [{self.stride}] * (num_steps [{self.num_steps}] - 1) + 1" - ) - elif sample_length == total_length: - start_idx, end_idx = 0, total_length - else: - start_idx = tf.random.stateless_uniform( - shape=(), - seed=seed, - maxval=total_length - sample_length + 1, - dtype=tf.int32, - ) - end_idx = start_idx + sample_length - - for name in self.feature_names: - features[name] = features[name][start_idx : end_idx : self.stride] - return features - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class Split(tfgrain.MapTransform): - """Splits a tensor feature into multiple sub tensors. - - Attributes: - feature_name: name of the feature to be split. - split_sizes: the sizes of each output feature along `axis`. Must sum to the - dimension of the presplit feature along `axis`. - split_names: the name of the output features. Must have the same length as - the `split_sizes`. - axis: the axis along which splitting happens. - keep_presplit: whether to keep the presplit feature in the processed batch. - """ - - feature_name: str - split_sizes: tuple[int, ...] - split_names: tuple[str, ...] - axis: int = 0 - keep_presplit: bool = False - - def __post_init__(self) -> None: - if len(self.split_names) != len(self.split_sizes): - raise ValueError( - f"Length of `split_sizes` [{self.split_sizes}] must match " - f"that of `split_names` [{self.split_names}]" - ) - - def map(self, features: FeatureDict) -> FeatureDict: - feature = features[self.feature_name] - splits = tf.split(feature, self.split_sizes, axis=self.axis) - for split_name, split_value in zip(self.split_names, splits): - features[split_name] = split_value - if not self.keep_presplit: - features.pop(self.feature_name) - return features diff --git a/swirl_dynamics/data/tfgrain_transforms_test.py b/swirl_dynamics/data/tfgrain_transforms_test.py deleted file mode 100644 index d6f8ab6..0000000 --- a/swirl_dynamics/data/tfgrain_transforms_test.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright 2023 The swirl_dynamics Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from absl.testing import absltest -from absl.testing import parameterized -from swirl_dynamics.data import tfgrain_transforms as transforms -import tensorflow as tf - - -class LinearRescaleTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.parameters( - # test positive, negative and mixed output ranges - {"output_range": (3, 4)}, - {"output_range": (-5, 5)}, - {"output_range": (-6, -3)}, - ) - def test_rescales_to_correct_range(self, output_range): - input_max = 5 - raw_feature = tf.range(input_max) - raw_sample = {"x": raw_feature} - transform = transforms.LinearRescale( - feature_name="x", - input_range=(0, input_max - 1), - output_range=output_range, - ) - transformed_sample = transform.map(raw_sample) - self.assertEqual(transformed_sample["x"].shape, (input_max,)) - self.assertEqual(tf.reduce_max(transformed_sample["x"]), output_range[1]) - self.assertEqual(tf.reduce_min(transformed_sample["x"]), output_range[0]) - - @parameterized.parameters( - {"input_range": (1, 0), "output_range": (3, 4)}, - {"input_range": (0, 1), "output_range": (-3, -3)}, - ) - def test_raises_invalid_range(self, input_range, output_range): - with self.assertRaisesRegex(ValueError, "strictly smaller"): - transforms.LinearRescale( - feature_name="x", input_range=input_range, output_range=output_range - ) - - -class NormalizeTest(tf.test.TestCase, parameterized.TestCase): - - def test_normalizes_to_correct_statistics(self): - raw_feature = tf.constant( - [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], - ) - raw_sample = {"x": raw_feature} - mean = tf.math.reduce_mean(raw_feature, axis=0) - std = tf.math.reduce_std(raw_feature, axis=0) - transform = transforms.Normalize( - feature_name="x", - mean=mean, - std=std, - ) - transformed_sample = transform.map(raw_sample) - self.assertEqual(transformed_sample["x"].shape, raw_feature.shape) - self.assertNear(tf.math.reduce_mean(transformed_sample["x"]), 0.0, 1e-5) - self.assertNear(tf.math.reduce_std(transformed_sample["x"]), 1.0, 1e-5) - - -class RandomSectionTransformTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.parameters( - {"num_steps": 4, "stride": 1}, # sample length < total length - {"num_steps": 5, "stride": 2}, # sample length = total length - ) - def test_correct_shapes_and_strides(self, num_steps, stride): - raw_feature = tf.tile( - tf.expand_dims(tf.expand_dims(tf.range(0, 9), axis=1), axis=2), - multiples=(1, 16, 1), - ) # shape = (9, 16, 1) - raw_sample = {"u": raw_feature, "t": raw_feature} - transform = transforms.RandomSection( - feature_names=("u", "t"), num_steps=num_steps, stride=stride - ) - transformed_sample = transform.random_map( - raw_sample, seed=tf.constant((2, 3)) - ) - with self.subTest(checking="shapes"): - self.assertEqual(transformed_sample["u"].shape, (num_steps, 16, 1)) - self.assertEqual(transformed_sample["t"].shape, (num_steps, 16, 1)) - with self.subTest(checking="stride"): - self.assertEqual( - transformed_sample["t"][1, 0, 0] - transformed_sample["t"][0, 0, 0], - stride, - ) - - @parameterized.parameters( - # sample length > total length with various strides - {"num_steps": 10, "stride": 1}, - {"num_steps": 5, "stride": 2}, - ) - def test_raises_not_enough_steps(self, num_steps, stride): - sample = {"u": tf.zeros((8, 16, 1))} - transform = transforms.RandomSection( - feature_names=("u",), num_steps=num_steps, stride=stride - ) - with self.assertRaisesRegex(ValueError, "Not enough steps"): - transform.random_map(sample, seed=tf.constant((2, 3))) - - def test_raises_unequal_feature_dim(self): - sample = {"u": tf.zeros((8, 16, 1)), "t": tf.zeros((4, 16, 1))} - transform = transforms.RandomSection(feature_names=("u", "t"), num_steps=5) - with self.assertRaisesRegex(ValueError, "same dimension"): - transform.random_map(sample, seed=tf.constant((2, 3))) - - -class SplitTransformTest(tf.test.TestCase, parameterized.TestCase): - - @parameterized.parameters( - { - "split_sizes": (2, 2), - "axis": 0, - "split_names": ("u_0", "u_1"), - "u0_shape": (2, 8), - }, - { - "split_sizes": (1, 3, 4), - "axis": 1, - "split_names": ("u_0", "u_1", "u_3"), - "u0_shape": (4, 1), - }, - ) - def test_splits_to_correct_shapes( - self, split_sizes, split_names, axis, u0_shape - ): - sample = {"u": tf.ones((4, 8))} - transform = transforms.Split( - feature_name="u", - split_sizes=split_sizes, - split_names=split_names, - axis=axis, - ) - transformed_sample = transform.map(sample) - self.assertEqual(transformed_sample["u_0"].shape, u0_shape) - - def test_raises_length_mismatch(self): - with self.assertRaisesRegex(ValueError, "Length .* must match"): - transforms.Split( - feature_name="u", split_sizes=(8,), split_names=("setup", "pred") - ) - - def test_keeps_presplit_feature(self): - sample = {"u": tf.zeros((4, 8))} - transform = transforms.Split( - feature_name="u", - split_sizes=(4, 4), - split_names=("u_setup", "u_pred"), - axis=1, - keep_presplit=True, - ) - transformed_sample = transform.map(sample) - self.assertIn("u", transformed_sample) - - -if __name__ == "__main__": - absltest.main() diff --git a/swirl_dynamics/data/utils_test.py b/swirl_dynamics/data/utils_test.py index 4340c78..c5859ec 100644 --- a/swirl_dynamics/data/utils_test.py +++ b/swirl_dynamics/data/utils_test.py @@ -14,10 +14,13 @@ import os +from absl import flags from absl.testing import absltest from absl.testing import parameterized from swirl_dynamics.data import utils +FLAGS = flags.FLAGS + class UtilsTest(parameterized.TestCase): diff --git a/swirl_dynamics/lib/networks/rational_networks_test.py b/swirl_dynamics/lib/networks/rational_networks_test.py index 773484d..0f92cd4 100644 --- a/swirl_dynamics/lib/networks/rational_networks_test.py +++ b/swirl_dynamics/lib/networks/rational_networks_test.py @@ -32,11 +32,15 @@ def test_rational_default_init(self): x = jax.random.normal(rng, (10,)) params = rat_net.init(rng, x)['params'] - # Checking that the arrays are the same up to 1e-6. + # Checking that the arrays are the same up to 1e-5. expected_p_params = jnp.array([1.1915, 1.5957, 0.5, 0.0218]) expected_q_params = jnp.array([2.383, 0.0, 1.0]) - self.assertSequenceAlmostEqual(params['p_coeffs'], expected_p_params) - self.assertSequenceAlmostEqual(params['q_coeffs'], expected_q_params) + self.assertSequenceAlmostEqual( + params['p_coeffs'], expected_p_params, places=5 + ) + self.assertSequenceAlmostEqual( + params['q_coeffs'], expected_q_params, places=5 + ) test_apply = rat_net.apply({'params': params}, x) @@ -52,7 +56,6 @@ def test_rational_default_init(self): 0.750805377960, 0.586633801460, ]) - self.assertSequenceAlmostEqual(test_apply, expected_apply, places=5) def test_unshared_rational_default_init(self): @@ -64,7 +67,7 @@ def test_unshared_rational_default_init(self): x = jax.random.normal(rng, (10,)) params = rat_net.init(rng, x)['params'] - # Checking that the arrays are the same up to 1e-6. + # Checking that the arrays are the same up to 1e-5. expected_p_params = jnp.array([1.1915, 1.5957, 0.5, 0.0218]) * jnp.ones( (x.shape[-1], 1) ) @@ -72,10 +75,14 @@ def test_unshared_rational_default_init(self): (x.shape[-1], 1) ) self.assertSequenceAlmostEqual( - params['p_params'].reshape((-1,)), expected_p_params.reshape((-1,)) + params['p_params'].reshape((-1,)), + expected_p_params.reshape((-1,)), + places=5, ) self.assertSequenceAlmostEqual( - params['q_params'].reshape((-1,)), expected_q_params.reshape((-1,)) + params['q_params'].reshape((-1,)), + expected_q_params.reshape((-1,)), + places=5, ) test_apply = rat_net.apply({'params': params}, x) @@ -92,7 +99,6 @@ def test_unshared_rational_default_init(self): 0.750805377960, 0.586633801460, ]) - self.assertSequenceAlmostEqual(test_apply, expected_apply, places=5) diff --git a/swirl_dynamics/projects/ergodic/stable_ar.py b/swirl_dynamics/projects/ergodic/stable_ar.py index f3729c7..1583a8e 100644 --- a/swirl_dynamics/projects/ergodic/stable_ar.py +++ b/swirl_dynamics/projects/ergodic/stable_ar.py @@ -83,7 +83,7 @@ def loss_fn( self, params: PyTree, batch: models.BatchType, - rng: jax.random.KeyArray, + rng: Array, mutables: PyTree, ) -> models.LossAndAux: """Computes training loss and metrics.""" @@ -182,7 +182,7 @@ def eval_fn( variables: PyTree, # batch is dict with keys: ['ic', 'true', 'tspan', 'normalize_stats'] batch: models.BatchType, - rng: jax.random.KeyArray, + rng: Array, **kwargs, ) -> models.ArrayDict: tspan = batch["tspan"].reshape((-1,)) @@ -308,7 +308,7 @@ def _preprocess_train_batch( ) def preprocess_train_batch( - self, batch_data: trainers.BatchType, step: int, rng: jax.random.KeyArray + self, batch_data: trainers.BatchType, step: int, rng: Array ) -> trainers.BatchType: """Wrapper method for _preprocess_train_batch. @@ -343,7 +343,7 @@ def preprocess_train_batch( return self._preprocess_train_batch(batch_data, num_time_steps) def preprocess_eval_batch( - self, batch_data: trainers.BatchType, rng: jax.random.KeyArray + self, batch_data: trainers.BatchType, rng: Array ) -> trainers.BatchType: """Preprocessed batch data.""" if self.conf.num_lookback_steps > 1: @@ -432,7 +432,7 @@ def _preprocess_train_batch( return jax.jit(trainers.reshape_for_pmap)(batch_dict) def preprocess_train_batch( - self, batch_data: trainers.BatchType, step: int, rng: jax.random.KeyArray + self, batch_data: trainers.BatchType, step: int, rng: Array ) -> trainers.BatchType: """Wrapper method for _preprocess_train_batch. @@ -461,7 +461,7 @@ def preprocess_train_batch( return self._preprocess_train_batch(batch_data, num_time_steps) def preprocess_eval_batch( - self, batch_data: trainers.BatchType, rng: jax.random.KeyArray + self, batch_data: trainers.BatchType, rng: Array ) -> trainers.BatchType: """Preprocessed batch data.""" if self.conf.num_lookback_steps > 1: diff --git a/swirl_dynamics/projects/ergodic/utils.py b/swirl_dynamics/projects/ergodic/utils.py index dc2f200..ff42f2e 100644 --- a/swirl_dynamics/projects/ergodic/utils.py +++ b/swirl_dynamics/projects/ergodic/utils.py @@ -304,7 +304,7 @@ def sobolev_norm( ) # Performing the integration using trapezoidal rule. - norm_squared = jnp.sum(mult * u_fft_squared, axis=axes) / (n_x)**dim + norm_squared = jnp.sum(mult * u_fft_squared, axis=axes) / (n_x) ** dim # Returns the mean. return jnp.mean(norm_squared) @@ -358,7 +358,7 @@ def sample_uniform_spherical_shell( n_points: int, radii: tuple[float, float], shape: tuple[int, ...], - key: jax.random.KeyArray, + key: Array, ): """Uniform sampling (in angle and radius) from an spherical shell. @@ -379,17 +379,19 @@ def sample_uniform_spherical_shell( broadcasting_shape = (n_points,) + len(shape) * (1,) # Obtain the correct axis for the sum, depending on the shape. # Here we suppose that shape comes in the form (nx, ny, d) or (nx, d). - assert len(shape) < 4 and len(shape) >= 2, ("The shape should represent ", - "one- or two-dimensional points.", - f" Instead we have shape {shape}") + assert len(shape) < 4 and len(shape) >= 2, ( + "The shape should represent ", + "one- or two-dimensional points.", + f" Instead we have shape {shape}", + ) - axis_sum = (1,) if len(shape) == 2 else (1, 2,) + axis_sum = (1,) if len(shape) == 2 else (1, 2) key_radius, key_vec = jax.random.split(key) - sampling_radius = jax.random.uniform(key_radius, (n_points,), - minval=inner_radius, - maxval=outer_radius) + sampling_radius = jax.random.uniform( + key_radius, (n_points,), minval=inner_radius, maxval=outer_radius + ) vec = jax.random.normal(key_vec, shape=((n_points,) + shape)) vec_norm = jnp.linalg.norm(vec, axis=axis_sum).reshape(broadcasting_shape) @@ -416,6 +418,7 @@ def linear_scale_dissipative_target(inputs: Array, scale: float = 1.0): def plot_cos_sims(dt: Array, traj_length: int, trajs: Array, pred_trajs: Array): """Plot cosine similarities over time.""" + def sum_non_batch_dims(x: Array) -> Array: """Helper method to sum array along all dimensions except the 0th.""" ndim = x.ndim @@ -429,17 +432,16 @@ def state_cos_sim(x: Array, y: Array) -> Array: Args: x: array of states; shape: batch_size x state_dimension y: array of states; shape: batch_size x state_dimension + Returns: cosine similarity averaged along batch dimension. """ x_norm = jnp.expand_dims( - jnp.sqrt(sum_non_batch_dims((x ** 2))), - axis=tuple(range(1, x.ndim)) + jnp.sqrt(sum_non_batch_dims((x**2))), axis=tuple(range(1, x.ndim)) ) x /= x_norm y_norm = jnp.expand_dims( - jnp.sqrt(sum_non_batch_dims((y ** 2))), - axis=tuple(range(1, y.ndim)) + jnp.sqrt(sum_non_batch_dims((y**2))), axis=tuple(range(1, y.ndim)) ) y /= y_norm return sum_non_batch_dims(x * y).mean(axis=0) @@ -450,20 +452,22 @@ def state_cos_sim(x: Array, y: Array) -> Array: # Plot 0.9, 0.8 threshold lines ax.plot( plot_time, - jnp.ones(traj_length)*0.9, - color="black", linestyle="dashed", - label="0.9 threshold" + jnp.ones(traj_length) * 0.9, + color="black", + linestyle="dashed", + label="0.9 threshold", ) ax.plot( plot_time, - jnp.ones(traj_length)*0.8, - color="red", linestyle="dashed", - label="0.8 threshold" + jnp.ones(traj_length) * 0.8, + color="red", + linestyle="dashed", + label="0.8 threshold", ) # Plot correlation lines - cosine_sims = jax.vmap( - state_cos_sim, in_axes=(1, 1) - )(trajs[:, :traj_length, :], pred_trajs[:, :traj_length, :]) + cosine_sims = jax.vmap(state_cos_sim, in_axes=(1, 1))( + trajs[:, :traj_length, :], pred_trajs[:, :traj_length, :] + ) ax.plot(plot_time, cosine_sims) ax.set_xlim([0, t_max]) ax.set_xlabel(r"$t$") diff --git a/swirl_dynamics/projects/evolve_smoothly/batch_decode.py b/swirl_dynamics/projects/evolve_smoothly/batch_decode.py index 8ab7e56..beda159 100644 --- a/swirl_dynamics/projects/evolve_smoothly/batch_decode.py +++ b/swirl_dynamics/projects/evolve_smoothly/batch_decode.py @@ -46,7 +46,7 @@ def __init__(self, ansatz: ansatzes.Ansatz, num_snapshots: int) -> None: self.ansatz = ansatz self.num_snapshots = num_snapshots - def initialize(self, rng: jax.random.KeyArray) -> models.ModelVariable: + def initialize(self, rng: jax.Array) -> models.ModelVariable: """Initializes the variables of the ansatz.""" return jax.vmap(self.ansatz.model.init, in_axes=(0, None))( jax.random.split(rng, self.num_snapshots), @@ -57,7 +57,7 @@ def loss_fn( self, params: PyTree, batch: models.BatchType, - rng: jax.random.KeyArray, + rng: jax.Array, mutables: PyTree, ) -> models.LossAndAux: """Computes the l2 reconstruction loss.""" @@ -69,7 +69,7 @@ def loss_fn( return loss, ({"loss": loss}, mutables) def eval_fn( - self, variables: PyTree, batch: models.BatchType, rng: jax.random.KeyArray + self, variables: PyTree, batch: models.BatchType, rng: jax.Array ) -> models.ArrayDict: """Evaluates mean, worst-case and std relative l2 errors.""" del rng diff --git a/swirl_dynamics/projects/evolve_smoothly/encode_decode.py b/swirl_dynamics/projects/evolve_smoothly/encode_decode.py index def6364..c7a1d28 100644 --- a/swirl_dynamics/projects/evolve_smoothly/encode_decode.py +++ b/swirl_dynamics/projects/evolve_smoothly/encode_decode.py @@ -61,7 +61,7 @@ def __init__( self.snapshot_dims = snapshot_dims self.consistency_weight = consistency_weight - def initialize(self, rng: jax.random.KeyArray) -> models.ModelVariable: + def initialize(self, rng: Array) -> models.ModelVariable: """Initializes the variables of the encoder.""" enc_output, enc_vars = self.encoder.init_with_output( rng, jnp.ones((1,) + self.snapshot_dims) @@ -78,7 +78,7 @@ def loss_fn( self, params: PyTree, batch: models.BatchType, - rng: jax.random.KeyArray, + rng: Array, mutables: PyTree, ) -> models.LossAndAux: """Computes the reconstruction and consistency loss.""" @@ -108,7 +108,7 @@ def loss_fn( return loss, (metric, mutables) def eval_fn( - self, variables: PyTree, batch: models.BatchType, rng: jax.random.KeyArray + self, variables: PyTree, batch: models.BatchType, rng: Array ) -> models.ArrayDict: """Evaluates mean, worst-case and std relative l2 errors.""" del rng diff --git a/swirl_dynamics/projects/evolve_smoothly/latent_dynamics.py b/swirl_dynamics/projects/evolve_smoothly/latent_dynamics.py index 2e91fa8..d41e015 100644 --- a/swirl_dynamics/projects/evolve_smoothly/latent_dynamics.py +++ b/swirl_dynamics/projects/evolve_smoothly/latent_dynamics.py @@ -70,7 +70,7 @@ class LatentDynamics(models.BaseModel): latent_weight: float = 1.0 consistency_weight: float = 1.0 - def initialize(self, rng: jax.random.KeyArray) -> models.ModelVariable: + def initialize(self, rng: Array) -> models.ModelVariable: """Initializes the variables of the dynamics model.""" sample_input = jnp.ones((1, self.ansatz.num_params)) out, variables = self.latent_dynamics_model.init_with_output( @@ -83,7 +83,7 @@ def loss_fn( self, params: PyTree, batch: models.BatchType, - rng: jax.random.KeyArray, + rng: Array, mutables: PyTree, ) -> models.LossAndAux: """Computes the reconstruction and consistency loss.""" @@ -123,7 +123,7 @@ def eval_fn( self, variables: PyTree, batch: models.BatchType, - rng: jax.random.KeyArray, + rng: Array, ) -> models.ArrayDict: """Evaluates mean, worst-case and std relative l2 errors.""" ambient_ic, ambient_target = batch["u"][:, 0], batch["u"] diff --git a/swirl_dynamics/templates/callbacks_test.py b/swirl_dynamics/templates/callbacks_test.py index 4f5d74d..0fa23f2 100644 --- a/swirl_dynamics/templates/callbacks_test.py +++ b/swirl_dynamics/templates/callbacks_test.py @@ -15,6 +15,7 @@ import io import os +from absl import flags from absl.testing import absltest from absl.testing import parameterized from clu import metric_writers @@ -27,6 +28,9 @@ from swirl_dynamics.templates import trainers from swirl_dynamics.templates import utils +jax.config.parse_flags_with_absl() + +FLAGS = flags.FLAGS mock = absltest.mock diff --git a/swirl_dynamics/templates/models.py b/swirl_dynamics/templates/models.py index cf5159d..4194572 100644 --- a/swirl_dynamics/templates/models.py +++ b/swirl_dynamics/templates/models.py @@ -42,7 +42,7 @@ class BaseModel(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def initialize(self, rng: jax.random.KeyArray) -> ModelVariable: + def initialize(self, rng: jax.Array) -> ModelVariable: """Initializes variables of the wrapped flax module(s). This method by design does not take any sample input in its argument. Input @@ -75,7 +75,7 @@ def loss_fn( self, params: PyTree, batch: BatchType, - rng: jax.random.KeyArray, + rng: jax.Array, mutables: PyTree, **kwargs, ) -> LossAndAux: @@ -107,7 +107,7 @@ def eval_fn( self, variables: PyTree, batch: BatchType, - rng: jax.random.KeyArray, + rng: jax.Array, **kwargs, ) -> ArrayDict: """Computes evaluation metrics.""" diff --git a/swirl_dynamics/templates/train_states_test.py b/swirl_dynamics/templates/train_states_test.py index c21ed02..c6e2338 100644 --- a/swirl_dynamics/templates/train_states_test.py +++ b/swirl_dynamics/templates/train_states_test.py @@ -14,6 +14,7 @@ import os +from absl import flags from absl.testing import absltest from absl.testing import parameterized import flax @@ -24,7 +25,9 @@ from orbax import checkpoint from swirl_dynamics.templates import train_states -mock = absltest.mock +jax.config.parse_flags_with_absl() + +FLAGS = flags.FLAGS class TrainStateTest(parameterized.TestCase): diff --git a/swirl_dynamics/templates/trainers.py b/swirl_dynamics/templates/trainers.py index f038261..a006132 100644 --- a/swirl_dynamics/templates/trainers.py +++ b/swirl_dynamics/templates/trainers.py @@ -26,6 +26,7 @@ from swirl_dynamics.templates import models from swirl_dynamics.templates import train_states +Array = jax.Array BatchType = Mapping[str, jax.typing.ArrayLike] Metrics = clu_metrics.Collection PyTree = Any @@ -72,11 +73,7 @@ class (corresponding to the metrics computed on an evaluation batch). See `clu.metrics.Collection` for how to define this class. """ - def __init__( - self, - model: M, - rng: jax.random.KeyArray, - ): + def __init__(self, model: M, rng: Array): self.model = model train_rng, eval_rng, self._init_rng = jax.random.split(rng, 3) @@ -100,11 +97,11 @@ def train_state(self, train_state: S) -> None: # Convenience properties/functions # ********************************** - def get_train_rng(self, num: int = 1) -> jax.random.KeyArray: + def get_train_rng(self, num: int = 1) -> Array: rng = jax.random.fold_in(self._train_rng, self.train_state.int_step) return jax.random.split(rng, num=num) - def get_eval_rng(self, num: int = 1) -> jax.random.KeyArray: + def get_eval_rng(self, num: int = 1) -> Array: rng, self._eval_rng = jax.random.split(self._eval_rng) return jax.random.split(rng, num=num) @@ -124,7 +121,7 @@ def _maybe_unreplicate(self, tree: PyTree) -> PyTree: def _maybe_replicate(self, tree: PyTree) -> PyTree: return flax.jax_utils.replicate(tree) if self.is_distributed else tree - def _maybe_split(self, rng: jax.random.KeyArray) -> jax.random.KeyArray: + def _maybe_split(self, rng: Array) -> Array: if self.is_distributed: rng = jax.random.split(rng, num=jax.local_device_count()) return rng @@ -175,7 +172,7 @@ def eval(self, batch_iter: Iterator[BatchType], num_steps: int) -> Metrics: # ********************************** @abc.abstractmethod - def initialize_train_state(self, rng: jax.random.KeyArray) -> S: + def initialize_train_state(self, rng: Array) -> S: """Instantiate the initial train state. Args: @@ -219,7 +216,7 @@ def _train_step(state, batch): @property @abc.abstractmethod - def eval_step(self) -> Callable[[S, BatchType, jax.random.KeyArray], Metrics]: + def eval_step(self) -> Callable[[S, BatchType, Array], Metrics]: """Returns the evaluation step function. Same as `BaseTrainer.train_step`, except for evaluation step. @@ -231,7 +228,7 @@ def eval_step(self) -> Callable[[S, BatchType, jax.random.KeyArray], Metrics]: # ********************************** def preprocess_train_batch( - self, batch_data: BatchType, step: int, rng: jax.random.KeyArray + self, batch_data: BatchType, step: int, rng: Array ) -> BatchType: """Preprocesses batch data before calling the training step function. @@ -254,7 +251,7 @@ def preprocess_train_batch( return batch_data def preprocess_eval_batch( - self, batch_data: BatchType, rng: jax.random.KeyArray + self, batch_data: BatchType, rng: Array ) -> BatchType: """Preprocesses batch before calling the eval step function. @@ -339,11 +336,11 @@ def __init__(self, optimizer: optax.GradientTransformation, *args, **kwargs): def train_step( self, ) -> Callable[ - [BasicTrainState, BatchType, jax.random.KeyArray], + [BasicTrainState, BatchType, Array], tuple[BasicTrainState, Metrics], ]: def _train_step( - train_state: BasicTrainState, batch: BatchType, rng: jax.random.KeyArray + train_state: BasicTrainState, batch: BatchType, rng: Array ) -> tuple[BasicTrainState, Metrics]: """Performs gradient step and compute training metrics.""" grad_fn = jax.grad(self.model.loss_fn, argnums=0, has_aux=True) @@ -385,9 +382,9 @@ def _update_train_state( @property def eval_step( self, - ) -> Callable[[BasicTrainState, BatchType, jax.random.KeyArray], Metrics]: + ) -> Callable[[BasicTrainState, BatchType, Array], Metrics]: def _eval_step( - train_state: BasicTrainState, batch: BatchType, rng: jax.random.KeyArray + train_state: BasicTrainState, batch: BatchType, rng: Array ) -> Metrics: """Use model to compute the evaluation metrics.""" eval_metrics = self.model.eval_fn( @@ -398,7 +395,7 @@ def _eval_step( return _eval_step - def initialize_train_state(self, rng: jax.random.KeyArray) -> BasicTrainState: + def initialize_train_state(self, rng: Array) -> BasicTrainState: """Initializes the model variables and the train state.""" init_vars = self.model.initialize(rng) mutables, params = flax.core.pop(init_vars, "params") @@ -434,11 +431,11 @@ class BasicDistributedTrainer(BasicTrainer[BasicModel, BasicTrainState]): def train_step( self, ) -> Callable[ - [BasicTrainState, BatchType, jax.random.KeyArray], + [BasicTrainState, BatchType, Array], tuple[BasicTrainState, Metrics], ]: def _train_step( - train_state: BasicTrainState, batch: BatchType, rng: jax.random.KeyArray + train_state: BasicTrainState, batch: BatchType, rng: Array ) -> tuple[BasicTrainState, Metrics]: """Performs gradient step and compute training metrics.""" grad_fn = jax.grad(self.model.loss_fn, argnums=0, has_aux=True) @@ -464,9 +461,9 @@ def _train_step( @property def eval_step( self, - ) -> Callable[[BasicTrainState, BatchType, jax.random.KeyArray], Metrics]: + ) -> Callable[[BasicTrainState, BatchType, Array], Metrics]: def _eval_step( - train_state: BasicTrainState, batch: BatchType, rng: jax.random.KeyArray + train_state: BasicTrainState, batch: BatchType, rng: Array ) -> Metrics: """Use model to compute the evaluation metrics.""" eval_metrics = self.model.eval_fn( @@ -479,7 +476,7 @@ def _eval_step( return _eval_step def preprocess_train_batch( - self, batch_data: BatchType, step: int, rng: jax.random.KeyArray + self, batch_data: BatchType, step: int, rng: Array ) -> BatchType: del step, rng # Preprocessed batch should always be reshaped for distributed training