Skip to content

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

License

Notifications You must be signed in to change notification settings

lenstronomy/JAXtronomy

Folders and files

NameName
Last commit message
Last commit date

Latest commit

3756adc · Mar 5, 2025
Mar 15, 2024
Mar 5, 2025
Mar 4, 2025
Mar 5, 2025
Mar 15, 2024
Feb 4, 2025
Aug 16, 2022
Aug 16, 2022
Aug 16, 2022
Mar 15, 2024
Mar 4, 2025
May 16, 2024
Feb 2, 2025
Mar 15, 2024
Feb 2, 2025

Repository files navigation

JAXtronomy

https://codecov.io/gh/lenstronomy/JAXtronomy/graph/badge.svg?token=6EJAX8CF62 https://img.shields.io/pypi/v/jaxtronomy?label=PyPI&logo=pypi

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

Disclaimer: This project is still in an early development phase and serves as a skeleton for someone taking the lead on it :)

The goal of this library is to reimplement lenstronomy functionalities in pure JAX to allow for automatic differentiation, GPU acceleration, and batched computations.

Guiding Principles:

  • Strive to be a drop-in replacement for lenstronomy, i.e. provide a close match to the lenstronomy API.
  • Each function/feature will be tested against the reference lenstronomy implementation.
  • This package will aim to be a subset of lenstronomy (i.e. only contains functions with a reference lenstronomy implementation).
  • Implementations should be easy to read and understand.
  • Code should be pip installable on any machine, no compilation required.
  • Any notable differences between the JAX and reference implementations will be clearly documented.

Performance comparison between jaxtronomy and lenstronomy

We compare the runtimes between jaxtronomy and lenstronomy by timing 10,000 function executions. These tests were done on one CPU. We expect the performance boosts to be even higher on GPU. A notebook for runtime comparisons is provided.

LensModel ray-shooting

The table below shows how much faster jaxtronomy is compared to lenstronomy for different deflector profiles and different grid sizes.

Deflector Profile 60x60 grid 180x180 grid
CONVERGENCE 1.5x 6.9x
CSE 5.0x 5.7x
EPL 12.0x 16.5x
EPL (jax) vs EPL_NUMBA 1.3x 2.0x
EPL_Q_PHI 0.3x 0.4x
GAUSSIAN 2.1x 3.3x
GAUSSIAN_POTENTIAL 1.9x 3.1x
HERNQUIST 1.6x 2.7x
HERNQUIST_ELLIPSE_CSE 4.7x 5.4x
LOS 2.9x 7.1x
LOS_MINIMAL 2.8x 7.3x
NFW 2.0x 4.0x
NFW_ELLIPSE_CSE 5.4x 7.0x
NIE 1.4x 1.9x
PJAFFE 1.7x 1.7x
PJAFFE_ELLIPSE_POTENTIAL 3.1x 3.1x
SHEAR 2.1x 5.0x
SIE 1.2x 1.9x
SIS 3.7x 4.3x
SPP 1.4x 2.3x

Note that some profiles' runtime are dependent on function arguments. For example, the EPL profile involves performing a hyp2f1 calculation using a power series expansion. In lenstronomy, the number of terms used depends on how quickly the series converges, whereas in jaxtronomy, the power series always involves a fixed number of terms, which is required for autodifferentiation. In the table, the EPL comparison shows the case where lenstronomy requires the maximum number of terms, and the EPL_Q_PHI comparison shows the case where lenstronomy requires the fewest.

LightModel surface brightness

The table below shows how much faster jaxtronomy is compared to lenstronomy for different source profiles and different grid sizes.

Source Profile 60x60 grid 180x180 grid
CORE_SERSIC 4.4x 15x
GAUSSIAN 3.0x 10x
GAUSSIAN_ELLIPSE 2.7x 15x
MULTI_GAUSSIAN (5 components) 8.0x 18x
MULTI_GAUSSIAN_ELLIPSE (5 components) 8.0x 20x
SERSIC 3.4x 10x
SERSIC_ELLIPSE 3.5x 10x
SERSIC_ELLIPSE_Q_PHI 4.0x 10x
SHAPELETS (n_max=6) 8.7x 3.6x
SHAPELETS (n_max=10) 10.0x 4.0x

Image Convolution

There is no gaussian convolution function in the JAX library. Thus, in jaxtronomy we construct a gaussian pixel kernel, pad the image, and perform an fft convolution which mimics scipy.ndimage.gaussian_filter with mode="nearest".

When the kernel radius is less than 10, jaxtronomy takes about 1.1x to 1.3x longer than lenstronomy to perform a gaussian convolution, and when the kernel size is larger, it takes jaxtronomy 2x as long or more. For pixel kernel fft convolution, performance varies significantly with kernel size and grid size. Further details can be found in the performance comparison notebook.

Related software packages

The following lensing software packages do use JAX-accelerated computing that in part were inspired or made use of lenstronomy functions:

About

JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published