Skip to content

Commit

Permalink
FEAT: implement use_jit flag for lambdifier functions (#542)
Browse files Browse the repository at this point in the history
* BREAK: enforce keyword argument on `use_cse`
* BREAK: move `max_complexity` to last argument
* ENH: type `jit_compile` decorator with `ParamSpec`
* ENH: warn if backend does not support JIT compilation
* FEAT: implement `use_jit` flag in `create_function()` etc.
* MAINT: remove redundant newlines in docstrings
* MAINT: update kernel Python versions
* MAINT: update lock files
  • Loading branch information
redeboer authored Feb 4, 2025
1 parent 3e918a9 commit 0af0191
Show file tree
Hide file tree
Showing 11 changed files with 429 additions and 310 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ repos:
metadata.vscode
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.3
rev: v0.9.4
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -115,7 +115,7 @@ repos:
- --in-place

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.31.0
rev: 0.31.1
hooks:
- id: check-jsonschema
name: Check CITATION.cff
Expand All @@ -128,7 +128,7 @@ repos:
pass_filenames: false

- repo: https://github.com/streetsidesoftware/cspell-cli
rev: v8.17.1
rev: v8.17.2
hooks:
- id: cspell

Expand All @@ -154,11 +154,11 @@ repos:
- python

- repo: https://github.com/ComPWA/pyright-pre-commit
rev: v1.1.392
rev: v1.1.393
hooks:
- id: pyright

- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.5.25
rev: 0.5.27
hooks:
- id: uv-lock
2 changes: 1 addition & 1 deletion benchmarks/ampform.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def create_function(
return create_parametrized_function(
expression=model.expression.doit(),
parameters=model.parameter_defaults,
max_complexity=max_complexity,
backend=backend,
max_complexity=max_complexity,
)


Expand Down
6 changes: 3 additions & 3 deletions docs/usage/faster-lambdify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,8 @@
"split_function = fast_lambdify(\n",
" expression,\n",
" sorted_symbols,\n",
" max_complexity=100,\n",
" backend=\"numpy\",\n",
" max_complexity=100,\n",
")"
]
},
Expand Down Expand Up @@ -392,8 +392,8 @@
"function = create_parametrized_function(\n",
" expression=model.expression.doit(),\n",
" parameters=model.parameter_defaults,\n",
" max_complexity=100,\n",
" backend=\"numpy\",\n",
" max_complexity=100,\n",
")"
]
}
Expand All @@ -417,7 +417,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"numpy",
"sympy >=1.9", # lambdify cse
"tqdm >=4.24.0", # autonotebook
'typing-extensions; python_version < "3.10"',
]
description = "Python fitter package for multiple computational back-ends"
dynamic = ["version"]
Expand Down
2 changes: 2 additions & 0 deletions src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def from_sympy(
backend: str,
*,
use_cse: bool = True,
use_jit: bool | None = None,
max_complexity: int | None = None,
) -> SympyDataTransformer:
expanded_expressions: dict[str, sp.Expr] = {
Expand All @@ -101,6 +102,7 @@ def from_sympy(
ordered_symbols,
backend,
use_cse=use_cse,
use_jit=use_jit,
max_complexity=max_complexity,
)
functions[variable_name] = PositionalArgumentFunction(
Expand Down
4 changes: 1 addition & 3 deletions src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def create_cached_function(
parameters: Mapping[sp.Symbol, ParameterValue],
backend: str,
free_parameters: Iterable[sp.Symbol],
*,
use_cse: bool = True,
) -> tuple[ParametrizedFunction[DataSample, np.ndarray], DataTransformer]:
"""Create a function and data transformer for cached computations.
Expand All @@ -45,7 +46,6 @@ def create_cached_function(
returned :attr:`.ParametrizedFunction.parameters`.
backend: The computational backend to which in which to express the
input :code:`expression`.
free_parameters: Symbols in the expression that change and should not be cached.
use_cse: See :func:`.create_parametrized_function`.
Expand Down Expand Up @@ -106,12 +106,10 @@ class ChiSquared(Estimator):
a set of free `~.ParametrizedFunction.parameters` :math:`\mathbf{p}`.
domain: Input data-set :math:`\mathbf{x}` of :math:`n` events
:math:`x_i` over which to compute :code:`function` :math:`f_\mathbf{p}`.
observed_values: Observed values :math:`y_i`.
weights: Optional weights :math:`w_i`. Default: :math:`w_i=1`
(unweighted). A common choice is :math:`w_i = 1/\sigma_i^2`, with
:math:`\sigma_i` the uncertainty in each measured value of :math:`y_i`.
backend: Computational backend with which to compute the sum
:math:`\sum_{i=1}^n`.
Expand Down
32 changes: 30 additions & 2 deletions src/tensorwaves/function/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,20 @@
from __future__ import annotations

from functools import partial
from typing import Callable
from typing import TYPE_CHECKING, Callable
from warnings import warn

if TYPE_CHECKING:
import sys
from typing import TypeVar

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

P = ParamSpec("P")
T = TypeVar("T")


def find_function(function_name: str, backend: str) -> Callable:
Expand Down Expand Up @@ -59,7 +72,20 @@ def get_backend_modules(backend: str | tuple | dict) -> str | tuple | dict:
return backend


def jit_compile(backend: str) -> Callable[[Callable], Callable]:
def get_jit_compile_dectorator(
backend: str, use_jit: bool | None
) -> Callable[[Callable[P, T]], Callable[P, T]]:
if use_jit is None:
backends_supporting_jit = {"jax", "numba"}
if backend.lower() in backends_supporting_jit:
return jit_compile(backend)
return lambda x: x
if use_jit:
return jit_compile(backend)
return lambda x: x


def jit_compile(backend: str) -> Callable[[Callable[P, T]], Callable[P, T]]:
backend = backend.lower()
if backend == "jax":
try:
Expand All @@ -75,6 +101,8 @@ def jit_compile(backend: str) -> Callable[[Callable], Callable]:
raise_missing_module_error("numba", extras_require="numba")
return partial(numba.jit, forceobj=True, parallel=True)

msg = f"Backend {backend} does not yet support JIT compilation"
warn(msg, category=UserWarning, stacklevel=3)
return lambda x: x


Expand Down
Loading

0 comments on commit 0af0191

Please sign in to comment.