diff --git a/.bumpversion.toml b/.bumpversion.toml index 8e03e5d..9678091 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v1.4.7" +current_version = "v1.4.8" commit = true commit_args = "--no-verify" tag = true diff --git a/pyproject.toml b/pyproject.toml index fbcfa70..5737f44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "invrs_gym" -version = "v1.4.7" +version = "v1.4.8" description = "A collection of inverse design challenges" keywords = ["topology", "optimization", "jax", "inverse design"] readme = "README.md" diff --git a/src/invrs_gym/__init__.py b/src/invrs_gym/__init__.py index 26ed250..6dfbb79 100644 --- a/src/invrs_gym/__init__.py +++ b/src/invrs_gym/__init__.py @@ -3,7 +3,7 @@ Copyright (c) 2023 The INVRS-IO authors. """ -__version__ = "v1.4.7" +__version__ = "v1.4.8" __author__ = "Martin F. Schubert <mfschubert@gmail.com>" from invrs_gym import challenges as challenges diff --git a/src/invrs_gym/utils/materials.py b/src/invrs_gym/utils/materials.py index cae2f21..eb52b13 100644 --- a/src/invrs_gym/utils/materials.py +++ b/src/invrs_gym/utils/materials.py @@ -6,6 +6,7 @@ import functools import pathlib import warnings +from packaging import version from typing import Dict, Protocol, Union import jax @@ -24,6 +25,12 @@ VACUUM = "vacuum" # Permittivity is computed via dedicated function. +if version.Version(jax.__version__) > version.Version("0.4.31"): + callback = functools.partial(jax.pure_callback, vmap_method="broadcast_all") +else: + callback = functools.partial(jax.pure_callback, vectorized=True) + + def permittivity( material: str | ri.RefractiveIndexMaterial, wavelength_um: jnp.ndarray, @@ -74,7 +81,7 @@ def _refractive_index_fn(wavelength_um: jnp.ndarray) -> onp.ndarray: dtype = jnp.promote_types(wavelength_um.dtype, jnp.complex64) result_shape_dtypes = jnp.zeros_like(wavelength_um, dtype=dtype) - refractive_index = jax.pure_callback( + refractive_index = callback( _refractive_index_fn, result_shape_dtypes, wavelength_um ) return (refractive_index + 1j * background_extinction_coeff) ** 2