Skip to content

Commit

Permalink
Merge pull request #164 from invrs-io/fix-callback
Browse files Browse the repository at this point in the history
Avoid type promotion when computing permittivity
  • Loading branch information
mfschubert authored Jan 17, 2025
2 parents cc74921 + dfc0df0 commit b477cca
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.bumpversion]
current_version = "v1.4.6"
current_version = "v1.4.7"
commit = true
commit_args = "--no-verify"
tag = true
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]

name = "invrs_gym"
version = "v1.4.6"
version = "v1.4.7"
description = "A collection of inverse design challenges"
keywords = ["topology", "optimization", "jax", "inverse design"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/invrs_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Copyright (c) 2023 The INVRS-IO authors.
"""

__version__ = "v1.4.6"
__version__ = "v1.4.7"
__author__ = "Martin F. Schubert <[email protected]>"

from invrs_gym import challenges as challenges
Expand Down
28 changes: 15 additions & 13 deletions src/invrs_gym/utils/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,19 @@ def permittivity_from_database(
background_extinction_coeff: float,
) -> jnp.ndarray:
"""Return the permittivity for the specified material from the database."""
is_x64 = jax.config.read("jax_enable_x64")

def _refractive_index_fn(wavelength_um: jnp.ndarray) -> onp.ndarray:
numpy_wavelength_um = onp.asarray(wavelength_um)
dtype = onp.promote_types(wavelength_um.dtype, onp.complex64)
try:
epsilon = material.get_epsilon(
wavelength_um=onp.asarray(wavelength_um),
)
epsilon = material.get_epsilon(numpy_wavelength_um)
refractive_index = onp.sqrt(epsilon)
except ri.refractiveindex.NoExtinctionCoefficient:
refractive_index = material.get_refractive_index(
wavelength_um=onp.asarray(wavelength_um),
)
return onp.asarray(
refractive_index, dtype=(onp.complex128 if is_x64 else onp.complex64)
)
refractive_index = material.get_refractive_index(numpy_wavelength_um)
return onp.asarray(refractive_index, dtype=dtype)

result_shape_dtypes = jnp.zeros_like(wavelength_um, dtype=complex)
dtype = jnp.promote_types(wavelength_um.dtype, jnp.complex64)
result_shape_dtypes = jnp.zeros_like(wavelength_um, dtype=dtype)
refractive_index = jax.pure_callback(
_refractive_index_fn, result_shape_dtypes, wavelength_um
)
Expand All @@ -89,7 +85,10 @@ def permittivity_vacuum(
background_extinction_coeff: float = 0.0,
) -> jnp.ndarray:
"""Return the permittivity of vacuum, with optional background extinction coeff."""
return jnp.full(wavelength_um.shape, 1.0 + 1j * background_extinction_coeff)
dtype = jnp.promote_types(wavelength_um.dtype, jnp.complex64)
return jnp.full(
wavelength_um.shape, 1.0 + 1j * background_extinction_coeff, dtype=dtype
)


class PermittivityFn(Protocol):
Expand Down Expand Up @@ -118,9 +117,12 @@ def _permittivity_fn(
wavelength_um: jnp.ndarray,
background_extinction_coeff: float,
) -> jnp.ndarray:
dtype = jnp.promote_types(wavelength_um.dtype, jnp.complex64)
refractive_index = jnp.sqrt(
jnp.interp(wavelength_um, data_wavelength_um, data_permittivity)
)
return (refractive_index + 1j * background_extinction_coeff) ** 2
return jnp.asarray(
(refractive_index + 1j * background_extinction_coeff) ** 2, dtype=dtype
)

PERMITTIVITY_FNS[name] = _permittivity_fn

0 comments on commit b477cca

Please sign in to comment.