Skip to content

Commit

Permalink
Merge branch 'master' into credibility-interval
Browse files Browse the repository at this point in the history
  • Loading branch information
williamjameshandley committed Feb 2, 2025
2 parents 3e6bd4b + 42b1ead commit e435542
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 27 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
- name: Upgrade pip and install doc requirements
run: |
python -m pip install --upgrade pip
python -m pip install numpy~=1.0
python -m pip install -e ".[all,docs]"
- name: build documentation
run: |
Expand Down Expand Up @@ -172,7 +173,12 @@ jobs:
run: python -m pip install -e ".[test]"

- name: Test with pytest
run: python -m pytest tests
run: python -m pytest --cov=anesthetic tests

- name: Upload coverage to Codecov
uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}

latest-dependencies:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
anesthetic: nested sampling post-processing
===========================================
:Authors: Will Handley and Lukas Hergt
:Version: 2.9.0
:Version: 2.9.1
:Homepage: https://github.com/handley-lab/anesthetic
:Documentation: http://anesthetic.readthedocs.io/

Expand Down
2 changes: 1 addition & 1 deletion anesthetic/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.9.0'
__version__ = '2.9.1'
13 changes: 11 additions & 2 deletions anesthetic/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
to create a set of axes and legend proxies.
"""
from packaging import version
import numpy as np
from pandas import Series, DataFrame
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -839,7 +840,11 @@ def fastkde_plot_1d(ax, data, *args, **kwargs):
p /= p.max()
i = ((x > quantile(x, q[0], p)) & (x < quantile(x, q[-1], p)))

area = np.trapz(x=x[i], y=p[i]) if density else 1
if version.parse(np.__version__) >= version.parse("2.0.0"):
trapezoid = np.trapezoid
else:
trapezoid = np.trapz
area = trapezoid(x=x[i], y=p[i]) if density else 1
if ax.get_xaxis().get_scale() == 'log':
x = 10**x
ans = ax.plot(x[i], p[i]/area, color=color, *args, **kwargs)
Expand Down Expand Up @@ -962,7 +967,11 @@ def kde_plot_1d(ax, data, *args, **kwargs):
bw = np.sqrt(kde.covariance[0, 0])
pp = cut_and_normalise_gaussian(x, p, bw, xmin=data.min(), xmax=data.max())
pp /= pp.max()
area = np.trapz(x=x, y=pp) if density else 1
if version.parse(np.__version__) >= version.parse("2.0.0"):
trapezoid = np.trapezoid
else:
trapezoid = np.trapz
area = trapezoid(x=x, y=pp) if density else 1
if ax.get_xaxis().get_scale() == 'log':
x = 10**x
ans = ax.plot(x, pp/area, color=color, *args, **kwargs)
Expand Down
57 changes: 43 additions & 14 deletions anesthetic/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def ns_output(self, *args, **kwargs):
" as well as average loglikelihoods: help(samples.logL_P)"
)

def stats(self, nsamples=None, beta=None):
def stats(self, nsamples=None, beta=None, norm=None):
r"""Compute Nested Sampling statistics.
Using nested sampling we can compute:
Expand Down Expand Up @@ -822,20 +822,27 @@ def stats(self, nsamples=None, beta=None):
beta : float, array-like, optional
inverse temperature(s) beta=1/kT. Default self.beta
norm : Series, :class:`Samples`, optional
:meth:`NestedSamples.stats` output used for normalisation.
Can be either a Series of mean values or Samples produced with
matching `nsamples` and `beta`. In addition to the columns
['logZ', 'D_KL', 'logL_P', 'd_G'], this adds the normalised
versions ['Delta_logZ', 'Delta_D_KL', 'Delta_logL_P', 'Delta_d_G'].
Returns
-------
if beta is scalar and nsamples is None:
Series, index ['logZ', 'd_G', 'DK_L', 'logL_P']
Series, index ['logZ', 'd_G', 'D_KL', 'logL_P']
elif beta is scalar and nsamples is int:
:class:`Samples`, index range(nsamples),
columns ['logZ', 'd_G', 'DK_L', 'logL_P']
columns ['logZ', 'd_G', 'D_KL', 'logL_P']
elif beta is array-like and nsamples is None:
:class:`Samples`, index beta,
columns ['logZ', 'd_G', 'DK_L', 'logL_P']
columns ['logZ', 'd_G', 'D_KL', 'logL_P']
elif beta is array-like and nsamples is int:
:class:`Samples`, index :class:`pandas.MultiIndex` the product of
beta and range(nsamples)
columns ['logZ', 'd_G', 'DK_L', 'logL_P']
columns ['logZ', 'd_G', 'D_KL', 'logL_P']
"""
logw = self.logw(nsamples, beta)
if nsamples is None and beta is None:
Expand All @@ -861,6 +868,26 @@ def stats(self, nsamples=None, beta=None):
samples.set_label('d_G', r'$d_\mathrm{G}$')

samples.label = self.label

if norm is not None:
samples['Delta_logZ'] = samples['logZ'] - norm['logZ']
samples.set_label('Delta_logZ',
r"$\Delta\ln\mathcal{Z}$")

samples['Delta_D_KL'] = samples['D_KL'] - norm['D_KL']
samples.set_label('Delta_D_KL',
r"$\Delta\mathcal{D}_\mathrm{KL}$")

samples['Delta_logL_P'] = samples['logL_P'] - norm['logL_P']
samples.set_label(
'Delta_logL_P',
r"$\Delta\langle\ln\mathcal{L}\rangle_\mathcal{P}$"
)

samples['Delta_d_G'] = samples['d_G'] - norm['d_G']
samples.set_label('Delta_d_G',
r"$\Delta d_\mathrm{G}$")

return samples

def logX(self, nsamples=None):
Expand Down Expand Up @@ -1302,15 +1329,17 @@ def recompute(self, logL_birth=None, inplace=False):
n_bad = invalid.sum()
n_equal = (samples.logL == samples.logL_birth).sum()
if n_bad:
warnings.warn("%i out of %i samples have logL <= logL_birth,"
"\n%i of which have logL == logL_birth."
"\nThis may just indicate numerical rounding "
"errors at the peak of the likelihood, but "
"further investigation of the chains files is "
"recommended."
"\nDropping the invalid samples." %
(n_bad, len(samples), n_equal),
RuntimeWarning)
n_inf = ((samples.logL == samples.logL_birth) &
(samples.logL == -np.inf)).sum()
if n_bad > n_inf:
warnings.warn(
"%i out of %i samples have logL <= logL_birth,\n"
"%i of which have logL == logL_birth.\n"
"This may just indicate numerical rounding errors at "
"the peak of the likelihood, but further "
"investigation of the chains files is recommended.\n"
"Dropping the invalid samples."
% (n_bad, len(samples), n_equal), RuntimeWarning)
samples = samples[~invalid].reset_index(drop=True)

samples.sort_values('logL', inplace=True)
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ readme = "README.rst"
license = {file = "LICENSE"}
requires-python = ">=3.8"
dependencies = [
"scipy",
"numpy",
"scipy<2.0.0",
"numpy>=1.26.0,<3.0.0",
"pandas~=2.2.0",
"matplotlib>=3.6.1,<3.9.0",
"matplotlib>=3.6.1,<3.10.0",
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down
55 changes: 55 additions & 0 deletions tests/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,39 +946,93 @@ def test_stats():
beta = [0., 0.5, 1.]

vals = ['logZ', 'D_KL', 'logL_P', 'd_G']
delta_vals = ['Delta_logZ', 'Delta_D_KL', 'Delta_logL_P', 'Delta_d_G']

labels = [r'$\ln\mathcal{Z}$',
r'$\mathcal{D}_\mathrm{KL}$',
r'$\langle\ln\mathcal{L}\rangle_\mathcal{P}$',
r'$d_\mathrm{G}$']
delta_labels = [r'$\Delta\ln\mathcal{Z}$',
r'$\Delta\mathcal{D}_\mathrm{KL}$',
r'$\Delta\langle\ln\mathcal{L}\rangle_\mathcal{P}$',
r'$\Delta d_\mathrm{G}$']

stats = pc.stats()
assert isinstance(stats, WeightedLabelledSeries)
assert_array_equal(stats.drop_labels().index, vals)
assert_array_equal(stats.get_labels(), labels)

stats = pc.stats(norm=pc.stats())
assert isinstance(stats, WeightedLabelledSeries)
assert_array_equal(stats.drop_labels().index, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)

stats = pc.stats(nsamples=nsamples)
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals)
assert_array_equal(stats.get_labels(), labels)
assert stats.index.name == 'samples'
assert_array_equal(stats.index, range(nsamples))

stats = pc.stats(nsamples=nsamples, norm=pc.stats())
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)
assert stats.index.name == 'samples'
assert_array_equal(stats.index, range(nsamples))

stats = pc.stats(nsamples=nsamples, norm=pc.stats(nsamples=nsamples))
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)
assert stats.index.name == 'samples'
assert_array_equal(stats.index, range(nsamples))

stats = pc.stats(beta=beta)
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals)
assert_array_equal(stats.get_labels(), labels)
assert stats.index.name == 'beta'
assert_array_equal(stats.index, beta)

stats = pc.stats(beta=beta, norm=pc.stats())
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)
assert stats.index.name == 'beta'
assert_array_equal(stats.index, beta)

stats = pc.stats(beta=beta, norm=pc.stats(beta=beta))
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)
assert stats.index.name == 'beta'
assert_array_equal(stats.index, beta)

stats = pc.stats(nsamples=nsamples, beta=beta)
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals)
assert_array_equal(stats.get_labels(), labels)
assert stats.index.names == ['beta', 'samples']
assert stats.index.levshape == (len(beta), nsamples)

stats = pc.stats(nsamples=nsamples, beta=beta, norm=pc.stats())
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)
assert stats.index.names == ['beta', 'samples']
assert stats.index.levshape == (len(beta), nsamples)

stats = pc.stats(nsamples=nsamples, beta=beta,
norm=pc.stats(nsamples=nsamples, beta=beta))
assert isinstance(stats, WeightedLabelledDataFrame)
assert_array_equal(stats.drop_labels().columns, vals + delta_vals)
assert_array_equal(stats.get_labels(), labels + delta_labels)
assert stats.index.names == ['beta', 'samples']
assert stats.index.levshape == (len(beta), nsamples)

for beta in [1., 0., 0.5]:
np.random.seed(42)
pc.beta = beta
n = 1000
PC = pc.stats(n, beta)
Expand Down Expand Up @@ -1133,6 +1187,7 @@ def test_beta():
def test_beta_with_logL_infinities():
ns = read_chains("./tests/example_data/pc")
ns.loc[:10, ('logL', r'$\ln\mathcal{L}$')] = -np.inf
ns.loc[1000, ('logL', r'$\ln\mathcal{L}$')] = -np.inf
with pytest.warns(RuntimeWarning):
ns.recompute(inplace=True)
assert (ns.logL == -np.inf).sum() == 0
Expand Down
10 changes: 5 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from importlib.util import find_spec
import pytest
import sys

try:
import astropy # noqa: F401
except ImportError:
pass

condition = 'astropy' not in sys.modules
condition = find_spec('astropy') is None
reason = "requires astropy package"
raises = ImportError
astropy_mark_skip = pytest.mark.skipif(condition, reason=reason)
Expand All @@ -22,7 +22,7 @@ def skipif_no_astropy(param):
except ImportError:
pass
reason = "requires fastkde package"
condition = 'fastkde' not in sys.modules
condition = find_spec('fastkde') is None
raises = ImportError
fastkde_mark_skip = pytest.mark.skipif(condition, reason=reason)
fastkde_mark_xfail = pytest.mark.xfail(condition, raises=raises, reason=reason)
Expand All @@ -36,7 +36,7 @@ def skipif_no_fastkde(param):
import getdist # noqa: F401
except ImportError:
pass
condition = 'getdist' not in sys.modules
condition = find_spec('getdist') is None
reason = "requires getdist package"
raises = ImportError
getdist_mark_skip = pytest.mark.skipif(condition, reason=reason)
Expand Down Expand Up @@ -65,7 +65,7 @@ def skipif_no_getdist(param):
except ImportError:
pass

condition = 'h5py' not in sys.modules
condition = find_spec('h5py') is None
reason = "requires h5py package"
raises = ImportError
h5py_mark_skip = pytest.mark.skipif(condition, reason=reason)
Expand Down

0 comments on commit e435542

Please sign in to comment.