diff --git a/pyproject.toml b/pyproject.toml index d48a92525..d1014be4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ Changelog = "https://github.com/ott-jax/ott/releases" [project.optional-dependencies] neural = [ "flax>=0.6.6", - "optax>=0.1.1", + "optax>=0.2.4", "diffrax>=0.4.1", ] dev = [