You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On the gen3d branch, we are running into issues upgrading JAX to 0.4.31. We want to do this because per the top answer here, starting with JAX 0.4.31, the function jax.map.lax supports a batch_size argument. We would like to include this batch_size parameters at this line. One way to test that this works is to ensure that all the tests in tests/gen3d/ pass after adding batch_size=1000 at the indicated line. More specifically, if this line runs without failing, we should be good to go. [In this blob it looks like we commented this out -- oops! But this test should be uncommented and work.]
When we try changing this line to request jaxlib ==0.4.31, and run pixi install, we get
(gpu) georgematheos@pixi-vm-2:~/b3d$ pixi install
WARN Defined custom mapping channel https://conda.anaconda.org/conda-forge/ is missing from project channels
× failed to solve the conda requirements of 'gpu''linux-64'
╰─▶ Cannot solve the request because of: The following packages are incompatible
├─ pytorch ==2.3.0 cuda12* can be installed with any of the following options:
│ └─ pytorch 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 | 2.3.0 would require
│ └─ cudnn >=8.9.7.29,<9.0a0, which can be installed with any of the following options:
│ └─ cudnn 8.9.7.29
└─ jaxlib ==0.4.31 cuda12* cannot be installed because there are no viable options:
└─ jaxlib 0.4.31 | 0.4.31 | 0.4.31 would require
└─ cudnn >=9.2.1.18,<10.0a0, which cannot be installed because there are no viable options:
└─ cudnn 9.2.1.18 | 9.2.1.18, which conflicts with the versions reported above.
The text was updated successfully, but these errors were encountered:
@eightysteele
On the
gen3d
branch, we are running into issues upgrading JAX to 0.4.31. We want to do this because per the top answer here, starting with JAX 0.4.31, the functionjax.map.lax
supports abatch_size
argument. We would like to include this batch_size parameters at this line. One way to test that this works is to ensure that all the tests intests/gen3d/
pass after addingbatch_size=1000
at the indicated line. More specifically, if this line runs without failing, we should be good to go. [In this blob it looks like we commented this out -- oops! But this test should be uncommented and work.]When we try changing this line to request jaxlib
==0.4.31
, and runpixi install
, we getThe text was updated successfully, but these errors were encountered: