Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

main -> develop (Release 0.4.1) #209

Merged
merged 22 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/autoblack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.8
- name: Set up Python 3.11
uses: actions/setup-python@v1
with:
python-version: 3.8
python-version: 3.11
- name: Install Black
run: pip install black
run: pip install black==24.4.2
- name: Run black --check .
run: black --check .
3 changes: 2 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: provision-with-micromamba
uses: mamba-org/provision-with-micromamba@main
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment_all_backends.yml
environment-name: torchquad
Expand All @@ -62,6 +62,7 @@ jobs:
- name: pytest coverage comment
uses: MishaKav/pytest-coverage-comment@main
if: github.event_name == 'pull_request'
continue-on-error: true
with:
pytest-coverage-path: ./torchquad/tests/pytest-coverage.txt
title: Coverage Report
Expand Down
14 changes: 6 additions & 8 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@
# Required
version: 2

build:
os: "ubuntu-22.04"
tools:
python: "mambaforge-22.9"

# Build documentation in the docs/ directory with Sphinx
sphinx:
configuration: docs/source/conf.py

# Optionally build your docs in additional formats such as PDF
formats: all

# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.8
install:
- method: setuptools
path: .

conda:
environment: rtd_environment.yml
environment: rtd_environment.yml
6 changes: 3 additions & 3 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,11 @@ sample points for both functions:

# Integrate the first integrand with the sample points
function_values, _ = integrator.evaluate_integrand(integrand1, grid_points)
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs)
integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)

# Integrate the second integrand with the same sample points
function_values, _ = integrator.evaluate_integrand(integrand2, grid_points)
integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs)
integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain)

print(f"Quadrature results: {integral1}, {integral2}")

Expand All @@ -745,7 +745,7 @@ As an example, here we evaluate a similar integrand many times for different val
.. code:: ipython3

def parametrized_integrand(x, a, b):
return torch.sqrt(torch.cos(torch.sin((a + b) * x)))
return torch.sqrt(torch.cos(torch.sin((a + b) * x)))

a_params = torch.arange(40)
b_params = torch.arange(10, 20)
Expand Down
1 change: 0 additions & 1 deletion torchquad/integration/boole.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Boole(NewtonCotes):

"""Boole's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas ."""

def __init__(self):
Expand Down
17 changes: 4 additions & 13 deletions torchquad/integration/grid_integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_linspace_with_grads,
expand_func_values_and_squeeze_integral,
_setup_integration_domain,
_torch_trace_without_warnings,
)


Expand Down Expand Up @@ -208,8 +209,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
grid_points, hs, n_per_dim = self.calculate_grid(
Expand All @@ -218,7 +217,7 @@ def step1(integration_domain):
return (
grid_points,
hs,
torch.Tensor([n_per_dim]),
anp.array([n_per_dim], like="torch"),
) # n_per_dim is constant

dim = int(integration_domain.shape[0])
Expand All @@ -229,7 +228,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the first step
step1 = torch.jit.trace(step1, (integration_domain,))
step1 = _torch_trace_without_warnings(step1, (integration_domain,))

# Get example input for the third step
grid_points, hs, n_per_dim = step1(integration_domain)
Expand All @@ -241,15 +240,7 @@ def step3(function_values, hs, integration_domain):
)

# Trace the third step
# Avoid the warnings about a .grad attribute access of a
# non-leaf Tensor
if hs.requires_grad:
hs = hs.detach()
hs.requires_grad = True
if function_values.requires_grad:
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(
step3 = _torch_trace_without_warnings(
step3, (function_values, hs, integration_domain)
)

Expand Down
21 changes: 11 additions & 10 deletions torchquad/integration/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from loguru import logger

from .base_integrator import BaseIntegrator
from .utils import _setup_integration_domain, expand_func_values_and_squeeze_integral
from .utils import (
_setup_integration_domain,
expand_func_values_and_squeeze_integral,
_torch_trace_without_warnings,
)
from .rng import RNG


Expand Down Expand Up @@ -195,8 +199,6 @@ def compiled_integrate(fn, integration_domain):
elif backend == "torch":
# Torch requires explicit tracing with example inputs.
def do_compile(example_integrand):
import torch

# Define traceable first and third steps
def step1(integration_domain):
return self.calculate_sample_points(
Expand All @@ -206,7 +208,9 @@ def step1(integration_domain):
step3 = self.calculate_result

# Trace the first step (which is non-deterministic)
step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False)
step1 = _torch_trace_without_warnings(
step1, (integration_domain,), check_trace=False
)

# Get example input for the third step
sample_points = step1(integration_domain)
Expand All @@ -215,12 +219,9 @@ def step1(integration_domain):
)

# Trace the third step
if function_values.requires_grad:
# Avoid the warning about a .grad attribute access of a
# non-leaf Tensor
function_values = function_values.detach()
function_values.requires_grad = True
step3 = torch.jit.trace(step3, (function_values, integration_domain))
step3 = _torch_trace_without_warnings(
step3, (function_values, integration_domain)
)

# Define a compiled integrate function
def compiled_integrate(fn, integration_domain):
Expand Down
1 change: 0 additions & 1 deletion torchquad/integration/simpson.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class Simpson(NewtonCotes):

"""Simpson's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas ."""

def __init__(self):
Expand Down
64 changes: 51 additions & 13 deletions torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Utility functions for the integrator implementations including extensions for
autoray, which are registered when importing this file
"""

import sys
from pathlib import Path

Expand Down Expand Up @@ -193,20 +194,11 @@ def _check_integration_domain(integration_domain):
raise ValueError("integration_domain.shape[0] needs to be 1 or larger.")
if num_bounds != 2:
raise ValueError("integration_domain must have 2 values per boundary")
# Skip the values check if an integrator.integrate method is JIT
# compiled with JAX
if any(
nam in type(integration_domain).__name__ for nam in ["Jaxpr", "JVPTracer"]
):
return dim
boundaries_are_invalid = (
anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0
)
# Skip the values check if an integrator.integrate method is
# compiled with tensorflow.function
if type(boundaries_are_invalid).__name__ == "Tensor":
# The boundary values check does not work if the code is JIT compiled
# with JAX or TensorFlow.
if _is_compiling(integration_domain):
return dim
if boundaries_are_invalid:
if anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0:
raise ValueError("integration_domain has invalid boundary values")
return dim

Expand Down Expand Up @@ -261,3 +253,49 @@ def wrap(*args, **kwargs):
return f(*args, **kwargs)

return wrap


def _is_compiling(x):
"""
Check if code is currently being compiled with PyTorch, JAX or TensorFlow

Args:
x (backend tensor): A tensor currently used for computations
Returns:
bool: True if code is currently being compiled, False otherwise
"""
backend = infer_backend(x)
if backend == "jax":
return any(nam in type(x).__name__ for nam in ["Jaxpr", "JVPTracer"])
if backend == "torch":
import torch

if hasattr(torch.jit, "is_tracing"):
# We ignore torch.jit.is_scripting() since we do not support
# compilation to TorchScript
return torch.jit.is_tracing()
# torch.jit.is_tracing() is unavailable below PyTorch version 1.11.0
return type(x.shape[0]).__name__ == "Tensor"
if backend == "tensorflow":
import tensorflow as tf

if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
# tf.is_symbolic_tensor() is unavailable below TensorFlow version 2.13.0
return type(x).__name__ == "Tensor"
return False


def _torch_trace_without_warnings(*args, **kwargs):
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings

PyTorch can show warnings about traces being potentially incorrect because
the Python3 control flow is not completely recorded.
This function can be used to hide the warnings in situations where they are
false positives.
"""
import torch

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
return torch.jit.trace(*args, **kwargs)
2 changes: 1 addition & 1 deletion torchquad/tests/integration_test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
"""
self.integration_dim = integration_dim
self.expected_result = expected_result
if type(integrand_dims) == int or hasattr(integrand_dims, "__len__"):
if type(integrand_dims) is int or hasattr(integrand_dims, "__len__"):
self.integrand_dims = integrand_dims
else:
ValueError(
Expand Down
38 changes: 38 additions & 0 deletions torchquad/tests/utils_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_linspace_with_grads,
_add_at_indices,
_setup_integration_domain,
_is_compiling,
)
from utils.set_precision import set_precision
from utils.enable_cuda import enable_cuda
Expand Down Expand Up @@ -196,11 +197,48 @@ def test_setup_integration_domain():
_run_tests_with_all_backends(_run_setup_integration_domain_tests)


def _run_is_compiling_tests(dtype_name, backend):
"""
Test _is_compiling with the given dtype and numerical backend
"""
dtype = to_backend_dtype(dtype_name, like=backend)
x = anp.array([[0.0, 1.0], [1.0, 2.0]], dtype=dtype, like=backend)
assert not _is_compiling(
x
), f"_is_compiling has a false positive with backend {backend}"

def check_compiling(x):
assert _is_compiling(
x
), f"_is_compiling has a false negative with backend {backend}"
return x

if backend == "jax":
import jax

jax.jit(check_compiling)(x)
elif backend == "torch":
import torch

torch.jit.trace(check_compiling, (x,), check_trace=False)(x)
elif backend == "tensorflow":
import tensorflow as tf

tf.function(check_compiling, jit_compile=True)(x)
tf.function(check_compiling, jit_compile=False)(x)


def test_is_compiling():
"""Test _is_compiling with all possible configurations"""
_run_tests_with_all_backends(_run_is_compiling_tests)


if __name__ == "__main__":
try:
# used to run this test individually
test_linspace_with_grads()
test_add_at_indices()
test_setup_integration_domain()
test_is_compiling()
except KeyboardInterrupt:
pass
Loading