Skip to content

Commit

Permalink
Merge pull request #358 from UCL-CCS/fix_MC_for_one_param
Browse files Browse the repository at this point in the history
Fix MCSampler for 1D problems
  • Loading branch information
orbitfold authored Oct 21, 2021
2 parents 78ca8fb + 43c141b commit fac0b57
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
4 changes: 4 additions & 0 deletions easyvvuq/sampling/mc_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,12 @@ def saltelli(self, n_mc):
# number of different sampling matrices
step = self.n_params + 2
# store M2 first, with entries separated by step places
if M_2.ndim == 1:
M_2 = M_2.reshape([-1, 1])
self.xi_mc[0:self.max_num:step] = M_2
# store M1 entries last
if M_1.ndim == 1:
M_1 = M_1.reshape([-1, 1])
self.xi_mc[(step - 1):self.max_num:step] = M_1
# store N_i entries between M2 and M1
for i in range(self.n_params):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_sampling_mc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
import chaospy as cp
from easyvvuq.sampling import MCSampler
from easyvvuq.sampling.base import Vary


def test_sampling():
vary = {'a': cp.Uniform(-5, 0), 'b': cp.Uniform(2, 10)}
sampler = MCSampler(vary, 100)
assert(sampler.n_samples() == 400)
for _ in range(sampler.n_samples()):
sample = next(sampler)
assert(sample['a'] >= -5 and sample['a'] <= 0)
assert(sample['b'] >= 2 and sample['b'] <= 10)
with pytest.raises(StopIteration):
next(sampler)


def test_sampling_1D():
vary = {'a': cp.Uniform(-1, 1)}
sampler = MCSampler(vary, 100)
# This used to fail in the saltelli subroutine if there was only 1 input
for _ in range(sampler.n_samples()):
sample = next(sampler)

0 comments on commit fac0b57

Please sign in to comment.