diff --git a/swirl_dynamics/templates/evaluate.py b/swirl_dynamics/templates/evaluate.py index 1f25076..99c376a 100644 --- a/swirl_dynamics/templates/evaluate.py +++ b/swirl_dynamics/templates/evaluate.py @@ -182,7 +182,7 @@ def compute(self) -> jax.Array: def TensorRatio( # pylint: disable=invalid-name - axis: int | tuple[int, ...] | None = None + axis: int | tuple[int, ...] | None = None, ): """Computes the ratio between two aggregated metrics. @@ -462,7 +462,9 @@ def evaluate( rng = jax.random.fold_in(self.rng, self.state.step) batch_agg_update = {} for key, inf_fn in self._compiled_inf_fns.items(): - inference_rng = jax.random.fold_in(rng, hash(key)) + inference_rng = jax.random.fold_in( + rng, np.int32(hash(key) % (2**31 - 1)) # Prevent overflows. + ) pred = inf_fn(batch, inference_rng) batch_collect, batch_res = self._compiled_metrics_compute(pred, batch) collected.collect_batch_result(key, batch_collect)