Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583417368
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Nov 17, 2023
1 parent 61afc04 commit 58b12fa
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 92 deletions.
97 changes: 97 additions & 0 deletions swirl_dynamics/data/hdf5_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2023 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility functions for hdf5 file reading and writing."""

from collections.abc import Mapping, Sequence
import io
from typing import Any

from etils import epath
import h5py
import numpy as np
import tensorflow as tf

gfile = tf.io.gfile
exists = tf.io.gfile.exists


def read_arrays_as_tuple(
file_path: epath.PathLike, keys: Sequence[str], dtype: Any = np.float32
) -> tuple[np.ndarray, ...]:
"""Reads specified fields from a given file as numpy arrays."""
if not exists(file_path):
raise FileNotFoundError(f"No data file found at {file_path}")

with gfile.GFile(file_path, "rb") as f:
with h5py.File(f, "r") as hf:
data = tuple(np.asarray(hf[key], dtype=dtype) for key in keys)
return data


def read_single_array(
file_path: epath.PathLike, key: str, dtype: Any = np.float32
) -> np.ndarray:
"""Reads a single field from a given file."""
return read_arrays_as_tuple(file_path, [key], dtype)[0]


def _read_group(
group: h5py.Group, array_dtype: Any = np.float32
) -> Mapping[str, Any]:
"""Recursively reads a hdf5 group."""
out = {}
for key in group.keys():
if isinstance(group[key], h5py.Group):
out[key] = _read_group(group[key])
elif isinstance(group[key], h5py.Dataset):
if group[key].shape:
out[key] = np.asarray(group[key], dtype=array_dtype)
else:
out[key] = group[key][()]
else:
raise ValueError(f"Unknown type for key {key}")
return out


def read_all_arrays_as_dict(
file_path: epath.PathLike, array_dtype: Any = np.float32
) -> Mapping[str, Any]:
"""Reads the entire contents of a file as a (possibly nested) dictionary."""
if not exists(file_path):
raise FileNotFoundError(f"No data file found at {file_path}")

with gfile.GFile(file_path, "rb") as f:
with h5py.File(f, "r") as hf:
return _read_group(hf, array_dtype)


def _save_array_dict(group: h5py.Group, data: Mapping[str, Any]) -> None:
"""Saves a nested python dictionary to hdf5 groups recursively."""
for key, value in data.items():
if isinstance(value, dict):
subgroup = group.create_group(key)
_save_array_dict(subgroup, value)
else:
group.create_dataset(key, data=value)


def save_array_dict(save_path: epath.PathLike, data: Mapping[str, Any]) -> None:
"""Saves a dictionary (possibly nested) to hdf5 file."""
bio = io.BytesIO()
with h5py.File(bio, "w") as f:
_save_array_dict(f, data)

with gfile.GFile(save_path, "w") as f:
f.write(bio.getvalue())
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,50 @@
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from swirl_dynamics.data import utils
from jax import tree_util
import numpy as np
from swirl_dynamics.data import hdf5_utils

FLAGS = flags.FLAGS


class UtilsTest(parameterized.TestCase):
class Hdf5UtilsTest(parameterized.TestCase):

@parameterized.parameters(
({"a": 0, "b": 1},),
({"a": {"b": np.ones((10,)), "c": 2}},),
({"a": {"b": np.ones((10,)), "c": 2.0 * np.ones((3, 3))}, "d": 8},),
)
def test_save_and_load_whole_dicts(self, test_input):
tmp_dir = self.create_tempdir().full_path
save_path = os.path.join(tmp_dir, "test.hdf5")
hdf5_utils.save_array_dict(save_path, test_input)
self.assertTrue(os.path.exists(save_path))

restored = hdf5_utils.read_all_arrays_as_dict(save_path)
self.assertEqual(
tree_util.tree_flatten(test_input)[1],
tree_util.tree_flatten(restored)[1],
)
self.assertTrue(
np.all(
tree_util.tree_flatten(
tree_util.tree_map(np.array_equal, test_input, restored)
)[0]
)
)

@parameterized.parameters(
({"a": 0, "b": 1}, {"a": 0, "b": 1}),
({"a": {"b": 1, "c": 2}}, {"a/b": 1, "a/c": 2}),
)
def test_save_and_load_nparrays_from_hdf5(self, test_input, check_items):
def test_save_and_load_nparrays(self, test_input, check_items):
tmp_dir = self.create_tempdir().full_path
save_path = os.path.join(tmp_dir, "test.hdf5")
utils.save_dict_to_hdf5(save_path, test_input)
hdf5_utils.save_array_dict(save_path, test_input)
self.assertTrue(os.path.exists(save_path))
for key, value in check_items.items():
(saved,) = utils.read_nparray_from_hdf5(save_path, key)
saved = hdf5_utils.read_single_array(save_path, key)
self.assertEqual(saved, value)


Expand Down
62 changes: 0 additions & 62 deletions swirl_dynamics/data/utils.py

This file was deleted.

14 changes: 5 additions & 9 deletions swirl_dynamics/projects/ergodic/generate_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import numpy as np
from orbax import checkpoint
import pandas as pd
from swirl_dynamics.data import utils as data_utils
from swirl_dynamics.data import hdf5_utils
from swirl_dynamics.lib.solvers import utils as solver_utils
from swirl_dynamics.projects.ergodic import choices
import tensorflow as tf
Expand Down Expand Up @@ -157,7 +157,7 @@ def generate_pred_traj(exps_df, all_steps, dt, trajs, mean=None, std=None):
pt += mean
del params
print("Generated.", end=" ")
data_utils.save_dict_to_hdf5(traj_file, {"pred_traj": pt})
hdf5_utils.save_array_dict(traj_file, {"pred_traj": pt})
print(f"Saved to file {traj_file}.")
cnt += 1

Expand All @@ -168,10 +168,8 @@ def main(argv):
exp_dir = FLAGS.exp_dir
exps_df = parse_dir(exp_dir)
dataset_path = exps_df["dataset_path"].unique().tolist()[0]
trajs, tspan = data_utils.read_nparray_from_hdf5(
dataset_path,
"test/u",
"test/t",
trajs, tspan = hdf5_utils.read_arrays_as_tuple(
dataset_path, ("test/u", "test/t")
)
all_steps = trajs.shape[1]
dt = jnp.mean(jnp.diff(tspan, axis=1))
Expand All @@ -185,9 +183,7 @@ def main(argv):
trajs = trajs[:, :, ::spatial_downsample, ::spatial_downsample, :]
print("Spatial resolution:", trajs.shape[2:-1])

train_snapshots = data_utils.read_nparray_from_hdf5(dataset_path, "train/u")[
0
]
train_snapshots = hdf5_utils.read_single_array(dataset_path, "train/u")
mean = jnp.mean(train_snapshots, axis=(0, 1))
std = jnp.std(train_snapshots, axis=(0, 1))
del train_snapshots
Expand Down
13 changes: 4 additions & 9 deletions swirl_dynamics/projects/ergodic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from swirl_dynamics.data import hdf5_utils
from swirl_dynamics.data import tfgrain_transforms as transforms
from swirl_dynamics.data import utils as data_utils
from swirl_dynamics.lib.solvers import ode
import tensorflow as tf

Expand Down Expand Up @@ -90,10 +90,8 @@ def create_loader_from_hdf5(
mean and std stats (if normalize=True, else dict
contains NoneType values).
"""
snapshots, tspan = data_utils.read_nparray_from_hdf5(
dataset_path,
f"{split}/u",
f"{split}/t",
snapshots, tspan = hdf5_utils.read_arrays_as_tuple(
dataset_path, (f"{split}/u", f"{split}/t")
)
if spatial_downsample_factor > 1:
if snapshots.ndim == 3:
Expand All @@ -116,10 +114,7 @@ def create_loader_from_hdf5(
std = normalize_stats["std"]
else:
if split != "train":
data_for_stats = data_utils.read_nparray_from_hdf5(
dataset_path,
"train/u",
)
data_for_stats = hdf5_utils.read_single_array(dataset_path, "train/u")
else:
data_for_stats = snapshots
mean = jnp.mean(data_for_stats, axis=(0, 1))
Expand Down
14 changes: 7 additions & 7 deletions swirl_dynamics/projects/evolve_smoothly/data_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import grain.tensorflow as tfgrain
import numpy as np
from swirl_dynamics.data import hdf5_utils
from swirl_dynamics.data import tfgrain_transforms as transforms
from swirl_dynamics.data import utils as data_utils
import tensorflow as tf

_DEFAULT_LINEAR_RESCALE = transforms.LinearRescale(
Expand Down Expand Up @@ -57,8 +57,8 @@ def create_batch_decode_pipeline(
Returns:
A TfGrain dataloader.
"""
snapshots, grid = data_utils.read_nparray_from_hdf5(
hdf5_file_path, snapshot_field, grid_field
snapshots, grid = hdf5_utils.read_arrays_as_tuple(
hdf5_file_path, (snapshot_field, grid_field)
)
snapshots = np.reshape(snapshots, (-1,) + snapshots.shape[-2:])
# select a subset of snapshots to train
Expand Down Expand Up @@ -119,8 +119,8 @@ def create_encode_decode_pipeline(
Returns:
A TfGrain dataloader.
"""
snapshots, grid = data_utils.read_nparray_from_hdf5(
hdf5_file_path, snapshot_field, grid_field
snapshots, grid = hdf5_utils.read_arrays_as_tuple(
hdf5_file_path, (snapshot_field, grid_field)
)
snapshots = np.reshape(snapshots, (-1,) + snapshots.shape[-2:])
# select a subset of snapshots to train if applicable
Expand Down Expand Up @@ -190,8 +190,8 @@ def create_latent_dynamics_pipeline(
Returns:
A TfGrain dataloader.
"""
snapshots, tspan, grid = data_utils.read_nparray_from_hdf5(
hdf5_file_path, snapshot_field, tspan_field, grid_field
snapshots, tspan, grid = hdf5_utils.read_arrays_as_tuple(
hdf5_file_path, (snapshot_field, tspan_field, grid_field)
)
source = tfgrain.TfInMemoryDataSource.from_dataset(
tf.data.Dataset.from_tensor_slices({
Expand Down

0 comments on commit 58b12fa

Please sign in to comment.