Skip to content

Commit

Permalink
Add changes based on second review
Browse files Browse the repository at this point in the history
  • Loading branch information
k-shep committed Mar 28, 2024
1 parent 513891e commit 9be1b06
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 49 deletions.
94 changes: 48 additions & 46 deletions pints/_log_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,11 +865,11 @@ def evaluateS1(self, x):
class CensoredGaussianLogLikelihood(pints.ProblemLogLikelihood):
r""" Calculates a log-likelihood assuming independent Gaussian noise at
each time point, and that data above and below certain limits are censored.
In other words, any data values less than the lower limit are recorded as
equal to it; and any values greater than the upper limit are recorded as
equal to it. This likelihood is useful for censored data - see for
instance [3]_. The parameter sigma represents the standard deviation
of the noise on each output.
In other words, any data values less than the lower limit are only known to
be less than or equal to it; and any values greater than the upper limit
are only known to be greater than or equal to it. This likelihood is useful
for censored data - see for instance [3]_. The parameter sigma represents
the standard deviation of the noise on each output.
For a noise level of ``sigma``, and left and right censoring (below the
``lower`` limit cutoff and above the ``upper`` limit), the likelihood
Expand Down Expand Up @@ -936,8 +936,6 @@ class CensoredGaussianLogLikelihood(pints.ProblemLogLikelihood):
The lower limit for censoring.
upper
The upper limit for censoring.
verbose
When True, the input data and the values that are censored are printed.
References
----------
Expand All @@ -947,7 +945,7 @@ class CensoredGaussianLogLikelihood(pints.ProblemLogLikelihood):
"""

def __init__(self, problem, lower=None, upper=None, verbose=True):
def __init__(self, problem, lower=None, upper=None):
super(CensoredGaussianLogLikelihood, self).__init__(problem)

# Get number of times, number of outputs
Expand Down Expand Up @@ -978,13 +976,16 @@ def __init__(self, problem, lower=None, upper=None, verbose=True):
self._a = a
self._b = b

# Define the condition for whether a point is not censored
self._condition = (self._a < self._values
) & (self._values < self._b)
# Define the conditions for whether a point is lower censored,
# upper censored and not censored
self._lower_condition = self._values <= self._a
self._upper_condition = self._values >= self._b
self._not_censored_condition = (self._a < self._values
) & (self._values < self._b)

# Number of points that aren't censored (for each observation
# in the multioutput case)
self._n_not_censored = np.sum(self._condition, axis=0)
self._n_not_censored = np.sum(self._not_censored_condition, axis=0)

def _convert_type(self, limit, limit_type="lower"):

Expand All @@ -997,25 +998,13 @@ def _convert_type(self, limit, limit_type="lower"):
limit = np.inf

# Convert the limit to an object of the correct type
if limit is not None:
if np.isscalar(limit):
limit = np.ones(self._no) * float(limit)
else:
limit = pints.vector(limit)
if np.isscalar(limit):
limit = np.ones(self._no) * float(limit)
else:
limit = pints.vector(limit)

return limit

def print_censored_values(self):

# Print data and values that are censored
print("The data are {}. \n The lower censored values"
" are {}. \n The upper censored values"
" are {}.".format(self._values,
np.extract(self._values <= self._a,
self._values),
np.extract(self._values >= self._b,
self._values)))

def __call__(self, x):
theta = np.asarray(x[:-self._no])
sigma = np.asarray(x[-self._no:])
Expand All @@ -1027,16 +1016,16 @@ def __call__(self, x):
output = self._problem.evaluate(theta)

squared_error = np.sum((self._values - output)**2,
axis=0, where=self._condition)
axis=0, where=self._not_censored_condition)

# Calculate part of the likelihood corresponding to the censored data
lower_censored_sum = np.sum(np.log(
scipy.stats.norm.cdf(x=self._a, loc=output, scale=sigma)),
where=self._values <= self._a)
where=self._lower_condition)
upper_censored_sum = np.sum(
np.log(1 - scipy.stats.norm.
cdf(x=self._b, loc=output, scale=sigma)),
where=self._values >= self._b)
where=self._upper_condition)

# Calculate part of the likelihood corresponding to
# the data that isn't censored
Expand Down Expand Up @@ -1073,25 +1062,26 @@ def evaluateS1(self, x):
# Make conditions for where data isn't censored and is lower/upper
# censored into 3D arrays

lower_condition = self._values <= self._a
upper_condition = self._values >= self._b

if self._values.ndim == 1:
where_condition = np.reshape(
self._condition, newshape=(np.shape(self._condition)[0], 1, 1))
self._not_censored_condition,
newshape=(np.shape(self._not_censored_condition)[0], 1, 1))
lower_where_condition = np.reshape(
lower_condition, newshape=(np.shape(lower_condition)[0], 1, 1))
self._lower_condition,
newshape=(np.shape(self._lower_condition)[0], 1, 1))
upper_where_condition = np.reshape(
upper_condition, newshape=(np.shape(upper_condition)[0], 1, 1))
self._upper_condition,
newshape=(np.shape(self._upper_condition)[0], 1, 1))
else:
where_condition = np.repeat(self._condition[:, :, np.newaxis],
np.shape(self._condition)[-1], axis=2)
where_condition = np.repeat(
self._not_censored_condition[:, :, np.newaxis],
np.shape(self._not_censored_condition)[-1], axis=2)
lower_where_condition = np.repeat(
lower_condition[:, :, np.newaxis],
np.shape(lower_condition)[-1], axis=2)
self._lower_condition[:, :, np.newaxis],
np.shape(self._lower_condition)[-1], axis=2)
upper_where_condition = np.repeat(
upper_condition[:, :, np.newaxis],
np.shape(upper_condition)[-1], axis=2)
self._upper_condition[:, :, np.newaxis],
np.shape(self._upper_condition)[-1], axis=2)

# 1. Parts of the derivative corresponding to the data that
# isn't censored
Expand All @@ -1103,7 +1093,8 @@ def evaluateS1(self, x):

# Calculate derivative wrt sigma
not_censored_dsigma = -self._n_not_censored / sigma + sigma**(-3.0) *\
np.sum((self._values - y)**2, axis=0, where=self._condition)
np.sum((self._values - y)**2, axis=0,
where=self._not_censored_condition)

# 2. Parts of the derivative corresponding to the data that is
# censored
Expand Down Expand Up @@ -1140,13 +1131,13 @@ def evaluateS1(self, x):
(self._values - y).T) / (lower_cdf.T)
lower_censored_dsigma = - sigma**(-2) *\
np.sum(lower_dsigma_inner_val.T,
where=self._values <= self._a, axis=0).T
where=self._lower_condition, axis=0).T

upper_dsigma_inner_val = (upper_pdf.T *
(self._values - y).T) / (1 - upper_cdf.T)
upper_censored_dsigma = sigma**(-2) *\
np.sum(upper_dsigma_inner_val.T,
where=self._values >= self._b, axis=0).T
where=self._upper_condition, axis=0).T

dL = not_censored_dL + lower_censored_dL + upper_censored_dL
dsigma = not_censored_dsigma + lower_censored_dsigma + \
Expand All @@ -1157,6 +1148,17 @@ def evaluateS1(self, x):
# Return
return L, dL

def print_censored_values(self):

# Print data and values that are censored
print("The data are {}. \n The lower censored values"
" are {}. \n The upper censored values"
" are {}.".format(self._values,
np.extract(self._lower_condition,
self._values),
np.extract(self._upper_condition,
self._values)))


class KnownNoiseLogLikelihood(GaussianKnownSigmaLogLikelihood):
""" Deprecated alias of :class:`GaussianKnownSigmaLogLikelihood`. """
Expand Down
6 changes: 3 additions & 3 deletions pints/tests/test_log_likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2805,8 +2805,8 @@ def test_bad_constructor(self):
ValueError, pints.CensoredGaussianLogLikelihood, problem,
lower=0.1, upper=0)

def test_stdout(self):
# Check prints correct output when verbose=True
def test_print_censored_values(self):
# Check the function print_censored_values prints the correct output

with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
Expand All @@ -2817,7 +2817,7 @@ def test_stdout(self):
# Create log_likelihood
log_likelihood = pints.\
CensoredGaussianLogLikelihood(problem, lower=0.2,
upper=0.8, verbose=True)
upper=0.8)
log_likelihood.print_censored_values()

# Add \n at end due to how print statements work in Python
Expand Down

0 comments on commit 9be1b06

Please sign in to comment.