Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random state for reproducibility #86

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions swyft/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
class Bound(StateDictSaveable):
"""A bound region on the hypercube.

.. note::
Notes:
The Bound object provides methods to sample from subregions of the
hypercube, to evaluate the volume of the constrained region, and to
evaluate the bound.
"""


rng = np.random.default_rng()

def __init__(self) -> None:
pass

Expand All @@ -38,6 +41,10 @@ def n_parameters(self) -> int:
"""Number of dimensions."""
raise NotImplementedError

def set_seed(self, seed):
"""Set seed for random number generator."""
self.rng = np.random.default_rng(seed=seed)

def sample(self, n_samples: int) -> np.ndarray:
"""Sample.

Expand All @@ -46,6 +53,10 @@ def sample(self, n_samples: int) -> np.ndarray:

Returns:
s (n_samples x n_parameters)

Notes:
Overriding methods should generate random numbers using `self.rng`,
which can be seeded via the method `Bound.set_seed`.
"""
raise NotImplementedError

Expand Down Expand Up @@ -122,7 +133,9 @@ def sample(self, n_samples):
Args:
n_samples (int): Number of samples
"""
return np.random.rand(n_samples, self.n_parameters)

return self.rng.random((n_samples, self.n_parameters))


def __call__(self, u):
"""Evaluate bound.
Expand Down Expand Up @@ -167,7 +180,7 @@ def n_parameters(self):
return len(self._rec_bounds)

def sample(self, n_samples):
u = np.random.rand(n_samples, self.n_parameters)
u = self.rng.random((n_samples, self.n_parameters))
for i in range(self.n_parameters):
u[:, i] *= self._rec_bounds[i, 1] - self._rec_bounds[i, 0]
u[:, i] += self._rec_bounds[i, 0]
Expand Down Expand Up @@ -200,7 +213,7 @@ def __init__(self, points, scale=1.0):
self._n_parameters = self.X.shape[-1]
self.bt = BallTree(self.X, leaf_size=2)
self.epsilon = self._set_epsilon(self.X, self.bt, scale)
self._volume = self._get_volume(self.X, self.epsilon, self.bt)
self._volume = self._get_volume(self.X, self.epsilon, self.bt, self.rng)

@property
def volume(self) -> float:
Expand All @@ -219,16 +232,16 @@ def _set_epsilon(X, bt, scale):
return epsilon

@staticmethod
def _get_volume(X, epsilon, bt):
def _get_volume(X, epsilon, bt, rng):
n_samples = 100
vol_est = []
d = X.shape[-1]
area = {1: 2 * epsilon, 2: np.pi * epsilon ** 2}[d]
for i in range(n_samples):
n = np.random.randn(*X.shape)
n = rng.standard_normal(X.shape)
norm = (n ** 2).sum(axis=1) ** 0.5
n = n / norm.reshape(-1, 1)
r = np.random.rand(len(X)) ** (1 / d) * epsilon
r = rng.random(len(X)) ** (1 / d) * epsilon
Y = X + n * r.reshape(-1, 1)
in_bounds = ((Y >= 0.0) & (Y <= 1.0)).prod(axis=1, dtype="bool")
Y = Y[in_bounds]
Expand All @@ -245,22 +258,24 @@ def sample(self, n_samples):
counter = 0
samples = []
d = self.X.shape[-1]

while counter < n_samples:
n = np.random.randn(*self.X.shape)
n = self.rng.standard_normal(self.X.shape)
norm = (n ** 2).sum(axis=1) ** 0.5
n = n / norm.reshape(-1, 1)
r = np.random.rand(len(self.X)) ** (1 / d) * self.epsilon
r = self.rng.random(len(self.X)) ** (1 / d) * self.epsilon
Y = self.X + n * r.reshape(-1, 1)
in_bounds = ((Y >= 0.0) & (Y <= 1.0)).prod(axis=1, dtype="bool")
Y = Y[in_bounds]
counts = self.bt.query_radius(Y, r=self.epsilon, count_only=True)
p = 1.0 / counts
w = np.random.rand(len(p))
w = self.rng.random(len(p))
Y = Y[p >= w]
samples.append(Y)
counter += len(Y)
samples = np.vstack(samples)
ind = np.random.choice(range(len(samples)), size=n_samples, replace=False)
ind = self.rng.choice(range(len(samples)), size=n_samples, replace=False)

return samples[ind]

def __call__(self, u):
Expand Down
30 changes: 30 additions & 0 deletions tests/bound_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
from swyft.prior import get_uniform_prior
from swyft.bounds import UnitCubeBound, RectangleBound, BallsBound
from swyft import PriorTruncator

# Define a prior with 3 parameters
low = np.zeros(3)
high = np.array([0.7, 1.0, 0.5])
prior = get_uniform_prior(low, high)


class TestBoundSeed:
def test_UnitCubeBound_seed(self):
bound1 = UnitCubeBound(3)
bound1.set_seed(1234)
pdf1 = PriorTruncator(prior, bound1)
bound2 = UnitCubeBound(3)
bound2.set_seed(1234)
pdf2 = PriorTruncator(prior, bound2)
assert np.all(pdf1.sample(10) == pdf2.sample(10))

def test_RectangleBound_seed(self):
bound1 = RectangleBound(np.stack((low, high), axis=1))
bound1.set_seed(1234)
pdf1 = PriorTruncator(prior, bound1)
bound2 = RectangleBound(np.stack((low, high), axis=1))
bound2.set_seed(1234)
pdf2 = PriorTruncator(prior, bound2)
assert np.all(pdf1.sample(10) == pdf2.sample(10))

22 changes: 22 additions & 0 deletions tests/store_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from re import A
import tempfile
from pathlib import Path

Expand All @@ -8,6 +9,7 @@
from swyft.prior import get_uniform_prior
from swyft.store.simulator import DaskSimulator, SimulationStatus, Simulator
from swyft.store.store import DirectoryStore, MemoryStore
from swyft.bounds import UnitCubeBound


def model(params):
Expand Down Expand Up @@ -98,6 +100,26 @@ def test_store_add(self):
store.add(20, prior)
assert store.sims.x1.shape[0] > 0

def test_store_add_with_bound_seed(self):
store1 = MemoryStore(simulator=sim_multi_out)
store2 = MemoryStore(simulator=sim_multi_out)
bound = UnitCubeBound(2)
bound.set_seed(1234)
store1.add(20, prior, bound)
bound.set_seed(1234)
store2.add(20, prior, bound)
ind = np.min((store1.v.shape[0], store2.v.shape[0]))
assert np.all(store1.v[:ind] == store2.v[:ind])

def test_store_add_without_bound_seed(self):
store1 = MemoryStore(simulator=sim_multi_out)
store2 = MemoryStore(simulator=sim_multi_out)
bound = UnitCubeBound(2)
store1.add(20, prior, bound)
store2.add(20, prior, bound)
ind = np.min((store1.v.shape[0], store2.v.shape[0]))
assert not (np.all(store1.v[:ind] == store2.v[:ind]))

def test_memory_store_simulate(self):
store = MemoryStore(simulator=sim_multi_out)
indices = store.sample(100, prior, add=True)
Expand Down