Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666896095
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Aug 23, 2024
1 parent 456cab2 commit 25a1a23
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions swirl_dynamics/templates/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def compute(self) -> jax.Array:


def TensorRatio( # pylint: disable=invalid-name
axis: int | tuple[int, ...] | None = None
axis: int | tuple[int, ...] | None = None,
):
"""Computes the ratio between two aggregated metrics.
Expand Down Expand Up @@ -462,7 +462,9 @@ def evaluate(
rng = jax.random.fold_in(self.rng, self.state.step)
batch_agg_update = {}
for key, inf_fn in self._compiled_inf_fns.items():
inference_rng = jax.random.fold_in(rng, hash(key))
inference_rng = jax.random.fold_in(
rng, np.int32(hash(key) % (2**31 - 1)) # Prevent overflows.
)
pred = inf_fn(batch, inference_rng)
batch_collect, batch_res = self._compiled_metrics_compute(pred, batch)
collected.collect_batch_result(key, batch_collect)
Expand Down

0 comments on commit 25a1a23

Please sign in to comment.