diff --git a/chex/_src/fake.py b/chex/_src/fake.py index a6fdc74..7d3d622 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -69,7 +69,7 @@ def set_n_cpu_devices(n: Optional[int] = None) -> None: n = n or FLAGS['chex_n_cpu_devices'].value n_devices = get_n_cpu_devices_from_xla_flags() - cpu_backend = (jax.lib.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access + cpu_backend = (jax._src.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access if cpu_backend is not None and n_devices != n: raise RuntimeError( f'Attempted to set {n} devices, but {n_devices} CPUs already available:'