Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor multinormal for better masking #129

Merged
merged 3 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/stream_ml/pytorch/builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
"Parallax2DistMod",
# -- multivariate
"MultivariateNormal",
"MultivariateMissingNormal",
]

from dataclasses import field, make_dataclass
Expand All @@ -42,10 +41,7 @@
StreamMassFunction,
UniformStreamMassFunction,
)
from stream_ml.pytorch.builtin._multinormal import (
MultivariateMissingNormal,
MultivariateNormal,
)
from stream_ml.pytorch.builtin._multinormal import MultivariateNormal
from stream_ml.pytorch.builtin._skewnorm import SkewNormal
from stream_ml.pytorch.builtin._sloped import Sloped
from stream_ml.pytorch.builtin._truncskewnorm import TruncatedSkewNormal
Expand Down
1 change: 0 additions & 1 deletion src/stream_ml/pytorch/builtin/_isochrone/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def ln_likelihood(
- mean[sel]
)

lnliks = xp.zeros((len(data), len(self._gamma_points))) # (N, I)
lnliks = -0.5 * ( # (N, I, 1, 1) -> (N, I)
D[:, None] * _log2pi
+ logdet
Expand Down
162 changes: 76 additions & 86 deletions src/stream_ml/pytorch/builtin/_multinormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

__all__: list[str] = []

from dataclasses import KW_ONLY, dataclass
from dataclasses import dataclass
from typing import TYPE_CHECKING

import torch as xp
from torch.distributions import MultivariateNormal as TorchMultivariateNormal

from stream_ml.core.builtin._utils import WhereRequiredError

from stream_ml.pytorch._base import ModelBase

Expand All @@ -34,7 +35,9 @@ def ln_likelihood(
/,
data: Data[Array],
*,
where: Data[Array] | None = None,
correlation_matrix: Array | None = None,
correlation_det: Array | None = None,
**kwargs: Array,
) -> Array:
r"""Log-likelihood of the distribution.
Expand All @@ -47,16 +50,26 @@ def ln_likelihood(
data : Data[Array]
Data (phi1, phi2, ...).

correlation_matrix : Array[(N,F,F)], optional keyword-only
The correlation matrix. If not provided, then the covariance matrix is
assumed to be diagonal.
The covariance matrix is computed as:
where : Data[Array[(N,), bool]] | None, optional keyword-only
Where to evaluate the log-likelihood. If not provided, then the
log-likelihood is evaluated at all data points. ``where`` must
contain the fields in ``phot_names``. Each field must be a boolean
array of the same length as `data`. `True` indicates that the data
point is available, and `False` indicates that the data point is not
available.

correlation_matrix : Array[(N,F,F)] | None, optional keyword-only
The correlation matrix. If not provided, then the covariance matrix
is assumed to be diagonal. The covariance matrix is computed as:

.. math::

\rm{cov}(X) = \rm{diag}(\vec{\sigma})
\cdot \rm{corr}
\cdot \rm{diag}(\vec{\sigma})
\cdot \rm{corr} \cdot \rm{diag}(\vec{\sigma})
correlation_det: Array[(N,)] | None, optional keyword-only
The determinant of the correlation matrix. If not provided, then
the determinant is only the product of the diagonal elements of the
covariance matrix.

**kwargs : Array
Additional arguments.
Expand All @@ -65,88 +78,65 @@ def ln_likelihood(
-------
Array
"""
marginals = xp.diag_embed(
self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names))
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
# available.
where_: Array # (N, F)
if where is not None:
where_ = where[self.coord_names].array
elif self.require_where:
raise WhereRequiredError
else:
where_ = self.xp.ones((len(data), self.nF), dtype=bool)

if correlation_matrix is not None and correlation_det is None:
msg = "Must provide `correlation_det`."
raise ValueError(msg)

# Covariance data "The covariance matrix can be written as the rescaling
# of a correlation matrix by the marginal variances:"
# (https://en.wikipedia.org/wiki/Covariance_matrix#Correlation_matrix)
std_data = (
xp.diag_embed(data[self.coord_err_names].array)
if self.coord_err_names is not None
else self.xp.zeros(1)
)
cov = (
marginals @ marginals
cov_data = (
std_data**2
if correlation_matrix is None
else marginals @ correlation_matrix @ marginals
else std_data @ correlation_matrix[:, :, :] @ std_data
)

return TorchMultivariateNormal(
self._stack_param(mpars, "mu", self.coord_names),
covariance_matrix=cov,
).log_prob(data[self.coord_names].array)


##############################################################################


@dataclass(unsafe_hash=True)
class MultivariateMissingNormal(MultivariateNormal): # (MultivariateNormal)
"""Multivariate Normal with missing data.

.. note::

Currently this requires a diagonal covariance matrix.
"""

_: KW_ONLY
require_mask: bool = True

def ln_likelihood(
self,
mpars: Params[Array],
/,
data: Data[Array],
*,
mask: Data[Array] | None = None,
**kwargs: Array,
) -> Array:
"""Negative log-likelihood.

Parameters
----------
mpars : Params[Array], positional-only
Model parameters. Note that these are different from the ML
parameters.
data : Data[Array]
Labelled data.
mask : Data[Array[bool]] | None, optional
Data availability. `True` if data is available, `False` if not.
Should have the same keys as `data`.
**kwargs : Array
Additional arguments.
"""
# Normal
x = data[self.coord_names].array
# Covariance model (N, F, F)
lnsigma = self._stack_param(mpars, "ln-sigma", self.coord_names)
cov_model = xp.diag_embed(self.xp.exp(2 * lnsigma))
# The covariance, setting non-observed dimensions to 0. (N, F, F)
# positive definite.
idx_cov = xp.diag_embed(where_.to(dtype=data.dtype))
cov = idx_cov @ (cov_data + cov_model) @ idx_cov
# The determinant, dropping the dimensionality of non-observed
# dimensions.
logdet = xp.log(
xp.linalg.det(cov + (xp.eye(self.nF)[None, None] - idx_cov))
) # (N, [I])

# Dimensionality, dropping missing dimensions (N, [I])
D = where_.sum(dim=-1) # noqa: N806

# Construct the data - mean (N, I, F), setting non-observed dimensions to 0.
mu = self._stack_param(mpars, "mu", self.coord_names)
sigma = self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names))

idx: Array
if mask is not None:
idx = mask[tuple(self.coord_bounds.keys())].array
elif self.require_mask:
msg = "mask is required"
raise ValueError(msg)
else:
idx = xp.ones_like(x, dtype=xp.int)
# shape (1, F) so that it can broadcast with (N, F)

D = idx.sum(dim=1) # Dimensionality (N,) # noqa: N806
dmm = idx * (x - mu) # Data - model (N, F)

# Covariance related
cov = idx * sigma**2 # (N, F) positive definite
det = (cov * idx + (1 - idx)).prod(dim=1) # (N,)
sel = where_[:, None, :].expand(-1, self.nI, -1)
x = xp.zeros((len(data), self.nI, self.nF), dtype=data.dtype)
x[sel] = (
data[self.coord_names].array[:, None, :].expand(-1, self.nI, -1)[sel]
- mu[sel]
)

return -0.5 * (
return -0.5 * ( # (N, I, 1, 1) -> (N, I)
D * _log2pi
+ xp.log(det)
+ logdet
+ (
dmm[:, None, :] # (N, 1, F)
@ xp.linalg.pinv(xp.diag_embed(cov)) # (N, F, F)
@ dmm[..., None] # (N, F, 1)
).flatten() # (N, 1, 1) -> (N,)
) # (N,)
x[:, None, :] # (N, 1, F)
@ xp.linalg.pinv(cov) # (N, F, F)
@ x[..., None] # (N, F, 1)
)[..., 0, 0]
)