Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redo python package #24

Merged
merged 10 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions .bumpversion.cfg

This file was deleted.

22 changes: 22 additions & 0 deletions .bumpversion.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[tool.bumpversion]
current_version = "v0.3.0"
commit = true
commit_args = "--no-verify"
tag = true
tag_name = "{new_version}"
allow_dirty = true
parse = "(?P<major>\\d+)\\.(?P<minor>\\d+)\\.(?P<patch>\\d+)(\\.(?P<dev>dev)\\d+\\+[-_a-zA-Z0-9]+)?"
serialize = [
"v{major}.{minor}.{patch}.{dev}{distance_to_latest_tag}+{short_branch_name}",
"v{major}.{minor}.{patch}"
]
message = "Version updated from {current_version} to {new_version}"

[[tool.bumpversion.files]]
filename = "pyproject.toml"

[[tool.bumpversion.files]]
filename = "README.md"

[[tool.bumpversion.files]]
filename = "src/agjax/__init__.py"
84 changes: 49 additions & 35 deletions .github/workflows/test_code.yml → .github/workflows/build-ci.yml
Original file line number Diff line number Diff line change
@@ -1,83 +1,97 @@
name: Test pre-commit, code and docs
name: CI

on:
pull_request:
push:
branches:
- main
schedule:
- cron: "0 13 * * 1" # Every Monday at 9AM EST

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml

- name: Test pre-commit hooks
run: |
python -m pip install --upgrade pip
pip install pre-commit
pre-commit run -a
test_code:
needs: [pre-commit]
runs-on: ${{ matrix.os }}
strategy:
max-parallel: 12
matrix:
python-version: ["3.10"]
os: [ubuntu-latest, windows-latest, macos-latest]

validate-types-and-docstrings:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml
- name: Install dependencies

- name: Setup environment
run: |
python -m pip install --upgrade pip
pip install ".[tests,dev]"

- name: mypy type validation
run: |
pip install -e .[tests]
- name: Test with pytest
run: pytest
test_code_coverage:
mypy src

- name: darglint docstring validation
run: |
darglint src --strictness=short --ignore-raise=ValueError

tests:
runs-on: ubuntu-latest
needs: [pre-commit]
steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
python-version: "3.10"
cache: "pip"
cache-dependency-path: pyproject.toml

- name: Setup environment
run: |
pip install -e .[tests]
- name: Test with pytest
python -m pip install --upgrade pip
pip install ".[tests,dev]"

- name: Run Python tests
run: |
pytest --cov=agjax tests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false

test_docs:
runs-on: ubuntu-latest
needs: [pre-commit]

steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: "pip"
cache-dependency-path: pyproject.toml

- name: Install dependencies
run: |
make dev
run: make dev

- name: Test documentation
run: |
make docs
run: make docs
14 changes: 9 additions & 5 deletions .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@ jobs:
runs-on: ubuntu-latest
name: Sphinx docs to gh-pages
steps:
- uses: actions/checkout@v4
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
cache: "pip"
cache-dependency-path: pyproject.toml

- name: Installing the library
shell: bash -l {0}
run: |
make dev
run: make dev

- name: make docs
run: |
make docs
run: make docs

- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: "./docs/_build/html/"

deploy-docs:
needs: build-docs
permissions:
Expand Down
31 changes: 14 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "9260cbc9c84c06022993bfbcc42fdbf0305c5b8e"
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
Expand All @@ -14,44 +14,41 @@ repos:
args: ["--pytest-test-first"]
- id: trailing-whitespace

# Black formatting
- repo: https://github.com/psf/black
rev: "d9b8a6407e2f46304a8d36b18e4a73d8e0613519"
rev: 23.9.1
hooks:
- id: black

# Upgrade python syntax
- repo: https://github.com/asottile/pyupgrade
rev: ddb39ad37166dbc938d853cc77606526a0b1622a
rev: v3.10.1
hooks:
- id: pyupgrade
args: [--py37-plus, --keep-runtime-typing]

# Validate shell scripts
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: 953faa6870f6663ac0121ab4a800f1ce76bca31f
rev: v0.9.0.5
hooks:
- id: shellcheck

# Security linter
- repo: https://github.com/PyCQA/bandit
rev: fe1361fdcc274850d4099885a802f2c9f28aca08
rev: 1.7.5
hooks:
- id: bandit
args: [--exit-zero]
# ignore all tests, not just tests data
exclude: ^tests/

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.0.1"
hooks:
- id: mypy
exclude: ^(docs/|example-plugin/|tests/|fixtures/)
additional_dependencies:
- "pydantic"

- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "6a0ba1854991b693612486cc84a2254de82d071d"
# Linter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.287
hooks:
- id: ruff

# Strip whitespace from notebooks.
- repo: https://github.com/kynan/nbstripout
rev: 0.3.9
rev: 0.6.0
hooks:
- id: nbstripout
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ pylint:
pylint agjax

ruff:
ruff --fix agjax/*.py
ruff --fix src/agjax/*.py
ruff --fix src/agjax/experimental/*.py

git-rm-merged:
git branch -D `git branch --merged | grep -v \* | xargs`
Expand Down
34 changes: 27 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Overview v0.2.0
Agjax is a jax wrapper for autograd-differentiable functions. It allows existing code built with autograd to be used with the jax framework. In particular, agjax allows an arbitrary autograd function to be differentiated using `jax.grad`. Several other function transformations (e.g. compilation via `jax.jit`) are not supported.
# Agjax -- jax wrapper for autograd-differentiable functions.
`v0.3.0`

Agjax allows existing code built with autograd to be used with the jax framework.

In particular, `agjax.wrap_for_jax` allows arbitrary autograd functions ot be differentiated using `jax.grad`. Several other function transformations (e.g. compilation via `jax.jit`) are not supported.

Meanwhile, `agjax.experimental.wrap_for_jax` supports `grad`, `jit`, `vmap`, and `jacrev`. However, it depends on certain under-the-hood behavior by jax, which is not guaranteed to remain unchanged. It also is more restrictive in terms of the valid function signatures of functions to be wrapped: all arguments and outputs must be convertible to valid jax types. (`agjax.wrap_for_jax` also supports non-jax inputs and outputs, e.g. strings.)

## Installation
```
Expand All @@ -13,14 +19,28 @@ Basic usage is as follows:
def fn(x, y):
return x * npa.cos(y)

grad = jax.grad(fn, argnums=(0, 1))(1.0, 0.0)
print(f"grad = {grad}")
```
jax.grad(fn, argnums=(0, 1))(1.0, 0.0)

# (Array(1., dtype=float32), Array(0., dtype=float32))
```
grad = (Array(1., dtype=float32), Array(0., dtype=float32))

The experimental wrapper is similar, but requires that the function outputs and datatypes be specified, simiilar to `jax.pure_callback`.
```python
wrapped_fn = agjax.experimental.wrap_for_jax(
lambda x, y: x * npa.cos(y),
result_shape_dtypes=jnp.ones((5,)),
)

jax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))

# [[ 0.28366217 0. 0. 0. 0. ]
# [ 0. 0.96017027 0. 0. 0. ]
# [ 0. 0. 0.75390226 0. 0. ]
# [ 0. 0. 0. -0.14550003 0. ]
# [ 0. 0. 0. 0. -0.91113025]]
```

Agjax is intended to be quite general, and can support functions with multiple inputs and outputs as well as functions that have nondifferentiable outputs or arguments that cannot be differentiated with respect to. These should be specified using `nondiff_argnums` and `nondiff_outputnums` arguments to `wrap_for_jax`.
Agjax wrappers are intended to be quite general, and can support functions with multiple inputs and outputs as well as functions that have nondifferentiable outputs or arguments that cannot be differentiated with respect to. These should be specified using `nondiff_argnums` and `nondiff_outputnums` arguments. In the experimental wrapper, these must still be jax-convertible types, while in the standard wrapper they may have arbitrary typess.

```python
@functools.partial(
Expand Down
8 changes: 0 additions & 8 deletions agjax/__init__.py

This file was deleted.

27 changes: 26 additions & 1 deletion docs/basic_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"import functools\n",
"import autograd.numpy as npa\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"import agjax"
]
Expand Down Expand Up @@ -46,6 +47,22 @@
"print(f\"grad = {grad}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61d32051",
"metadata": {},
"outputs": [],
"source": [
"wrapped_fn = agjax.experimental.wrap_for_jax(\n",
" lambda x, y: x * npa.cos(y),\n",
" result_shape_dtypes=jnp.ones((5,)),\n",
")\n",
"\n",
"jac = jax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))\n",
"print(f\"jac = \\n{jac}\")"
]
},
{
"cell_type": "markdown",
"id": "a331b186",
Expand Down Expand Up @@ -77,6 +94,14 @@
"print(f\" aux = {aux}\")\n",
"print(f\" grad = {grad}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a8e13f0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -95,7 +120,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
Loading