diff --git a/src/mici/adapters.py b/src/mici/adapters.py index a099016..372afed 100644 --- a/src/mici/adapters.py +++ b/src/mici/adapters.py @@ -5,24 +5,24 @@ from abc import ABC, abstractmethod from math import exp, log from typing import TYPE_CHECKING + import numpy as np -from mici.errors import IntegratorError, AdaptationError -from mici.matrices import PositiveDiagonalMatrix, DensePositiveDefiniteMatrix + +from mici.errors import AdaptationError, IntegratorError +from mici.matrices import DensePositiveDefiniteMatrix, PositiveDiagonalMatrix if TYPE_CHECKING: - from typing import Collection, Optional, Iterable, Union + from typing import Collection, Iterable, Optional, Union + from numpy.random import Generator from numpy.typing import ArrayLike + from mici.integrators import Integrator from mici.states import ChainState from mici.systems import System from mici.transitions import Transition - from mici.types import ( - AdaptationStatisticFunction, - AdapterState, - ReducerFunction, - TransitionStatistics, - ) + from mici.types import (AdaptationStatisticFunction, AdapterState, + ReducerFunction, TransitionStatistics) class Adapter(ABC): @@ -486,7 +486,7 @@ def finalize( mean_est /= n_iter var_est += adapt_state["sum_diff_sq"] var_est += ( - mean_diff ** 2 * (adapt_state["iter"] * n_iter_prev) / n_iter + mean_diff**2 * (adapt_state["iter"] * n_iter_prev) / n_iter ) if n_iter < 2: raise AdaptationError( diff --git a/src/mici/autodiff.py b/src/mici/autodiff.py index c77a971..46d050f 100644 --- a/src/mici/autodiff.py +++ b/src/mici/autodiff.py @@ -3,6 +3,7 @@ from __future__ import annotations from typing import TYPE_CHECKING + import mici.autograd_wrapper as autograd_wrapper if TYPE_CHECKING: diff --git a/src/mici/autograd_wrapper.py b/src/mici/autograd_wrapper.py index 2bbb471..99199f3 100644 --- a/src/mici/autograd_wrapper.py +++ b/src/mici/autograd_wrapper.py @@ -7,24 +7,19 @@ AUTOGRAD_AVAILABLE = True try: - from autograd.wrap_util import unary_to_nary + import autograd.numpy as np from autograd.builtins import tuple as atuple from autograd.core import make_vjp from autograd.extend import vspace - import autograd.numpy as np + from autograd.wrap_util import unary_to_nary except ImportError: AUTOGRAD_AVAILABLE = False if TYPE_CHECKING: from typing import Callable - from mici.types import ( - ScalarLike, - ArrayLike, - ScalarFunction, - ArrayFunction, - MatrixHessianProduct, - MatrixTressianProduct, - ) + + from mici.types import (ArrayFunction, ArrayLike, MatrixHessianProduct, + MatrixTressianProduct, ScalarFunction, ScalarLike) def _wrapped_unary_to_nary(func: Callable) -> Callable: diff --git a/src/mici/integrators.py b/src/mici/integrators.py index ca725ab..91fac2e 100644 --- a/src/mici/integrators.py +++ b/src/mici/integrators.py @@ -4,21 +4,21 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING + import numpy as np from numpy.typing import ArrayLike -from mici.errors import NonReversibleStepError, AdaptationError -from mici.solvers import ( - maximum_norm, - solve_fixed_point_direct, - solve_projection_onto_manifold_newton, - FixedPointSolver, - ProjectionSolver, -) + +from mici.errors import AdaptationError, NonReversibleStepError +from mici.solvers import (FixedPointSolver, ProjectionSolver, maximum_norm, + solve_fixed_point_direct, + solve_projection_onto_manifold_newton) if TYPE_CHECKING: from typing import Any, Callable, Optional, Sequence + from mici.states import ChainState - from mici.systems import ConstrainedTractableFlowSystem, System, TractableFlowSystem + from mici.systems import (ConstrainedTractableFlowSystem, System, + TractableFlowSystem) from mici.types import NormFunction diff --git a/src/mici/interop.py b/src/mici/interop.py index 29cd50d..89f6ecc 100644 --- a/src/mici/interop.py +++ b/src/mici/interop.py @@ -4,15 +4,18 @@ import importlib import os -import mici -import numpy as np from typing import TYPE_CHECKING +import numpy as np + +import mici + if TYPE_CHECKING: from typing import Literal, Optional, Union - from numpy.typing import ArrayLike + import arviz import pymc + from numpy.typing import ArrayLike def convert_to_inference_data( diff --git a/src/mici/matrices.py b/src/mici/matrices.py index b620e87..a2e397b 100644 --- a/src/mici/matrices.py +++ b/src/mici/matrices.py @@ -5,15 +5,19 @@ import abc import numbers from typing import TYPE_CHECKING + import numpy as np -from mici.errors import LinAlgError import numpy.linalg as nla import scipy.linalg as sla + +from mici.errors import LinAlgError from mici.utils import hash_array if TYPE_CHECKING: from typing import Iterable, Literal, Optional, Tuple, Union + from numpy.typing import NDArray + from mici.types import MatrixLike, ScalarLike @@ -648,7 +652,7 @@ def grad_log_abs_det(self): return self.shape[0] / self._scalar def grad_quadratic_form_inv(self, vector: NDArray) -> float: - return -np.sum(vector ** 2) / self._scalar ** 2 + return -np.sum(vector**2) / self._scalar**2 def __str__(self) -> str: return f"(shape={self.shape}, scalar={self._scalar})" @@ -681,7 +685,7 @@ def _construct_inv(self) -> PositiveScaledIdentityMatrix: return PositiveScaledIdentityMatrix(1 / self._scalar, self.shape[0]) def _construct_sqrt(self) -> PositiveScaledIdentityMatrix: - return PositiveScaledIdentityMatrix(self._scalar ** 0.5, self.shape[0]) + return PositiveScaledIdentityMatrix(self._scalar**0.5, self.shape[0]) class DiagonalMatrix(SymmetricMatrix, DifferentiableMatrix, ImplicitArrayMatrix): @@ -766,7 +770,7 @@ def _construct_inv(self) -> PositiveDiagonalMatrix: return PositiveDiagonalMatrix(1.0 / self.diagonal) def _construct_sqrt(self) -> PositiveDiagonalMatrix: - return PositiveDiagonalMatrix(self.diagonal ** 0.5) + return PositiveDiagonalMatrix(self.diagonal**0.5) def _make_array_triangular(array: NDArray, lower: bool) -> NDArray: @@ -1042,7 +1046,7 @@ def __init__( def _scalar_multiply(self, scalar: ScalarLike) -> TriangularFactoredDefiniteMatrix: if scalar > 0: return TriangularFactoredPositiveDefiniteMatrix( - factor=scalar ** 0.5 * self.factor + factor=scalar**0.5 * self.factor ) else: return super()._scalar_multiply(scalar) @@ -1542,7 +1546,7 @@ def _construct_inv(self) -> EigendecomposedPositiveDefiniteMatrix: return EigendecomposedPositiveDefiniteMatrix(self.eigvec, 1 / self.eigval) def _construct_sqrt(self) -> EigendecomposedPositiveDefiniteMatrix: - return EigendecomposedPositiveDefiniteMatrix(self.eigvec, self.eigval ** 0.5) + return EigendecomposedPositiveDefiniteMatrix(self.eigvec, self.eigval**0.5) class SoftAbsRegularizedPositiveDefiniteMatrix( diff --git a/src/mici/progressbars.py b/src/mici/progressbars.py index 817c305..1281fa5 100644 --- a/src/mici/progressbars.py +++ b/src/mici/progressbars.py @@ -225,7 +225,7 @@ def __init__( @property def description(self) -> str: - """"Description of task being tracked.""" + """Description of task being tracked.""" return self._description @property diff --git a/src/mici/samplers.py b/src/mici/samplers.py index a15145a..05ac1db 100644 --- a/src/mici/samplers.py +++ b/src/mici/samplers.py @@ -2,62 +2,49 @@ from __future__ import annotations +import logging import os import queue +import signal +import tempfile from contextlib import ExitStack, contextmanager, nullcontext from pathlib import Path from pickle import PicklingError -import logging -import tempfile -import signal -from warnings import warn from typing import TYPE_CHECKING, NamedTuple +from warnings import warn + import numpy as np from numpy.random import default_rng -from mici.transitions import ( - IndependentMomentumTransition, - MetropolisRandomIntegrationTransition, - MetropolisStaticIntegrationTransition, - MultinomialDynamicIntegrationTransition, - SliceDynamicIntegrationTransition, - euclidean_no_u_turn_criterion, - riemannian_no_u_turn_criterion, -) -from mici.states import ChainState -from mici.progressbars import ( - SequenceProgressBar, - LabelledSequenceProgressBar, - DummyProgressBar, - _ProxySequenceProgressBar, -) -from mici.errors import AdaptationError + from mici.adapters import DualAveragingStepSizeAdapter +from mici.errors import AdaptationError +from mici.progressbars import (DummyProgressBar, LabelledSequenceProgressBar, + SequenceProgressBar, _ProxySequenceProgressBar) from mici.stagers import WarmUpStager, WindowedWarmUpStager +from mici.states import ChainState +from mici.transitions import (IndependentMomentumTransition, + MetropolisRandomIntegrationTransition, + MetropolisStaticIntegrationTransition, + MultinomialDynamicIntegrationTransition, + SliceDynamicIntegrationTransition, + euclidean_no_u_turn_criterion, + riemannian_no_u_turn_criterion) if TYPE_CHECKING: - from typing import ( - Container, - Generator, - Iterable, - Optional, - Sequence, - Union, - ) + from typing import (Container, Generator, Iterable, Optional, Sequence, + Union) + from numpy.typing import ArrayLike, DTypeLike, NDArray + from mici.adapters import Adapter from mici.integrators import Integrator from mici.progressbars import ProgressBar from mici.stagers import Stager from mici.systems import System - from mici.transitions import IntegrationTransition, MomentumTransition, Transition - from mici.types import ( - AdapterState, - ChainIterator, - ScalarLike, - PyTree, - TraceFunction, - TerminationCriterion, - ) + from mici.transitions import (IntegrationTransition, MomentumTransition, + Transition) + from mici.types import (AdapterState, ChainIterator, PyTree, ScalarLike, + TerminationCriterion, TraceFunction) # Preferentially import from multiprocess library if available as able to # serialize much wider range of types including autograd functions @@ -145,7 +132,10 @@ def _generate_memmap_filenames( def _open_new_memmap( - file_path: str, shape: tuple[int, ...], default_val: ScalarLike, dtype: DTypeLike, + file_path: str, + shape: tuple[int, ...], + default_val: ScalarLike, + dtype: DTypeLike, ) -> np.memmap: """Open a new memory-mapped array object and fill with a default-value. @@ -298,7 +288,15 @@ def _init_traces( ] else: traces[key] = list( - np.full((n_chain, n_iter,) + val.shape, init, val.dtype) + np.full( + ( + n_chain, + n_iter, + ) + + val.shape, + init, + val.dtype, + ) ) return traces @@ -961,7 +959,11 @@ def sample_chains( ) ) stats = _init_stats( - self.transitions, n_chain, n_trace_iter, use_memmap, memmap_path, + self.transitions, + n_chain, + n_trace_iter, + use_memmap, + memmap_path, ) per_chain_rngs = _get_per_chain_rngs(self.rng, n_chain) per_chain_traces = ( @@ -1051,7 +1053,7 @@ class HMCSampleChainsOutputs(NamedTuple): and main sampling stages otherwise. The key for each value is the corresponding key in the dictionary returned by the trace function which computed the traced value. - statistics: Dictionary of chain transition statistic dictionaries. Values in + statistics: Dictionary of chain transition statistic dictionaries. Values in dictionary are lists of arrays of chain statistic values with each array in the list corresponding to a single chain and the leading dimension of each array corresponding to the iteration (draw) index, within the main diff --git a/src/mici/solvers.py b/src/mici/solvers.py index ce1f84a..584c5f7 100644 --- a/src/mici/solvers.py +++ b/src/mici/solvers.py @@ -2,17 +2,17 @@ from __future__ import annotations -from typing import Protocol, TYPE_CHECKING -from mici.errors import ConvergenceError, LinAlgError +from typing import TYPE_CHECKING, Protocol + import numpy as np +from mici.errors import ConvergenceError, LinAlgError + if TYPE_CHECKING: from mici.states import ChainState - from mici.systems import ( - ConstrainedEuclideanMetricSystem, - ConstrainedTractableFlowSystem, - ) - from mici.types import ScalarFunction, ArrayFunction, ArrayLike + from mici.systems import (ConstrainedEuclideanMetricSystem, + ConstrainedTractableFlowSystem) + from mici.types import ArrayFunction, ArrayLike, ScalarFunction def euclidean_norm(vct): diff --git a/src/mici/stagers.py b/src/mici/stagers.py index 8ee295f..0cfe018 100644 --- a/src/mici/stagers.py +++ b/src/mici/stagers.py @@ -3,10 +3,11 @@ from __future__ import annotations import abc -from typing import NamedTuple, TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: from typing import Iterable, Optional + from mici.adapters import Adapter from mici.types import TraceFunction diff --git a/src/mici/states.py b/src/mici/states.py index 7dfd18d..255436d 100644 --- a/src/mici/states.py +++ b/src/mici/states.py @@ -3,14 +3,17 @@ from __future__ import annotations import copy -from functools import wraps from collections import Counter +from functools import wraps from typing import TYPE_CHECKING + from mici.errors import ReadOnlyStateError if TYPE_CHECKING: from typing import Any, Callable, Iterable, Optional + from numpy.typing import ArrayLike + from mici.systems import System diff --git a/src/mici/systems.py b/src/mici/systems.py index 5ed8b07..827c32e 100644 --- a/src/mici/systems.py +++ b/src/mici/systems.py @@ -4,32 +4,26 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING + import numpy as np -from mici.states import cache_in_state, cache_in_state_with_aux + import mici.matrices as matrices from mici.autodiff import autodiff_fallback +from mici.states import cache_in_state, cache_in_state_with_aux if TYPE_CHECKING: from typing import Any, Optional, Type + from numpy.random import Generator + from mici.states import ChainState - from mici.types import ( - ScalarLike, - ArrayLike, - MatrixLike, - MetricLike, - ArrayFunction, - ScalarFunction, - GradientFunction, - HessianFunction, - MatrixTressianProduct, - MatrixTressianProductFunction, - JacobianFunction, - MatrixHessianProduct, - MatrixHessianProductFunction, - VectorJacobianProduct, - VectorJacobianProductFunction, - ) + from mici.types import (ArrayFunction, ArrayLike, GradientFunction, + HessianFunction, JacobianFunction, + MatrixHessianProduct, MatrixHessianProductFunction, + MatrixLike, MatrixTressianProduct, + MatrixTressianProductFunction, MetricLike, + ScalarFunction, ScalarLike, VectorJacobianProduct, + VectorJacobianProductFunction) class System(ABC): diff --git a/src/mici/transitions.py b/src/mici/transitions.py index 7441a46..4841704 100644 --- a/src/mici/transitions.py +++ b/src/mici/transitions.py @@ -2,23 +2,22 @@ from __future__ import annotations -from abc import ABC, abstractmethod, abstractproperty import logging -from typing import NamedTuple, TYPE_CHECKING +from abc import ABC, abstractmethod, abstractproperty +from typing import TYPE_CHECKING, NamedTuple + import numpy as np + +from mici.errors import (ConvergenceError, Error, HamiltonianDivergenceError, + IntegratorError, NonReversibleStepError) from mici.utils import LogRepFloat -from mici.errors import ( - Error, - IntegratorError, - NonReversibleStepError, - ConvergenceError, - HamiltonianDivergenceError, -) if TYPE_CHECKING: from typing import Optional + from numpy.random import Generator from numpy.typing import ArrayLike, DTypeLike + from mici.integrators import Integrator from mici.states import ChainState from mici.systems import System @@ -184,7 +183,7 @@ def sample( state.mom = self.system.sample_momentum(state, rng) elif self.mom_resample_coeff != 0: mom_ind = self.system.sample_momentum(state, rng) - state.mom *= (1.0 - self.mom_resample_coeff ** 2) ** 0.5 + state.mom *= (1.0 - self.mom_resample_coeff**2) ** 0.5 state.mom += self.mom_resample_coeff * mom_ind return state, None @@ -198,7 +197,7 @@ class IntegrationTransition(Transition): """ @property - def state_variables(self) -> Set[str]: + def state_variables(self) -> set[str]: return {"pos", "mom", "dir"} @property @@ -588,7 +587,9 @@ def _weight_function( @abstractmethod def _weight_ratio( - self, numerator: ScalarLike, denominator: ScalarLike, + self, + numerator: ScalarLike, + denominator: ScalarLike, ) -> ScalarLike: pass diff --git a/src/mici/types.py b/src/mici/types.py index 70270ab..a16b8fc 100644 --- a/src/mici/types.py +++ b/src/mici/types.py @@ -2,23 +2,16 @@ from __future__ import annotations -from typing import ( - Any, - Callable, - Collection, - Iterable, - TypeVar, - Union, -) -from numpy.typing import ArrayLike +from typing import Any, Callable, Collection, Iterable, TypeVar, Union + from numpy import number +from numpy.typing import ArrayLike -from mici.states import ChainState from mici.matrices import Matrix, PositiveDefiniteMatrix +from mici.states import ChainState from mici.systems import System from mici.utils import LogRepFloat - ScalarLike = Union[bool, int, float, LogRepFloat, number] MatrixLike = Union[ArrayLike, Matrix] MetricLike = Union[ArrayLike, PositiveDefiniteMatrix] diff --git a/src/mici/utils.py b/src/mici/utils.py index 14c2570..fb04830 100644 --- a/src/mici/utils.py +++ b/src/mici/utils.py @@ -2,9 +2,10 @@ from __future__ import annotations +from math import exp, expm1, inf, log, log1p, nan from typing import TYPE_CHECKING + import numpy as np -from math import log, exp, log1p, expm1, inf, nan try: import xxhash @@ -15,6 +16,7 @@ if TYPE_CHECKING: from typing import Optional + from mici.types import ScalarLike diff --git a/tests/test_adapters.py b/tests/test_adapters.py index a5bba6d..4c148e1 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -1,11 +1,12 @@ -import pytest import copy from collections.abc import Mapping -from types import MappingProxyType from math import exp +from types import MappingProxyType + import numpy as np -import mici +import pytest +import mici SEED = 3046987125 STATE_DIM = 10 @@ -126,7 +127,7 @@ def chain_state(self, rng): @pytest.fixture def system(self): return mici.systems.EuclideanMetricSystem( - neg_log_dens=lambda pos: np.sum(pos ** 2) / 2, + neg_log_dens=lambda pos: np.sum(pos**2) / 2, grad_neg_log_dens=lambda pos: pos, ) @@ -151,16 +152,16 @@ def adapter(self): @pytest.fixture def chain_state(self, rng): pos, mom = rng.standard_normal((2, STATE_DIM)) - pos /= np.sum(pos ** 2) ** 0.5 + pos /= np.sum(pos**2) ** 0.5 mom -= (mom @ pos) * pos return mici.states.ChainState(pos=pos, mom=mom, dir=1) @pytest.fixture def system(self): return mici.systems.DenseConstrainedEuclideanMetricSystem( - neg_log_dens=lambda pos: np.sum(pos ** 2) / 2, + neg_log_dens=lambda pos: np.sum(pos**2) / 2, grad_neg_log_dens=lambda pos: pos, - constr=lambda pos: np.sum(pos ** 2)[None] - 1, + constr=lambda pos: np.sum(pos**2)[None] - 1, jacob_constr=lambda pos: 2 * pos[None], ) @@ -182,7 +183,7 @@ def adapter(self): def system(self, rng): var = np.exp(rng.standard_normal(STATE_DIM)) return mici.systems.EuclideanMetricSystem( - neg_log_dens=lambda pos: np.sum(pos ** 2 / var) / 2, + neg_log_dens=lambda pos: np.sum(pos**2 / var) / 2, grad_neg_log_dens=lambda pos: pos / var, ) diff --git a/tests/test_integrators.py b/tests/test_integrators.py index 705c2e0..14f7792 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -1,11 +1,12 @@ -import pytest import numpy as np +import pytest + import mici.integrators as integrators -import mici.systems as systems import mici.matrices as matrices import mici.solvers as solvers -from mici.states import ChainState +import mici.systems as systems from mici.errors import IntegratorError +from mici.states import ChainState SEED = 3046987125 N_STEPS = {1, 5, 20} @@ -72,7 +73,6 @@ def _integrate_with_reversal(integrator, init_state, n_step): class IntegratorTests: - h_diff_tol = 5e-3 @pytest.mark.parametrize("n_step", N_STEPS) @@ -234,7 +234,7 @@ class LinearEuclideanMetricSystemTests(LinearSystemIntegratorTests): @pytest.fixture def system(self, metric): return systems.EuclideanMetricSystem( - neg_log_dens=lambda q: 0.5 * np.sum(q ** 2), + neg_log_dens=lambda q: 0.5 * np.sum(q**2), metric=metric, grad_neg_log_dens=lambda q: q, ) @@ -244,9 +244,9 @@ class NonLinearEuclideanMetricSystemTests(IntegratorTests): @pytest.fixture def system(self, metric): return systems.EuclideanMetricSystem( - neg_log_dens=lambda q: 0.25 * np.sum(q ** 4), + neg_log_dens=lambda q: 0.25 * np.sum(q**4), metric=metric, - grad_neg_log_dens=lambda q: q ** 3, + grad_neg_log_dens=lambda q: q**3, ) @@ -262,9 +262,9 @@ class NonLinearGaussianEuclideanMetricSystem(IntegratorTests): @pytest.fixture def system(self, metric): return systems.GaussianEuclideanMetricSystem( - neg_log_dens=lambda q: 0.125 * np.sum(q ** 4), + neg_log_dens=lambda q: 0.125 * np.sum(q**4), metric=metric, - grad_neg_log_dens=lambda q: 0.5 * q ** 3, + grad_neg_log_dens=lambda q: 0.5 * q**3, ) @@ -277,7 +277,6 @@ def integrator(self, system): class TestLeapfrogIntegratorLinearEuclideanMetricSystem( LeapfrogIntegratorTests, LinearEuclideanMetricSystemTests ): - h_diff_tol = 2e-3 step_size = 0.25 @@ -285,7 +284,6 @@ class TestLeapfrogIntegratorLinearEuclideanMetricSystem( class TestLeapfrogIntegratorNonLinearEuclideanMetricSystem( LeapfrogIntegratorTests, NonLinearEuclideanMetricSystemTests ): - h_diff_tol = 1e-3 step_size = 0.05 @@ -293,7 +291,6 @@ class TestLeapfrogIntegratorNonLinearEuclideanMetricSystem( class TestLeapfrogIntegratorLinearGaussianEuclideanMetricSystem( LeapfrogIntegratorTests, LinearGaussianEuclideanMetricSystem ): - h_diff_tol = 1e-10 step_size = 0.5 @@ -301,7 +298,6 @@ class TestLeapfrogIntegratorLinearGaussianEuclideanMetricSystem( class TestLeapfrogIntegratorNonLinearGaussianEuclideanMetricSystem( LeapfrogIntegratorTests, NonLinearGaussianEuclideanMetricSystem ): - h_diff_tol = 2e-3 step_size = 0.1 @@ -315,7 +311,6 @@ def integrator(self, system): class TestBCSSTwoStageIntegratorLinearEuclideanMetricSystem( BCSSTwoStageIntegratorTests, LinearEuclideanMetricSystemTests ): - h_diff_tol = 2e-4 step_size = 0.25 @@ -323,7 +318,6 @@ class TestBCSSTwoStageIntegratorLinearEuclideanMetricSystem( class TestBCSSTwoStageIntegratorNonLinearEuclideanMetricSystem( BCSSTwoStageIntegratorTests, NonLinearEuclideanMetricSystemTests ): - h_diff_tol = 1e-3 step_size = 0.05 @@ -331,15 +325,14 @@ class TestBCSSTwoStageIntegratorNonLinearEuclideanMetricSystem( class TestBCSSTwoStageIntegratorLinearGaussianEuclideanMetricSystem( BCSSTwoStageIntegratorTests, LinearGaussianEuclideanMetricSystem ): - h_diff_tol = 1e-10 step_size = 0.5 class TestBCSSTwoStageIntegratorNonLinearGaussianEuclideanMetricSystem( - BCSSTwoStageIntegratorTests, NonLinearGaussianEuclideanMetricSystem, + BCSSTwoStageIntegratorTests, + NonLinearGaussianEuclideanMetricSystem, ): - h_diff_tol = 2e-3 step_size = 0.1 @@ -353,7 +346,6 @@ def integrator(self, system): class TestBCSSThreeStageIntegratorLinearEuclideanMetricSystem( BCSSThreeStageIntegratorTests, LinearEuclideanMetricSystemTests ): - h_diff_tol = 5e-5 step_size = 0.25 @@ -361,7 +353,6 @@ class TestBCSSThreeStageIntegratorLinearEuclideanMetricSystem( class TestBCSSThreeStageIntegratorNonLinearEuclideanMetricSystem( BCSSThreeStageIntegratorTests, NonLinearEuclideanMetricSystemTests ): - h_diff_tol = 5e-4 step_size = 0.25 @@ -369,15 +360,14 @@ class TestBCSSThreeStageIntegratorNonLinearEuclideanMetricSystem( class TestBCSSThreeStageIntegratorLinearGaussianEuclideanMetricSystem( BCSSThreeStageIntegratorTests, LinearGaussianEuclideanMetricSystem ): - h_diff_tol = 1e-10 step_size = 0.5 class TestBCSSThreeStageIntegratorNonLinearGaussianEuclideanMetricSystem( - BCSSThreeStageIntegratorTests, NonLinearGaussianEuclideanMetricSystem, + BCSSThreeStageIntegratorTests, + NonLinearGaussianEuclideanMetricSystem, ): - h_diff_tol = 5e-4 step_size = 0.5 @@ -391,7 +381,6 @@ def integrator(self, system): class TestBCSSFourStageIntegratorLinearEuclideanMetricSystem( BCSSFourStageIntegratorTests, LinearEuclideanMetricSystemTests ): - h_diff_tol = 2e-5 step_size = 1.0 @@ -399,7 +388,6 @@ class TestBCSSFourStageIntegratorLinearEuclideanMetricSystem( class TestBCSSFourStageIntegratorNonLinearEuclideanMetricSystem( BCSSFourStageIntegratorTests, NonLinearEuclideanMetricSystemTests ): - h_diff_tol = 1e-3 step_size = 0.25 @@ -407,15 +395,14 @@ class TestBCSSFourStageIntegratorNonLinearEuclideanMetricSystem( class TestBCSSFourStageIntegratorLinearGaussianEuclideanMetricSystem( BCSSFourStageIntegratorTests, LinearGaussianEuclideanMetricSystem ): - h_diff_tol = 1e-10 step_size = 1.0 class TestBCSSFourStageIntegratorNonLinearGaussianEuclideanMetricSystem( - BCSSFourStageIntegratorTests, NonLinearGaussianEuclideanMetricSystem, + BCSSFourStageIntegratorTests, + NonLinearGaussianEuclideanMetricSystem, ): - h_diff_tol = 5e-4 step_size = 0.5 @@ -455,7 +442,6 @@ def integrator(self, system, fixed_point_solver, reverse_check_norm): class TestImplicitLeapfrogIntegratorLinearEuclideanMetricSystem( ImplicitLeapfrogIntegratorTests, LinearEuclideanMetricSystemTests ): - h_diff_tol = 5e-3 step_size = 0.25 @@ -463,7 +449,6 @@ class TestImplicitLeapfrogIntegratorLinearEuclideanMetricSystem( class TestImplicitMidpointIntegratorLinearEuclideanMetricSystem( ImplicitMidpointIntegratorTests, LinearEuclideanMetricSystemTests ): - h_diff_tol = 1e-7 step_size = 0.25 @@ -471,7 +456,6 @@ class TestImplicitMidpointIntegratorLinearEuclideanMetricSystem( class TestImplicitLeapfrogIntegratorNonLinearEuclideanMetricSystem( ImplicitLeapfrogIntegratorTests, NonLinearEuclideanMetricSystemTests ): - h_diff_tol = 6e-3 step_size = 0.1 @@ -479,7 +463,6 @@ class TestImplicitLeapfrogIntegratorNonLinearEuclideanMetricSystem( class TestImplicitMidpointIntegratorNonLinearEuclideanMetricSystem( ImplicitMidpointIntegratorTests, NonLinearEuclideanMetricSystemTests ): - h_diff_tol = 5e-3 step_size = 0.1 @@ -488,9 +471,9 @@ class NonLinearDiagonalRiemannianMetricSystemTests(IntegratorTests): @pytest.fixture def system(self): return systems.DiagonalRiemannianMetricSystem( - lambda q: np.sum(q ** 2) / 2 + np.sum(q ** 4) / 12, - grad_neg_log_dens=lambda q: q + q ** 3 / 3, - metric_diagonal_func=lambda q: 1 + q ** 2, + lambda q: np.sum(q**2) / 2 + np.sum(q**4) / 12, + grad_neg_log_dens=lambda q: q + q**3 / 3, + metric_diagonal_func=lambda q: 1 + q**2, vjp_metric_diagonal_func=lambda q: lambda m: 2 * m * q, ) @@ -498,7 +481,6 @@ def system(self): class TestImplicitLeapfrogIntegratorNonLinearDiagonalRiemannianMetricSystem( ImplicitLeapfrogIntegratorTests, NonLinearDiagonalRiemannianMetricSystemTests ): - h_diff_tol = 1e-3 step_size = 0.1 @@ -506,7 +488,6 @@ class TestImplicitLeapfrogIntegratorNonLinearDiagonalRiemannianMetricSystem( class TestImplicitMidpointIntegratorNonLinearDiagonalRiemannianMetricSystem( ImplicitMidpointIntegratorTests, NonLinearDiagonalRiemannianMetricSystemTests ): - h_diff_tol = 2e-4 step_size = 0.1 @@ -514,13 +495,12 @@ class TestImplicitMidpointIntegratorNonLinearDiagonalRiemannianMetricSystem( class TestConstrainedLeapfrogIntegratorLinearSystem( ConstrainedLinearSystemIntegratorTests ): - h_diff_tol = 1e-2 @pytest.fixture def integrator(self, metric): system = systems.DenseConstrainedEuclideanMetricSystem( - neg_log_dens=lambda q: 0.5 * np.sum(q ** 2), + neg_log_dens=lambda q: 0.5 * np.sum(q**2), metric=metric, grad_neg_log_dens=lambda q: q, constr=lambda q: q[:1], @@ -532,15 +512,14 @@ def integrator(self, metric): class TestConstrainedLeapfrogIntegratorNonLinearSystem( ConstrainedNonLinearSystemIntegratorTests ): - h_diff_tol = 1e-2 @pytest.fixture def system(self, metric): return systems.DenseConstrainedEuclideanMetricSystem( - neg_log_dens=lambda q: 0.125 * np.sum(q ** 4), + neg_log_dens=lambda q: 0.125 * np.sum(q**4), metric=metric, - grad_neg_log_dens=lambda q: 0.5 * q ** 3, + grad_neg_log_dens=lambda q: 0.5 * q**3, constr=lambda q: q[0:1] ** 2 + q[1:2] ** 2 - 1.0, jacob_constr=lambda q: np.concatenate( [2 * q[0:1], 2 * q[1:2], np.zeros(q.shape[0] - 2)] @@ -563,7 +542,6 @@ def integrator(self, system, request): class TestConstrainedLeapfrogIntegratorGaussianLinearSystem( ConstrainedLinearSystemIntegratorTests ): - h_diff_tol = 1e-4 @pytest.fixture @@ -582,15 +560,14 @@ def integrator(self, metric): class TestConstrainedLeapfrogIntegratorGaussianNonLinearSystem( ConstrainedNonLinearSystemIntegratorTests ): - h_diff_tol = 5e-2 @pytest.fixture def system(self, metric): return systems.GaussianDenseConstrainedEuclideanMetricSystem( - neg_log_dens=lambda q: 0.125 * np.sum(q ** 4), + neg_log_dens=lambda q: 0.125 * np.sum(q**4), metric=metric, - grad_neg_log_dens=lambda q: 0.5 * q ** 3, + grad_neg_log_dens=lambda q: 0.5 * q**3, constr=lambda q: q[0:1] ** 2 + q[1:2] ** 2 - 1.0, jacob_constr=lambda q: np.concatenate( [2 * q[0:1], 2 * q[1:2], np.zeros(q.shape[0] - 2)] diff --git a/tests/test_interop.py b/tests/test_interop.py index f1b33e0..7680530 100644 --- a/tests/test_interop.py +++ b/tests/test_interop.py @@ -1,5 +1,6 @@ -import pytest import numpy as np +import pytest + import mici try: diff --git a/tests/test_matrices.py b/tests/test_matrices.py index 9dba823..6c7448f 100644 --- a/tests/test_matrices.py +++ b/tests/test_matrices.py @@ -1,18 +1,20 @@ -import pytest -from itertools import product -from functools import partial, wraps, reduce from copy import copy, deepcopy +from functools import partial, reduce, wraps +from itertools import product + import numpy as np import numpy.linalg as nla -import scipy.linalg as sla import numpy.testing as npt +import pytest +import scipy.linalg as sla + import mici.matrices as matrices AUTOGRAD_AVAILABLE = True try: import autograd.numpy as anp from autograd import grad - from autograd.core import primitive, defvjp + from autograd.core import defvjp, primitive except ImportError: AUTOGRAD_AVAILABLE = False import warnings @@ -312,7 +314,6 @@ def test_sqrt_array(self, matrix, np_matrix): class DifferentiableMatrixTests(MatrixTests): - if AUTOGRAD_AVAILABLE: @pytest.fixture @@ -383,7 +384,6 @@ class TestScaledIdentityMatrix( ExplicitShapeSymmetricMatrixTests, ExplicitShapeInvertibleMatrixTests, ): - matrix_class = matrices.ScaledIdentityMatrix @staticmethod @@ -394,7 +394,6 @@ def generate_scalar(rng): class TestPositiveScaledIdentityMatrix( DifferentiableScaledIdentityMatrixTests, ExplicitShapePositiveDefiniteMatrixTests ): - matrix_class = matrices.PositiveScaledIdentityMatrix @staticmethod @@ -422,7 +421,6 @@ class TestDiagonalMatrix( ExplicitShapeSymmetricMatrixTests, ExplicitShapeInvertibleMatrixTests, ): - matrix_class = matrices.DiagonalMatrix @staticmethod @@ -433,7 +431,6 @@ def generate_diagonal(size, rng): class TestPositiveDiagonalMatrix( DifferentiableDiagonalMatrixTests, ExplicitShapePositiveDefiniteMatrixTests ): - matrix_class = matrices.PositiveDiagonalMatrix @staticmethod @@ -488,7 +485,6 @@ class TestTriangularFactoredDefiniteMatrix( ExplicitShapeSymmetricMatrixTests, ExplicitShapeInvertibleMatrixTests, ): - matrix_class = matrices.TriangularFactoredDefiniteMatrix @pytest.fixture(params=(+1, -1)) @@ -532,7 +528,6 @@ class TestDenseDefiniteMatrix( ExplicitShapeSymmetricMatrixTests, ExplicitShapeInvertibleMatrixTests, ): - matrix_class = matrices.DenseDefiniteMatrix @pytest.fixture(params=(+1, -1)) @@ -785,7 +780,6 @@ def matrix_pair(self, rng, size, n_block): if AUTOGRAD_AVAILABLE: - # Define new block_diag primitive and corresponding vector-Jacobian-product @primitive @@ -793,7 +787,6 @@ def block_diag(blocks): return sla.block_diag(*blocks) def vjp_block_diag(ans, blocks): - blocks = tuple(blocks) def vjp(g): diff --git a/tests/test_progressbars.py b/tests/test_progressbars.py index e8384e0..6512a0d 100644 --- a/tests/test_progressbars.py +++ b/tests/test_progressbars.py @@ -1,8 +1,10 @@ -import pytest -import mici from collections import OrderedDict from queue import SimpleQueue +import pytest + +import mici + def test_format_time(): assert mici.progressbars._format_time(100) == "01:40" @@ -60,4 +62,3 @@ def test_progress_bar_iter(progress_bar_and_sequence): for (val, iter_dict), val_orig in zip(progress_bar, sequence): assert val == val_orig assert isinstance(iter_dict, dict) - diff --git a/tests/test_samplers.py b/tests/test_samplers.py index aabb4c2..65d6717 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -1,7 +1,7 @@ -import pytest import numpy as np -import mici +import pytest +import mici SEED = 3046987125 STATE_DIM = 2 @@ -14,7 +14,7 @@ def rng(): def neg_log_dens(pos): - return np.sum(pos ** 2) / 2 + return np.sum(pos**2) / 2 def grad_neg_log_dens(pos): @@ -333,7 +333,6 @@ def check_stats_dict(self, stats, n_iter, n_chain, transitions): class TestStaticMetropolisHMC(HamiltonianMCMCTests): - n_step = 2 def test_max_tree_depth(self, sampler): @@ -351,7 +350,6 @@ def sampler(self, integrator, system, rng): class TestRandomMetropolisHMC(HamiltonianMCMCTests): - n_step_range = (1, 3) def test_max_tree_depth(self, sampler): @@ -374,7 +372,6 @@ def sampler(self, integrator, system, rng): class DynamicHMCTests(HamiltonianMCMCTests): - max_tree_depth = 2 max_delta_h = 1000 diff --git a/tests/test_solvers.py b/tests/test_solvers.py index 758cf9b..ab11333 100644 --- a/tests/test_solvers.py +++ b/tests/test_solvers.py @@ -1,7 +1,9 @@ +from collections import namedtuple + +import numpy as np import pytest + import mici -import numpy as np -from collections import namedtuple @pytest.fixture(params=("direct", "steffensen")) @@ -21,10 +23,10 @@ def fixed_point_solver(request): def convergent_fixed_point_problem(request): if request.param == "babylonian": y = np.array([3.0, 5.0, 7.0]) - return FixedPointProblem(lambda x: (y / x + x) / 2, y ** 0.5, np.ones_like(y)) + return FixedPointProblem(lambda x: (y / x + x) / 2, y**0.5, np.ones_like(y)) elif request.param == "ratio": y = np.array([3.0, 5.0, 7.0]) - return FixedPointProblem(lambda x: (x + y) / (x + 1), y ** 0.5, np.ones_like(y)) + return FixedPointProblem(lambda x: (x + y) / (x + 1), y**0.5, np.ones_like(y)) elif request.param == "cosine": return FixedPointProblem( lambda x: np.cos(x), np.array([0.7390851332151607]), np.array([1.0]) @@ -36,9 +38,7 @@ def divergent_fixed_point_problem(request): if request.param == "doubling": return FixedPointProblem(lambda x: 2 * x, None, np.arange(3)) elif request.param == "quadratic": - return FixedPointProblem( - lambda x: 1 + x**2, None, np.arange(3) - ) + return FixedPointProblem(lambda x: 1 + x**2, None, np.arange(3)) @pytest.fixture(params=(1e-6, 1e-8, 1e-10)) @@ -106,4 +106,3 @@ def func(x): with pytest.raises(mici.errors.ConvergenceError): fixed_point_solver(func=func, x0=np.array([1.0])) - diff --git a/tests/test_stagers.py b/tests/test_stagers.py index 838cbf3..0262dd9 100644 --- a/tests/test_stagers.py +++ b/tests/test_stagers.py @@ -1,4 +1,5 @@ import pytest + import mici N_MAIN_ITER = 1 diff --git a/tests/test_states.py b/tests/test_states.py index 9c5d837..832a06a 100644 --- a/tests/test_states.py +++ b/tests/test_states.py @@ -1,8 +1,10 @@ -import pytest -from unittest.mock import Mock -from itertools import combinations import pickle +from itertools import combinations +from unittest.mock import Mock + import numpy as np +import pytest + import mici diff --git a/tests/test_utils.py b/tests/test_utils.py index c628064..54a21b2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ -import pytest import numpy as np -import mici +import pytest +import mici SEED = 3046987125 @@ -41,7 +41,6 @@ def get_val(obj): class TestLogRepFloat: - VALS = sorted((0.0, 0.9, 1, 1.1, 2.1, 120.0)) FIXED = VALS[1]