Skip to content

Commit

Permalink
Merge pull request #211 from esa/develop
Browse files Browse the repository at this point in the history
develop -> release for 0.4.1
  • Loading branch information
gomezzz authored Nov 25, 2024
2 parents e2caa2c + 37fa291 commit 5729761
Show file tree
Hide file tree
Showing 18 changed files with 180 additions and 88 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/autoblack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.8
- name: Set up Python 3.11
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.11
- name: Install Black
run: pip install black
run: pip install black==24.4.2
- name: Run black --check .
run: black --check .
3 changes: 2 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: provision-with-micromamba
uses: mamba-org/provision-with-micromamba@main
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment_all_backends.yml
environment-name: torchquad
Expand All @@ -62,6 +62,7 @@ jobs:
- name: pytest coverage comment
uses: MishaKav/pytest-coverage-comment@main
if: github.event_name == 'pull_request'
continue-on-error: true
with:
pytest-coverage-path: ./torchquad/tests/pytest-coverage.txt
title: Coverage Report
Expand Down
14 changes: 6 additions & 8 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
# Required
version: 2

build:
os: "ubuntu-22.04"
tools:
python: "mambaforge-22.9"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py

# Optionally build your docs in additional formats such as PDF
formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
install:
- method: setuptools
path: .

conda:
environment: rtd_environment.yml
environment: rtd_environment.yml
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Note also that installing PyTorch with *pip* may **not** set it up with CUDA sup
Here are installation instructions for other numerical backends:
```sh
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
pip install "jax[cuda]>=0.4.17" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
conda install "numpy>=1.19.5" -c conda-forge
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Here are installation instructions for other numerical backends:
.. code-block:: bash
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
pip install "jax[cuda]>=0.4.17" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
conda install "numpy>=1.19.5" -c conda-forge
More installation instructions for numerical backends can be found in
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,11 @@ sample points for both functions:
# Integrate the first integrand with the sample points
function_values, _ = integrator.evaluate_integrand(integrand1, grid_points)
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs)
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)
# Integrate the second integrand with the same sample points
function_values, _ = integrator.evaluate_integrand(integrand2, grid_points)
integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs)
integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)
print(f"Quadrature results: {integral1}, {integral2}")
Expand All @@ -745,7 +745,7 @@ As an example, here we evaluate a similar integrand many times for different val
.. code:: ipython3
def parametrized_integrand(x, a, b):
return torch.sqrt(torch.cos(torch.sin((a + b) * x)))
return torch.sqrt(torch.cos(torch.sin((a + b) * x)))
a_params = torch.arange(40)
b_params = torch.arange(10, 20)
Expand Down
6 changes: 3 additions & 3 deletions environment_all_backends.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- loguru>=0.5.3
- matplotlib>=3.3.3
- pytest>=6.2.1
- python>=3.8
- python==3.12
- scipy>=1.6.0
- sphinx>=3.4.3
- sphinx_rtd_theme>=0.5.1
Expand All @@ -16,9 +16,9 @@ dependencies:
- numpy>=1.19.5
- cudatoolkit>=11.1
- pytorch>=1.9 # CPU version
- tensorflow>=2.10.0 # CPU version
# jaxlib with CUDA support is not available for conda
- pip:
- --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
- jax[cpu]>=0.2.22 # this will only work on linux. for win see e.g. https://github.com/cloudhan/jax-windows-builder
- tensorflow>=2.18.0 # CPU version
- jax[cpu]>=0.4.17 # this will only work on linux. for win see e.g. https://github.com/cloudhan/jax-windows-builder
# CPU version
7 changes: 6 additions & 1 deletion torchquad/integration/base_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,13 @@ def evaluate_integrand(fn, points, weights=None, args=None):
len(result.shape) > 1
): # if the the integrand is multi-dimensional, we need to reshape/repeat weights so they can be broadcast in the *=
integrand_shape = anp.array(
result.shape[1:], like=infer_backend(points)
[
dim if isinstance(dim, int) else dim.as_list()
for dim in result.shape[1:]
],
like=infer_backend(points),
)

weights = anp.repeat(
anp.expand_dims(weights, axis=1), anp.prod(integrand_shape)
).reshape((weights.shape[0], *(integrand_shape)))
Expand Down
1 change: 0 additions & 1 deletion torchquad/integration/boole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Boole(NewtonCotes):

"""Boole's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas ."""

def __init__(self):
Expand Down
17 changes: 4 additions & 13 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_linspace_with_grads,
expand_func_values_and_squeeze_integral,
_setup_integration_domain,
_torch_trace_without_warnings,
)


Expand Down Expand Up @@ -208,8 +209,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
grid_points, hs, n_per_dim = self.calculate_grid(
Expand All @@ -218,7 +217,7 @@ def step1(integration_domain):
return (
grid_points,
hs,
torch.Tensor([n_per_dim]),
anp.array([n_per_dim], like="torch"),
) # n_per_dim is constant

dim = int(integration_domain.shape[0])
Expand All @@ -229,7 +228,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the first step
step1 = torch.jit.trace(step1, (integration_domain,))
step1 = _torch_trace_without_warnings(step1, (integration_domain,))

# Get example input for the third step
grid_points, hs, n_per_dim = step1(integration_domain)
Expand All @@ -241,15 +240,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the third step
# Avoid the warnings about a .grad attribute access of a
# non-leaf Tensor
if hs.requires_grad:
hs = hs.detach()
hs.requires_grad = True
if function_values.requires_grad:
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(
step3 = _torch_trace_without_warnings(
step3, (function_values, hs, integration_domain)
)

Expand Down
21 changes: 11 additions & 10 deletions torchquad/integration/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from loguru import logger

from .base_integrator import BaseIntegrator
from .utils import _setup_integration_domain, expand_func_values_and_squeeze_integral
from .utils import (
_setup_integration_domain,
expand_func_values_and_squeeze_integral,
_torch_trace_without_warnings,
)
from .rng import RNG


Expand Down Expand Up @@ -195,8 +199,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
return self.calculate_sample_points(
Expand All @@ -206,7 +208,9 @@ def step1(integration_domain):
step3 = self.calculate_result

# Trace the first step (which is non-deterministic)
step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False)
step1 = _torch_trace_without_warnings(
step1, (integration_domain,), check_trace=False
)

# Get example input for the third step
sample_points = step1(integration_domain)
Expand All @@ -215,12 +219,9 @@ def step1(integration_domain):
)

# Trace the third step
if function_values.requires_grad:
# Avoid the warning about a .grad attribute access of a
# non-leaf Tensor
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(step3, (function_values, integration_domain))
step3 = _torch_trace_without_warnings(
step3, (function_values, integration_domain)
)

# Define a compiled integrate function
def compiled_integrate(fn, integration_domain):
Expand Down
1 change: 0 additions & 1 deletion torchquad/integration/simpson.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Simpson(NewtonCotes):

"""Simpson's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas ."""

def __init__(self):
Expand Down
81 changes: 60 additions & 21 deletions torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Utility functions for the integrator implementations including extensions for
autoray, which are registered when importing this file
"""

import sys
from pathlib import Path

Expand Down Expand Up @@ -138,20 +139,21 @@ def _setup_integration_domain(dim, integration_domain, backend):
# Get a globally default backend
backend = _get_default_backend()
dtype_arg = _get_precision(backend)
if dtype_arg is not None:
# For NumPy and Tensorflow there is no global dtype, so set the
# configured default dtype here
integration_domain = anp.array(
integration_domain, like=backend, dtype=dtype_arg
)
else:
integration_domain = anp.array(integration_domain, like=backend)
if backend == "tensorflow":
import tensorflow as tf

dtype_arg = dtype_arg or tf.keras.backend.floatx()

integration_domain = anp.array(
integration_domain, like=backend, dtype=dtype_arg
)

if integration_domain.shape != (dim, 2):
raise ValueError(
"The integration domain has an unexpected shape. "
f"Expected {(dim, 2)}, got {integration_domain.shape}"
)

return integration_domain


Expand Down Expand Up @@ -193,20 +195,11 @@ def _check_integration_domain(integration_domain):
raise ValueError("integration_domain.shape[0] needs to be 1 or larger.")
if num_bounds != 2:
raise ValueError("integration_domain must have 2 values per boundary")
# Skip the values check if an integrator.integrate method is JIT
# compiled with JAX
if any(
nam in type(integration_domain).__name__ for nam in ["Jaxpr", "JVPTracer"]
):
return dim
boundaries_are_invalid = (
anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0
)
# Skip the values check if an integrator.integrate method is
# compiled with tensorflow.function
if type(boundaries_are_invalid).__name__ == "Tensor":
# The boundary values check does not work if the code is JIT compiled
# with JAX or TensorFlow.
if _is_compiling(integration_domain):
return dim
if boundaries_are_invalid:
if anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0:
raise ValueError("integration_domain has invalid boundary values")
return dim

Expand Down Expand Up @@ -261,3 +254,49 @@ def wrap(*args, **kwargs):
return f(*args, **kwargs)

return wrap


def _is_compiling(x):
"""
Check if code is currently being compiled with PyTorch, JAX or TensorFlow
Args:
x (backend tensor): A tensor currently used for computations
Returns:
bool: True if code is currently being compiled, False otherwise
"""
backend = infer_backend(x)
if backend == "jax":
return any(nam in type(x).__name__ for nam in ["Jaxpr", "JVPTracer"])
if backend == "torch":
import torch

if hasattr(torch.jit, "is_tracing"):
# We ignore torch.jit.is_scripting() since we do not support
# compilation to TorchScript
return torch.jit.is_tracing()
# torch.jit.is_tracing() is unavailable below PyTorch version 1.11.0
return type(x.shape[0]).__name__ == "Tensor"
if backend == "tensorflow":
import tensorflow as tf

if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
# tf.is_symbolic_tensor() is unavailable below TensorFlow version 2.13.0
return type(x).__name__ == "Tensor"
return False


def _torch_trace_without_warnings(*args, **kwargs):
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings
PyTorch can show warnings about traces being potentially incorrect because
the Python3 control flow is not completely recorded.
This function can be used to hide the warnings in situations where they are
false positives.
"""
import torch

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return torch.jit.trace(*args, **kwargs)
4 changes: 2 additions & 2 deletions torchquad/tests/integration_test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
"""
self.integration_dim = integration_dim
self.expected_result = expected_result
if type(integrand_dims) == int or hasattr(integrand_dims, "__len__"):
if type(integrand_dims) is int or hasattr(integrand_dims, "__len__"):
self.integrand_dims = integrand_dims
else:
ValueError(
Expand Down Expand Up @@ -214,7 +214,7 @@ def _poly(self, x):
# Tensorflow does not automatically cast float32 to complex128,
# so we do it here explicitly.
assert self.is_complex
exponentials = anp.cast(exponentials, self.coeffs.dtype)
exponentials = exponentials.astype(self.coeffs.dtype)

# multiply by coefficients
exponentials = anp.multiply(exponentials, self.coeffs)
Expand Down
Loading

0 comments on commit 5729761

Please sign in to comment.