From 238aa7a26bc00882837234a29138749c85b5703b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 18 Nov 2024 15:25:44 -0800 Subject: [PATCH] Avoid relying on jax.lib.xla_bridge._backends, as this is private and will soon be deleted There is no public API for this, but this change switches to a private API that is not yet planned for removal. We'd encourage the code authors to not rely on private APIs if at all possible. PiperOrigin-RevId: 697777022 --- chex/_src/fake.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chex/_src/fake.py b/chex/_src/fake.py index a6fdc74..7d3d622 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -69,7 +69,7 @@ def set_n_cpu_devices(n: Optional[int] = None) -> None: n = n or FLAGS['chex_n_cpu_devices'].value n_devices = get_n_cpu_devices_from_xla_flags() - cpu_backend = (jax.lib.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access + cpu_backend = (jax._src.xla_bridge._backends or {}).get('cpu', None) # pylint: disable=protected-access if cpu_backend is not None and n_devices != n: raise RuntimeError( f'Attempted to set {n} devices, but {n_devices} CPUs already available:'