Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade JAX to 0.4.31 to access jax.lax.map(batch_size=X) #188

Open
georgematheos opened this issue Sep 18, 2024 · 0 comments
Open

Upgrade JAX to 0.4.31 to access jax.lax.map(batch_size=X) #188

georgematheos opened this issue Sep 18, 2024 · 0 comments
Assignees

Comments

@georgematheos
Copy link
Collaborator

@eightysteele

On the gen3d branch, we are running into issues upgrading JAX to 0.4.31. We want to do this because per the top answer here, starting with JAX 0.4.31, the function jax.map.lax supports a batch_size argument. We would like to include this batch_size parameters at this line. One way to test that this works is to ensure that all the tests in tests/gen3d/ pass after adding batch_size=1000 at the indicated line. More specifically, if this line runs without failing, we should be good to go. [In this blob it looks like we commented this out -- oops! But this test should be uncommented and work.]

When we try changing this line to request jaxlib ==0.4.31, and run pixi install, we get

(gpu) georgematheos@pixi-vm-2:~/b3d$ pixi install
 WARN Defined custom mapping channel https://conda.anaconda.org/conda-forge/ is missing from project channels
  × failed to solve the conda requirements of 'gpu' 'linux-64'
  ╰─▶ Cannot solve the request because of: The following packages are incompatible
      ├─ pytorch ==2.3.0 cuda12* can be installed with any of the following options:
      │  └─ pytorch 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 would require
      │     └─ cudnn >=8.9.7.29,<9.0a0, which can be installed with any of the following options:
      │        └─ cudnn 8.9.7.29
      └─ jaxlib ==0.4.31 cuda12* cannot be installed because there are no viable options:
         └─ jaxlib 0.4.31 | 0.4.31 | 0.4.31 would require
            └─ cudnn >=9.2.1.18,<10.0a0, which cannot be installed because there are no viable options:
               └─ cudnn 9.2.1.18 | 9.2.1.18, which conflicts with the versions reported above.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants