From 58b12fad4d23bfcab4717cf31510e29dc815804d Mon Sep 17 00:00:00 2001 From: Zhong Yi Wan Date: Fri, 17 Nov 2023 09:55:49 -0800 Subject: [PATCH] Code update PiperOrigin-RevId: 583417368 --- swirl_dynamics/data/hdf5_utils.py | 97 +++++++++++++++++++ .../{utils_test.py => hdf5_utils_test.py} | 36 ++++++- swirl_dynamics/data/utils.py | 62 ------------ .../projects/ergodic/generate_traj.py | 14 +-- swirl_dynamics/projects/ergodic/utils.py | 13 +-- .../evolve_smoothly/data_pipelines.py | 14 +-- 6 files changed, 144 insertions(+), 92 deletions(-) create mode 100644 swirl_dynamics/data/hdf5_utils.py rename swirl_dynamics/data/{utils_test.py => hdf5_utils_test.py} (50%) delete mode 100644 swirl_dynamics/data/utils.py diff --git a/swirl_dynamics/data/hdf5_utils.py b/swirl_dynamics/data/hdf5_utils.py new file mode 100644 index 0000000..6bfb0c4 --- /dev/null +++ b/swirl_dynamics/data/hdf5_utils.py @@ -0,0 +1,97 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for hdf5 file reading and writing.""" + +from collections.abc import Mapping, Sequence +import io +from typing import Any + +from etils import epath +import h5py +import numpy as np +import tensorflow as tf + +gfile = tf.io.gfile +exists = tf.io.gfile.exists + + +def read_arrays_as_tuple( + file_path: epath.PathLike, keys: Sequence[str], dtype: Any = np.float32 +) -> tuple[np.ndarray, ...]: + """Reads specified fields from a given file as numpy arrays.""" + if not exists(file_path): + raise FileNotFoundError(f"No data file found at {file_path}") + + with gfile.GFile(file_path, "rb") as f: + with h5py.File(f, "r") as hf: + data = tuple(np.asarray(hf[key], dtype=dtype) for key in keys) + return data + + +def read_single_array( + file_path: epath.PathLike, key: str, dtype: Any = np.float32 +) -> np.ndarray: + """Reads a single field from a given file.""" + return read_arrays_as_tuple(file_path, [key], dtype)[0] + + +def _read_group( + group: h5py.Group, array_dtype: Any = np.float32 +) -> Mapping[str, Any]: + """Recursively reads a hdf5 group.""" + out = {} + for key in group.keys(): + if isinstance(group[key], h5py.Group): + out[key] = _read_group(group[key]) + elif isinstance(group[key], h5py.Dataset): + if group[key].shape: + out[key] = np.asarray(group[key], dtype=array_dtype) + else: + out[key] = group[key][()] + else: + raise ValueError(f"Unknown type for key {key}") + return out + + +def read_all_arrays_as_dict( + file_path: epath.PathLike, array_dtype: Any = np.float32 +) -> Mapping[str, Any]: + """Reads the entire contents of a file as a (possibly nested) dictionary.""" + if not exists(file_path): + raise FileNotFoundError(f"No data file found at {file_path}") + + with gfile.GFile(file_path, "rb") as f: + with h5py.File(f, "r") as hf: + return _read_group(hf, array_dtype) + + +def _save_array_dict(group: h5py.Group, data: Mapping[str, Any]) -> None: + """Saves a nested python dictionary to hdf5 groups recursively.""" + for key, value in data.items(): + if isinstance(value, dict): + subgroup = group.create_group(key) + _save_array_dict(subgroup, value) + else: + group.create_dataset(key, data=value) + + +def save_array_dict(save_path: epath.PathLike, data: Mapping[str, Any]) -> None: + """Saves a dictionary (possibly nested) to hdf5 file.""" + bio = io.BytesIO() + with h5py.File(bio, "w") as f: + _save_array_dict(f, data) + + with gfile.GFile(save_path, "w") as f: + f.write(bio.getvalue()) diff --git a/swirl_dynamics/data/utils_test.py b/swirl_dynamics/data/hdf5_utils_test.py similarity index 50% rename from swirl_dynamics/data/utils_test.py rename to swirl_dynamics/data/hdf5_utils_test.py index c5859ec..a0bdffa 100644 --- a/swirl_dynamics/data/utils_test.py +++ b/swirl_dynamics/data/hdf5_utils_test.py @@ -17,24 +17,50 @@ from absl import flags from absl.testing import absltest from absl.testing import parameterized -from swirl_dynamics.data import utils +from jax import tree_util +import numpy as np +from swirl_dynamics.data import hdf5_utils FLAGS = flags.FLAGS -class UtilsTest(parameterized.TestCase): +class Hdf5UtilsTest(parameterized.TestCase): + + @parameterized.parameters( + ({"a": 0, "b": 1},), + ({"a": {"b": np.ones((10,)), "c": 2}},), + ({"a": {"b": np.ones((10,)), "c": 2.0 * np.ones((3, 3))}, "d": 8},), + ) + def test_save_and_load_whole_dicts(self, test_input): + tmp_dir = self.create_tempdir().full_path + save_path = os.path.join(tmp_dir, "test.hdf5") + hdf5_utils.save_array_dict(save_path, test_input) + self.assertTrue(os.path.exists(save_path)) + + restored = hdf5_utils.read_all_arrays_as_dict(save_path) + self.assertEqual( + tree_util.tree_flatten(test_input)[1], + tree_util.tree_flatten(restored)[1], + ) + self.assertTrue( + np.all( + tree_util.tree_flatten( + tree_util.tree_map(np.array_equal, test_input, restored) + )[0] + ) + ) @parameterized.parameters( ({"a": 0, "b": 1}, {"a": 0, "b": 1}), ({"a": {"b": 1, "c": 2}}, {"a/b": 1, "a/c": 2}), ) - def test_save_and_load_nparrays_from_hdf5(self, test_input, check_items): + def test_save_and_load_nparrays(self, test_input, check_items): tmp_dir = self.create_tempdir().full_path save_path = os.path.join(tmp_dir, "test.hdf5") - utils.save_dict_to_hdf5(save_path, test_input) + hdf5_utils.save_array_dict(save_path, test_input) self.assertTrue(os.path.exists(save_path)) for key, value in check_items.items(): - (saved,) = utils.read_nparray_from_hdf5(save_path, key) + saved = hdf5_utils.read_single_array(save_path, key) self.assertEqual(saved, value) diff --git a/swirl_dynamics/data/utils.py b/swirl_dynamics/data/utils.py deleted file mode 100644 index 46c680e..0000000 --- a/swirl_dynamics/data/utils.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2023 The swirl_dynamics Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Some utility functions for data i/o and processing.""" - -from collections.abc import Mapping -import io -from typing import Any - -from etils import epath -import h5py -import numpy as np -import tensorflow as tf - -gfile = tf.io.gfile -exists = tf.io.gfile.exists - - -def read_nparray_from_hdf5( - file_path: epath.PathLike, *keys -) -> tuple[np.ndarray, ...]: - """Read specified fields from a given hdf5 file as numpy arrays.""" - if not exists(file_path): - raise FileNotFoundError(f"No data file found at {file_path}") - - with gfile.GFile(file_path, "rb") as f: - with h5py.File(f, "r") as hf: - data = tuple(np.asarray(hf[key]) for key in keys) - return data - - -def _save_dict_to_hdf5(group: h5py.Group, data: Mapping[str, Any]) -> None: - """Save fields to hdf5 groups recursively.""" - for key, value in data.items(): - if isinstance(value, dict): - subgroup = group.create_group(key) - _save_dict_to_hdf5(subgroup, value) - else: - group.create_dataset(key, data=value) - - -def save_dict_to_hdf5( - save_path: epath.PathLike, data: Mapping[str, Any] -) -> None: - """Save a (possibly nested) dictionary to hdf5 file.""" - bio = io.BytesIO() - with h5py.File(bio, "w") as f: - _save_dict_to_hdf5(f, data) - - with gfile.GFile(save_path, "w") as f: - f.write(bio.getvalue()) diff --git a/swirl_dynamics/projects/ergodic/generate_traj.py b/swirl_dynamics/projects/ergodic/generate_traj.py index 7116790..168f156 100644 --- a/swirl_dynamics/projects/ergodic/generate_traj.py +++ b/swirl_dynamics/projects/ergodic/generate_traj.py @@ -27,7 +27,7 @@ import numpy as np from orbax import checkpoint import pandas as pd -from swirl_dynamics.data import utils as data_utils +from swirl_dynamics.data import hdf5_utils from swirl_dynamics.lib.solvers import utils as solver_utils from swirl_dynamics.projects.ergodic import choices import tensorflow as tf @@ -157,7 +157,7 @@ def generate_pred_traj(exps_df, all_steps, dt, trajs, mean=None, std=None): pt += mean del params print("Generated.", end=" ") - data_utils.save_dict_to_hdf5(traj_file, {"pred_traj": pt}) + hdf5_utils.save_array_dict(traj_file, {"pred_traj": pt}) print(f"Saved to file {traj_file}.") cnt += 1 @@ -168,10 +168,8 @@ def main(argv): exp_dir = FLAGS.exp_dir exps_df = parse_dir(exp_dir) dataset_path = exps_df["dataset_path"].unique().tolist()[0] - trajs, tspan = data_utils.read_nparray_from_hdf5( - dataset_path, - "test/u", - "test/t", + trajs, tspan = hdf5_utils.read_arrays_as_tuple( + dataset_path, ("test/u", "test/t") ) all_steps = trajs.shape[1] dt = jnp.mean(jnp.diff(tspan, axis=1)) @@ -185,9 +183,7 @@ def main(argv): trajs = trajs[:, :, ::spatial_downsample, ::spatial_downsample, :] print("Spatial resolution:", trajs.shape[2:-1]) - train_snapshots = data_utils.read_nparray_from_hdf5(dataset_path, "train/u")[ - 0 - ] + train_snapshots = hdf5_utils.read_single_array(dataset_path, "train/u") mean = jnp.mean(train_snapshots, axis=(0, 1)) std = jnp.std(train_snapshots, axis=(0, 1)) del train_snapshots diff --git a/swirl_dynamics/projects/ergodic/utils.py b/swirl_dynamics/projects/ergodic/utils.py index ff42f2e..f6210d5 100644 --- a/swirl_dynamics/projects/ergodic/utils.py +++ b/swirl_dynamics/projects/ergodic/utils.py @@ -21,8 +21,8 @@ import jax import jax.numpy as jnp import matplotlib.pyplot as plt +from swirl_dynamics.data import hdf5_utils from swirl_dynamics.data import tfgrain_transforms as transforms -from swirl_dynamics.data import utils as data_utils from swirl_dynamics.lib.solvers import ode import tensorflow as tf @@ -90,10 +90,8 @@ def create_loader_from_hdf5( mean and std stats (if normalize=True, else dict contains NoneType values). """ - snapshots, tspan = data_utils.read_nparray_from_hdf5( - dataset_path, - f"{split}/u", - f"{split}/t", + snapshots, tspan = hdf5_utils.read_arrays_as_tuple( + dataset_path, (f"{split}/u", f"{split}/t") ) if spatial_downsample_factor > 1: if snapshots.ndim == 3: @@ -116,10 +114,7 @@ def create_loader_from_hdf5( std = normalize_stats["std"] else: if split != "train": - data_for_stats = data_utils.read_nparray_from_hdf5( - dataset_path, - "train/u", - ) + data_for_stats = hdf5_utils.read_single_array(dataset_path, "train/u") else: data_for_stats = snapshots mean = jnp.mean(data_for_stats, axis=(0, 1)) diff --git a/swirl_dynamics/projects/evolve_smoothly/data_pipelines.py b/swirl_dynamics/projects/evolve_smoothly/data_pipelines.py index fe9b675..419b541 100644 --- a/swirl_dynamics/projects/evolve_smoothly/data_pipelines.py +++ b/swirl_dynamics/projects/evolve_smoothly/data_pipelines.py @@ -16,8 +16,8 @@ import grain.tensorflow as tfgrain import numpy as np +from swirl_dynamics.data import hdf5_utils from swirl_dynamics.data import tfgrain_transforms as transforms -from swirl_dynamics.data import utils as data_utils import tensorflow as tf _DEFAULT_LINEAR_RESCALE = transforms.LinearRescale( @@ -57,8 +57,8 @@ def create_batch_decode_pipeline( Returns: A TfGrain dataloader. """ - snapshots, grid = data_utils.read_nparray_from_hdf5( - hdf5_file_path, snapshot_field, grid_field + snapshots, grid = hdf5_utils.read_arrays_as_tuple( + hdf5_file_path, (snapshot_field, grid_field) ) snapshots = np.reshape(snapshots, (-1,) + snapshots.shape[-2:]) # select a subset of snapshots to train @@ -119,8 +119,8 @@ def create_encode_decode_pipeline( Returns: A TfGrain dataloader. """ - snapshots, grid = data_utils.read_nparray_from_hdf5( - hdf5_file_path, snapshot_field, grid_field + snapshots, grid = hdf5_utils.read_arrays_as_tuple( + hdf5_file_path, (snapshot_field, grid_field) ) snapshots = np.reshape(snapshots, (-1,) + snapshots.shape[-2:]) # select a subset of snapshots to train if applicable @@ -190,8 +190,8 @@ def create_latent_dynamics_pipeline( Returns: A TfGrain dataloader. """ - snapshots, tspan, grid = data_utils.read_nparray_from_hdf5( - hdf5_file_path, snapshot_field, tspan_field, grid_field + snapshots, tspan, grid = hdf5_utils.read_arrays_as_tuple( + hdf5_file_path, (snapshot_field, tspan_field, grid_field) ) source = tfgrain.TfInMemoryDataSource.from_dataset( tf.data.Dataset.from_tensor_slices({