Skip to content

Commit

Permalink
test: add benchmarks for interpolated images (#100)
Browse files Browse the repository at this point in the history
* test: add benchmarks for interpolated images

* fix: blacken

* ci: add codspeed config

* fix: add pytest-codspeed to the setup.py

* test: need to list the dep here

* ci: bump the CI for codspeed

* doc: add codspeed badge

* test: make sure to block until ready so that benchmarks work on devices

* fix: block at proper part of code for benchmarks
  • Loading branch information
beckermr authored Jun 18, 2024
1 parent 018f7f3 commit b8348bd
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 3 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/benchmarks.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: benchmarks

env:
PY_COLORS: "1"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

on:
push:
branches:
- main
pull_request:
workflow_dispatch: null

jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.11",]

steps:
- uses: actions/checkout@v3

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-codspeed
python -m pip install .
- name: Run benchmarks
uses: CodSpeedHQ/action@v2
with:
token: ${{ secrets.CODSPEED_TOKEN }}
run: |
git submodule update --init --recursive
pytest -vvs --codspeed
2 changes: 1 addition & 1 deletion .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
python -m pip install pytest pytest-codspeed
python -m pip install .
- name: Test with pytest
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.**

[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](code_of_conduct.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main)
[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.1-4baaaa.svg)](code_of_conduct.md) [![Python package](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml/badge.svg)](https://github.com/GalSim-developers/JAX-GalSim/actions/workflows/python_package.yaml) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![pre-commit.ci status](https://results.pre-commit.ci/badge/github/GalSim-developers/JAX-GalSim/main.svg)](https://results.pre-commit.ci/latest/github/GalSim-developers/JAX-GalSim/main) [![CodSpeed Badge](https://img.shields.io/endpoint?url=https://codspeed.io/badge.json)](https://codspeed.io/GalSim-developers/JAX-GalSim)

**Disclaimer**: This project is still in an early development phase, **please use the [reference GalSim implementation](https://github.com/GalSim-developers/GalSim) for any scientific applications.**

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
"astropy >= 2.0",
"tensorflow-probability >= 0.21.0",
],
tests_require=["pytest"],
tests_require=["pytest", "pytest-codspeed"],
)
35 changes: 35 additions & 0 deletions tests/jax/test_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import jax

import jax_galsim as jgs


def test_benchmarks_interpolated_image_jit_compile(benchmark):
gal = jgs.Gaussian(fwhm=1.2)
im_gal = gal.drawImage(nx=32, ny=32, scale=0.2)
igal = jgs.InterpolatedImage(
im_gal, gsparams=jgs.GSParams(minimum_fft_size=128, maximum_fft_size=128)
)

def f():
return igal.drawImage(nx=32, ny=32, scale=0.2)

benchmark(lambda: jax.jit(f)().array.block_until_ready())


def test_benchmarks_interpolated_image_jit_run(benchmark):
gal = jgs.Gaussian(fwhm=1.2)
im_gal = gal.drawImage(nx=32, ny=32, scale=0.2)
igal = jgs.InterpolatedImage(
im_gal, gsparams=jgs.GSParams(minimum_fft_size=128, maximum_fft_size=128)
)

def f():
return igal.drawImage(nx=32, ny=32, scale=0.2)

jitf = jax.jit(f)

# run once to compile
jitf().array.block_until_ready()

# now benchmark
benchmark(lambda: jitf().array.block_until_ready())

0 comments on commit b8348bd

Please sign in to comment.