diff --git a/.github/workflows/autoblack.yml b/.github/workflows/autoblack.yml index 58b10247..12d48bc0 100644 --- a/.github/workflows/autoblack.yml +++ b/.github/workflows/autoblack.yml @@ -10,11 +10,11 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.11 - name: Install Black - run: pip install black + run: pip install black==24.4.2 - name: Run black --check . run: black --check . diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index c2b0a6c2..00fecde7 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -45,7 +45,7 @@ jobs: # # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: provision-with-micromamba - uses: mamba-org/provision-with-micromamba@main + uses: mamba-org/setup-micromamba@v1 with: environment-file: environment_all_backends.yml environment-name: torchquad @@ -62,6 +62,7 @@ jobs: - name: pytest coverage comment uses: MishaKav/pytest-coverage-comment@main if: github.event_name == 'pull_request' + continue-on-error: true with: pytest-coverage-path: ./torchquad/tests/pytest-coverage.txt title: Coverage Report diff --git a/.readthedocs.yml b/.readthedocs.yml index 879ddda2..2ebcbd7a 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -5,6 +5,11 @@ # Required version: 2 +build: + os: "ubuntu-22.04" + tools: + python: "mambaforge-22.9" + # Build documentation in the docs/ directory with Sphinx sphinx: configuration: docs/source/conf.py @@ -12,12 +17,5 @@ sphinx: # Optionally build your docs in additional formats such as PDF formats: all -# Optionally set the version of Python and requirements required to build your docs -python: - version: 3.8 - install: - - method: setuptools - path: . - conda: - environment: rtd_environment.yml \ No newline at end of file + environment: rtd_environment.yml diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index ac148beb..9d1c8ff5 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -724,11 +724,11 @@ sample points for both functions: # Integrate the first integrand with the sample points function_values, _ = integrator.evaluate_integrand(integrand1, grid_points) - integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs) + integral1 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain) # Integrate the second integrand with the same sample points function_values, _ = integrator.evaluate_integrand(integrand2, grid_points) - integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs) + integral2 = integrator.calculate_result(function_values, dim, n_per_dim, hs, integration_domain) print(f"Quadrature results: {integral1}, {integral2}") @@ -745,7 +745,7 @@ As an example, here we evaluate a similar integrand many times for different val .. code:: ipython3 def parametrized_integrand(x, a, b): - return torch.sqrt(torch.cos(torch.sin((a + b) * x))) + return torch.sqrt(torch.cos(torch.sin((a + b) * x))) a_params = torch.arange(40) b_params = torch.arange(10, 20) diff --git a/torchquad/integration/boole.py b/torchquad/integration/boole.py index 99e1b22a..a064dff3 100644 --- a/torchquad/integration/boole.py +++ b/torchquad/integration/boole.py @@ -6,7 +6,6 @@ class Boole(NewtonCotes): - """Boole's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas .""" def __init__(self): diff --git a/torchquad/integration/grid_integrator.py b/torchquad/integration/grid_integrator.py index 2243399d..6c574d81 100644 --- a/torchquad/integration/grid_integrator.py +++ b/torchquad/integration/grid_integrator.py @@ -7,6 +7,7 @@ _linspace_with_grads, expand_func_values_and_squeeze_integral, _setup_integration_domain, + _torch_trace_without_warnings, ) @@ -208,8 +209,6 @@ def compiled_integrate(fn, integration_domain): elif backend == "torch": # Torch requires explicit tracing with example inputs. def do_compile(example_integrand): - import torch - # Define traceable first and third steps def step1(integration_domain): grid_points, hs, n_per_dim = self.calculate_grid( @@ -218,7 +217,7 @@ def step1(integration_domain): return ( grid_points, hs, - torch.Tensor([n_per_dim]), + anp.array([n_per_dim], like="torch"), ) # n_per_dim is constant dim = int(integration_domain.shape[0]) @@ -229,7 +228,7 @@ def step3(function_values, hs, integration_domain): ) # Trace the first step - step1 = torch.jit.trace(step1, (integration_domain,)) + step1 = _torch_trace_without_warnings(step1, (integration_domain,)) # Get example input for the third step grid_points, hs, n_per_dim = step1(integration_domain) @@ -241,15 +240,7 @@ def step3(function_values, hs, integration_domain): ) # Trace the third step - # Avoid the warnings about a .grad attribute access of a - # non-leaf Tensor - if hs.requires_grad: - hs = hs.detach() - hs.requires_grad = True - if function_values.requires_grad: - function_values = function_values.detach() - function_values.requires_grad = True - step3 = torch.jit.trace( + step3 = _torch_trace_without_warnings( step3, (function_values, hs, integration_domain) ) diff --git a/torchquad/integration/monte_carlo.py b/torchquad/integration/monte_carlo.py index 3d854b61..432e6242 100644 --- a/torchquad/integration/monte_carlo.py +++ b/torchquad/integration/monte_carlo.py @@ -3,7 +3,11 @@ from loguru import logger from .base_integrator import BaseIntegrator -from .utils import _setup_integration_domain, expand_func_values_and_squeeze_integral +from .utils import ( + _setup_integration_domain, + expand_func_values_and_squeeze_integral, + _torch_trace_without_warnings, +) from .rng import RNG @@ -195,8 +199,6 @@ def compiled_integrate(fn, integration_domain): elif backend == "torch": # Torch requires explicit tracing with example inputs. def do_compile(example_integrand): - import torch - # Define traceable first and third steps def step1(integration_domain): return self.calculate_sample_points( @@ -206,7 +208,9 @@ def step1(integration_domain): step3 = self.calculate_result # Trace the first step (which is non-deterministic) - step1 = torch.jit.trace(step1, (integration_domain,), check_trace=False) + step1 = _torch_trace_without_warnings( + step1, (integration_domain,), check_trace=False + ) # Get example input for the third step sample_points = step1(integration_domain) @@ -215,12 +219,9 @@ def step1(integration_domain): ) # Trace the third step - if function_values.requires_grad: - # Avoid the warning about a .grad attribute access of a - # non-leaf Tensor - function_values = function_values.detach() - function_values.requires_grad = True - step3 = torch.jit.trace(step3, (function_values, integration_domain)) + step3 = _torch_trace_without_warnings( + step3, (function_values, integration_domain) + ) # Define a compiled integrate function def compiled_integrate(fn, integration_domain): diff --git a/torchquad/integration/simpson.py b/torchquad/integration/simpson.py index 68bea5c7..e67ce47b 100644 --- a/torchquad/integration/simpson.py +++ b/torchquad/integration/simpson.py @@ -6,7 +6,6 @@ class Simpson(NewtonCotes): - """Simpson's rule. See https://en.wikipedia.org/wiki/Newton%E2%80%93Cotes_formulas#Closed_Newton%E2%80%93Cotes_formulas .""" def __init__(self): diff --git a/torchquad/integration/utils.py b/torchquad/integration/utils.py index c2ab3fcc..9e78ac4a 100644 --- a/torchquad/integration/utils.py +++ b/torchquad/integration/utils.py @@ -2,6 +2,7 @@ Utility functions for the integrator implementations including extensions for autoray, which are registered when importing this file """ + import sys from pathlib import Path @@ -193,20 +194,11 @@ def _check_integration_domain(integration_domain): raise ValueError("integration_domain.shape[0] needs to be 1 or larger.") if num_bounds != 2: raise ValueError("integration_domain must have 2 values per boundary") - # Skip the values check if an integrator.integrate method is JIT - # compiled with JAX - if any( - nam in type(integration_domain).__name__ for nam in ["Jaxpr", "JVPTracer"] - ): - return dim - boundaries_are_invalid = ( - anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0 - ) - # Skip the values check if an integrator.integrate method is - # compiled with tensorflow.function - if type(boundaries_are_invalid).__name__ == "Tensor": + # The boundary values check does not work if the code is JIT compiled + # with JAX or TensorFlow. + if _is_compiling(integration_domain): return dim - if boundaries_are_invalid: + if anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0: raise ValueError("integration_domain has invalid boundary values") return dim @@ -261,3 +253,49 @@ def wrap(*args, **kwargs): return f(*args, **kwargs) return wrap + + +def _is_compiling(x): + """ + Check if code is currently being compiled with PyTorch, JAX or TensorFlow + + Args: + x (backend tensor): A tensor currently used for computations + Returns: + bool: True if code is currently being compiled, False otherwise + """ + backend = infer_backend(x) + if backend == "jax": + return any(nam in type(x).__name__ for nam in ["Jaxpr", "JVPTracer"]) + if backend == "torch": + import torch + + if hasattr(torch.jit, "is_tracing"): + # We ignore torch.jit.is_scripting() since we do not support + # compilation to TorchScript + return torch.jit.is_tracing() + # torch.jit.is_tracing() is unavailable below PyTorch version 1.11.0 + return type(x.shape[0]).__name__ == "Tensor" + if backend == "tensorflow": + import tensorflow as tf + + if hasattr(tf, "is_symbolic_tensor"): + return tf.is_symbolic_tensor(x) + # tf.is_symbolic_tensor() is unavailable below TensorFlow version 2.13.0 + return type(x).__name__ == "Tensor" + return False + + +def _torch_trace_without_warnings(*args, **kwargs): + """Execute `torch.jit.trace` on the passed arguments and hide tracer warnings + + PyTorch can show warnings about traces being potentially incorrect because + the Python3 control flow is not completely recorded. + This function can be used to hide the warnings in situations where they are + false positives. + """ + import torch + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + return torch.jit.trace(*args, **kwargs) diff --git a/torchquad/tests/integration_test_functions.py b/torchquad/tests/integration_test_functions.py index 9e2c63e5..e5284d2f 100644 --- a/torchquad/tests/integration_test_functions.py +++ b/torchquad/tests/integration_test_functions.py @@ -44,7 +44,7 @@ def __init__( """ self.integration_dim = integration_dim self.expected_result = expected_result - if type(integrand_dims) == int or hasattr(integrand_dims, "__len__"): + if type(integrand_dims) is int or hasattr(integrand_dims, "__len__"): self.integrand_dims = integrand_dims else: ValueError( diff --git a/torchquad/tests/utils_integration_test.py b/torchquad/tests/utils_integration_test.py index 9a3b02ba..ad35de41 100644 --- a/torchquad/tests/utils_integration_test.py +++ b/torchquad/tests/utils_integration_test.py @@ -12,6 +12,7 @@ _linspace_with_grads, _add_at_indices, _setup_integration_domain, + _is_compiling, ) from utils.set_precision import set_precision from utils.enable_cuda import enable_cuda @@ -196,11 +197,48 @@ def test_setup_integration_domain(): _run_tests_with_all_backends(_run_setup_integration_domain_tests) +def _run_is_compiling_tests(dtype_name, backend): + """ + Test _is_compiling with the given dtype and numerical backend + """ + dtype = to_backend_dtype(dtype_name, like=backend) + x = anp.array([[0.0, 1.0], [1.0, 2.0]], dtype=dtype, like=backend) + assert not _is_compiling( + x + ), f"_is_compiling has a false positive with backend {backend}" + + def check_compiling(x): + assert _is_compiling( + x + ), f"_is_compiling has a false negative with backend {backend}" + return x + + if backend == "jax": + import jax + + jax.jit(check_compiling)(x) + elif backend == "torch": + import torch + + torch.jit.trace(check_compiling, (x,), check_trace=False)(x) + elif backend == "tensorflow": + import tensorflow as tf + + tf.function(check_compiling, jit_compile=True)(x) + tf.function(check_compiling, jit_compile=False)(x) + + +def test_is_compiling(): + """Test _is_compiling with all possible configurations""" + _run_tests_with_all_backends(_run_is_compiling_tests) + + if __name__ == "__main__": try: # used to run this test individually test_linspace_with_grads() test_add_at_indices() test_setup_integration_domain() + test_is_compiling() except KeyboardInterrupt: pass