diff --git a/.bumpversion.toml b/.bumpversion.toml
index 8e03e5d..9678091 100644
--- a/.bumpversion.toml
+++ b/.bumpversion.toml
@@ -1,5 +1,5 @@
 [tool.bumpversion]
-current_version = "v1.4.7"
+current_version = "v1.4.8"
 commit = true
 commit_args = "--no-verify"
 tag = true
diff --git a/pyproject.toml b/pyproject.toml
index fbcfa70..5737f44 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
 [project]
 
 name = "invrs_gym"
-version = "v1.4.7"
+version = "v1.4.8"
 description = "A collection of inverse design challenges"
 keywords = ["topology", "optimization", "jax", "inverse design"]
 readme = "README.md"
diff --git a/src/invrs_gym/__init__.py b/src/invrs_gym/__init__.py
index 26ed250..6dfbb79 100644
--- a/src/invrs_gym/__init__.py
+++ b/src/invrs_gym/__init__.py
@@ -3,7 +3,7 @@
 Copyright (c) 2023 The INVRS-IO authors.
 """
 
-__version__ = "v1.4.7"
+__version__ = "v1.4.8"
 __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
 
 from invrs_gym import challenges as challenges
diff --git a/src/invrs_gym/utils/materials.py b/src/invrs_gym/utils/materials.py
index cae2f21..eb52b13 100644
--- a/src/invrs_gym/utils/materials.py
+++ b/src/invrs_gym/utils/materials.py
@@ -6,6 +6,7 @@
 import functools
 import pathlib
 import warnings
+from packaging import version
 from typing import Dict, Protocol, Union
 
 import jax
@@ -24,6 +25,12 @@
 VACUUM = "vacuum"  # Permittivity is computed via dedicated function.
 
 
+if version.Version(jax.__version__) > version.Version("0.4.31"):
+    callback = functools.partial(jax.pure_callback, vmap_method="broadcast_all")
+else:
+    callback = functools.partial(jax.pure_callback, vectorized=True)
+
+
 def permittivity(
     material: str | ri.RefractiveIndexMaterial,
     wavelength_um: jnp.ndarray,
@@ -74,7 +81,7 @@ def _refractive_index_fn(wavelength_um: jnp.ndarray) -> onp.ndarray:
 
     dtype = jnp.promote_types(wavelength_um.dtype, jnp.complex64)
     result_shape_dtypes = jnp.zeros_like(wavelength_um, dtype=dtype)
-    refractive_index = jax.pure_callback(
+    refractive_index = callback(
         _refractive_index_fn, result_shape_dtypes, wavelength_um
     )
     return (refractive_index + 1j * background_extinction_coeff) ** 2