Skip to content

Commit

Permalink
Merge pull request #165 from invrs-io/cb-fallback
Browse files Browse the repository at this point in the history
Use version-specific callback
  • Loading branch information
mfschubert authored Jan 17, 2025
2 parents b477cca + cdcc3bf commit 6d5fd90
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v1.4.7"
current_version = "v1.4.8"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_gym"
version = "v1.4.7"
version = "v1.4.8"
description = "A collection of inverse design challenges"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v1.4.7"
__version__ = "v1.4.8"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_gym import challenges as challenges
Expand Down
9 changes: 8 additions & 1 deletion src/invrs_gym/utils/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools
import pathlib
import warnings
from packaging import version
from typing import Dict, Protocol, Union

import jax
Expand All @@ -24,6 +25,12 @@
VACUUM = "vacuum" # Permittivity is computed via dedicated function.


if version.Version(jax.__version__) > version.Version("0.4.31"):
callback = functools.partial(jax.pure_callback, vmap_method="broadcast_all")
else:
callback = functools.partial(jax.pure_callback, vectorized=True)


def permittivity(
material: str | ri.RefractiveIndexMaterial,
wavelength_um: jnp.ndarray,
Expand Down Expand Up @@ -74,7 +81,7 @@ def _refractive_index_fn(wavelength_um: jnp.ndarray) -> onp.ndarray:

dtype = jnp.promote_types(wavelength_um.dtype, jnp.complex64)
result_shape_dtypes = jnp.zeros_like(wavelength_um, dtype=dtype)
refractive_index = jax.pure_callback(
refractive_index = callback(
_refractive_index_fn, result_shape_dtypes, wavelength_um
)
return (refractive_index + 1j * background_extinction_coeff) ** 2
Expand Down

0 comments on commit 6d5fd90

Please sign in to comment.