Skip to content

Commit

Permalink
Merge pull request #65 from invrs-io/x64
Browse files Browse the repository at this point in the history
ensure x64 types
  • Loading branch information
mfschubert authored Dec 22, 2024
2 parents 18dd393 + 07f6883 commit c89c92d
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/invrs_leaderboard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import jax
import jax.numpy as jnp
import numpy as onp
from jax import tree_util
from invrs_gym import challenges
from invrs_gym.challenges.bayer import challenge as bayer_challenge
from invrs_gym.challenges.diffract import metagrating_challenge, splitter_challenge
Expand Down Expand Up @@ -140,6 +141,9 @@ def evaluate_solutions_to_challenge(
with jax.default_device(jax.devices("cpu")[0]):

def evaluation_fn(params):
params = tree_util.tree_map(
lambda x: x.astype(jnp.promote_types(x.dtype, jnp.float64)), params
)
response, aux = challenge.component.response(params)
metrics = challenge.metrics(response=response, params=params, aux=aux)
eval_metric = challenge.eval_metric(response)
Expand All @@ -150,6 +154,7 @@ def evaluation_fn(params):

for solution_path, solution in solutions.items():
eval_metric, other_metrics = evaluation_fn(params=solution)
assert eval_metric.dtype == jnp.float64
minimum_width, minimum_spacing = compute_length_scale(solution)
results = {
"path": solution_path,
Expand Down

0 comments on commit c89c92d

Please sign in to comment.