diff --git a/.github/workflows/benchmarks.yaml b/.github/workflows/benchmarks.yaml new file mode 100644 index 00000000..67ec821b --- /dev/null +++ b/.github/workflows/benchmarks.yaml @@ -0,0 +1,45 @@ +name: benchmarks + +env: + PY_COLORS: "1" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + push: + branches: + - main + pull_request: + workflow_dispatch: null + +jobs: + build: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11",] + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-codspeed + python -m pip install . + + - name: Run benchmarks + uses: CodSpeedHQ/action@v2 + with: + token: ${{ secrets.CODSPEED_TOKEN }} + run: | + git submodule update --init --recursive + pytest -vvs --codspeed diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index badaf56f..dcab764e 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -33,7 +33,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest + python -m pip install pytest pytest-codspeed python -m pip install . - name: Test with pytest diff --git a/README.md b/README.md index 128a4851..b40f747c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ **JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.** -[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](code_of_conduct.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](code_of_conduct.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) [![CodSpeed Badge](https://img.shields.io/endpoint?url=https://codspeed.io/badge.json)](https://codspeed.io/GalSim-developers/JAX-GalSim) **Disclaimer**: This project is still in an early development phase, **please use the [reference GalSim implementation](https://github.com/GalSim-developers/GalSim) for any scientific applications.** diff --git a/setup.py b/setup.py index 0e59be8f..f35018ce 100644 --- a/setup.py +++ b/setup.py @@ -16,5 +16,5 @@ "astropy >= 2.0", "tensorflow-probability >= 0.21.0", ], - tests_require=["pytest"], + tests_require=["pytest", "pytest-codspeed"], ) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py new file mode 100644 index 00000000..7702e58b --- /dev/null +++ b/tests/jax/test_benchmarks.py @@ -0,0 +1,35 @@ +import jax + +import jax_galsim as jgs + + +def test_benchmarks_interpolated_image_jit_compile(benchmark): + gal = jgs.Gaussian(fwhm=1.2) + im_gal = gal.drawImage(nx=32, ny=32, scale=0.2) + igal = jgs.InterpolatedImage( + im_gal, gsparams=jgs.GSParams(minimum_fft_size=128, maximum_fft_size=128) + ) + + def f(): + return igal.drawImage(nx=32, ny=32, scale=0.2) + + benchmark(lambda: jax.jit(f)().array.block_until_ready()) + + +def test_benchmarks_interpolated_image_jit_run(benchmark): + gal = jgs.Gaussian(fwhm=1.2) + im_gal = gal.drawImage(nx=32, ny=32, scale=0.2) + igal = jgs.InterpolatedImage( + im_gal, gsparams=jgs.GSParams(minimum_fft_size=128, maximum_fft_size=128) + ) + + def f(): + return igal.drawImage(nx=32, ny=32, scale=0.2) + + jitf = jax.jit(f) + + # run once to compile + jitf().array.block_until_ready() + + # now benchmark + benchmark(lambda: jitf().array.block_until_ready())