- Close the repository
git clone https://github.com/pnnl/JAX-CanVeg/tree/main
- Create the conda virtual environment:
conda env create -f environment.yml
- Activate the virtual environment:
conda activate jax-canveg
- Install JAX either by (for CPU only)
pip install --upgrade "jax[cpu]"
, or by (for GPU support)
# CUDA 12 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# CUDA 11 installation
# Note: wheels only available on linux.
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- Install other packages that are only available under pip:
pip install equinox diffrax optax pre-commit optimistix lineax hydroeval pyproj
pip install -U scikit-learn
pip3 install torch torchvision torchaudio
- Compile the C++ code with pybind11 for generating dispersion matrix (make sure you have a suitable compiler installed):
cd ./src/jax_canveg/physics/energy_fluxes/
# For Unix
g++ -O3 -Wall -shared -std=c++11 -ftemplate-depth=2048 -fPIC $(python3 -m pybind11 --includes) DispersionMatrix.cpp -o dispersion$(python3-config --extension-suffix)
# For MacOS
c++ -O3 -Wall -shared -std=c++11 -ftemplate-depth=2048 -undefined dynamic_lookup $(python3 -m pybind11 --includes) DispersionMatrix.cpp -o dispersion$(python3-config --extension-suffix)
- Add the path of the source code src into the environment variable
PYTHONPATH
.