From 9fc9c55d631d28645a1c9aa66b01740a4dd83e81 Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 4 Aug 2023 00:42:38 -0700 Subject: [PATCH 1/3] refactor multinormal for better masking Signed-off-by: nstarman --- src/stream_ml/pytorch/builtin/__init__.py | 6 +- .../pytorch/builtin/_isochrone/core.py | 1 - src/stream_ml/pytorch/builtin/_multinormal.py | 163 +++++++++--------- 3 files changed, 78 insertions(+), 92 deletions(-) diff --git a/src/stream_ml/pytorch/builtin/__init__.py b/src/stream_ml/pytorch/builtin/__init__.py index 782864f..422f38e 100644 --- a/src/stream_ml/pytorch/builtin/__init__.py +++ b/src/stream_ml/pytorch/builtin/__init__.py @@ -20,7 +20,6 @@ "Parallax2DistMod", # -- multivariate "MultivariateNormal", - "MultivariateMissingNormal", ] from dataclasses import field, make_dataclass @@ -42,10 +41,7 @@ StreamMassFunction, UniformStreamMassFunction, ) -from stream_ml.pytorch.builtin._multinormal import ( - MultivariateMissingNormal, - MultivariateNormal, -) +from stream_ml.pytorch.builtin._multinormal import MultivariateNormal from stream_ml.pytorch.builtin._skewnorm import SkewNormal from stream_ml.pytorch.builtin._sloped import Sloped from stream_ml.pytorch.builtin._truncskewnorm import TruncatedSkewNormal diff --git a/src/stream_ml/pytorch/builtin/_isochrone/core.py b/src/stream_ml/pytorch/builtin/_isochrone/core.py index 08a0e1c..09e88b7 100644 --- a/src/stream_ml/pytorch/builtin/_isochrone/core.py +++ b/src/stream_ml/pytorch/builtin/_isochrone/core.py @@ -336,7 +336,6 @@ def ln_likelihood( - mean[sel] ) - lnliks = xp.zeros((len(data), len(self._gamma_points))) # (N, I) lnliks = -0.5 * ( # (N, I, 1, 1) -> (N, I) D[:, None] * _log2pi + logdet diff --git a/src/stream_ml/pytorch/builtin/_multinormal.py b/src/stream_ml/pytorch/builtin/_multinormal.py index e9cdce8..9f6fc9d 100644 --- a/src/stream_ml/pytorch/builtin/_multinormal.py +++ b/src/stream_ml/pytorch/builtin/_multinormal.py @@ -4,11 +4,12 @@ __all__: list[str] = [] -from dataclasses import KW_ONLY, dataclass +from dataclasses import dataclass from typing import TYPE_CHECKING import torch as xp -from torch.distributions import MultivariateNormal as TorchMultivariateNormal + +from stream_ml.core.builtin._utils import WhereRequiredError from stream_ml.pytorch._base import ModelBase @@ -34,7 +35,9 @@ def ln_likelihood( /, data: Data[Array], *, + where: Data[Array] | None = None, correlation_matrix: Array | None = None, + correlation_det: Array | None = None, **kwargs: Array, ) -> Array: r"""Log-likelihood of the distribution. @@ -47,16 +50,26 @@ def ln_likelihood( data : Data[Array] Data (phi1, phi2, ...). - correlation_matrix : Array[(N,F,F)], optional keyword-only - The correlation matrix. If not provided, then the covariance matrix is - assumed to be diagonal. - The covariance matrix is computed as: + where : Data[Array[(N,), bool]] | None, optional keyword-only + Where to evaluate the log-likelihood. If not provided, then the + log-likelihood is evaluated at all data points. ``where`` must + contain the fields in ``phot_names``. Each field must be a boolean + array of the same length as `data`. `True` indicates that the data + point is available, and `False` indicates that the data point is not + available. + + correlation_matrix : Array[(N,F,F)] | None, optional keyword-only + The correlation matrix. If not provided, then the covariance matrix + is assumed to be diagonal. The covariance matrix is computed as: .. math:: \rm{cov}(X) = \rm{diag}(\vec{\sigma}) - \cdot \rm{corr} - \cdot \rm{diag}(\vec{\sigma}) + \cdot \rm{corr} \cdot \rm{diag}(\vec{\sigma}) + correlation_det: Array[(N,)] | None, optional keyword-only + The determinant of the correlation matrix. If not provided, then + the determinant is only the product of the diagonal elements of the + covariance matrix. **kwargs : Array Additional arguments. @@ -65,88 +78,66 @@ def ln_likelihood( ------- Array """ - marginals = xp.diag_embed( - self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names)) + # 'where' is used to indicate which data points are available. If + # 'where' is not provided, then all data points are assumed to be + # available. + where_: Array # (N, F) + if where is not None: + where_ = where[self.coord_names].array + elif self.require_where: + raise WhereRequiredError + else: + where_ = self.xp.ones((len(data), self.nF), dtype=bool) + + if correlation_matrix is not None and correlation_det is None: + msg = "Must provide `correlation_det`." + raise ValueError(msg) + + # Covariance: model (N, F, F) + lnsigma = self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names)) + cov_model = xp.diag_embed(xp.exp(2 * lnsigma)) + # Covariance data "The covariance matrix can be written as the rescaling + # of a correlation matrix by the marginal variances:" + # (https://en.wikipedia.org/wiki/Covariance_matrix#Correlation_matrix) + stds = data[self.coord_err_names].array + std_data = ( + xp.diag_embed(stds) + if self.coord_err_names is not None + else self.xp.zeros(1) ) - cov = ( - marginals @ marginals + cov_data = ( + std_data**2 if correlation_matrix is None - else marginals @ correlation_matrix @ marginals + else std_data @ correlation_matrix[:, :, :] @ std_data ) - - return TorchMultivariateNormal( - self._stack_param(mpars, "mu", self.coord_names), - covariance_matrix=cov, - ).log_prob(data[self.coord_names].array) - - -############################################################################## - - -@dataclass(unsafe_hash=True) -class MultivariateMissingNormal(MultivariateNormal): # (MultivariateNormal) - """Multivariate Normal with missing data. - - .. note:: - - Currently this requires a diagonal covariance matrix. - """ - - _: KW_ONLY - require_mask: bool = True - - def ln_likelihood( - self, - mpars: Params[Array], - /, - data: Data[Array], - *, - mask: Data[Array] | None = None, - **kwargs: Array, - ) -> Array: - """Negative log-likelihood. - - Parameters - ---------- - mpars : Params[Array], positional-only - Model parameters. Note that these are different from the ML - parameters. - data : Data[Array] - Labelled data. - mask : Data[Array[bool]] | None, optional - Data availability. `True` if data is available, `False` if not. - Should have the same keys as `data`. - **kwargs : Array - Additional arguments. - """ - # Normal - x = data[self.coord_names].array + # The covariance, setting non-observed dimensions to 0. (N, F, F) + # positive definite. + idx_cov = xp.diag_embed(where_.to(dtype=data.dtype)) # (N, F, F) + cov = idx_cov @ (cov_data + cov_model) @ idx_cov + # The determinant, dropping the dimensionality of non-observed + # dimensions. + logdet = xp.log( + xp.linalg.det(cov + (xp.eye(self.nF)[None, None] - idx_cov)) + ) # (N, [I]) + + # Dimensionality, dropping missing dimensions (N, [I]) + D = where_.sum(dim=-1) # noqa: N806 + + # Construct the data - mean (N, I, F), setting non-observed dimensions to 0. mu = self._stack_param(mpars, "mu", self.coord_names) - sigma = self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names)) - - idx: Array - if mask is not None: - idx = mask[tuple(self.coord_bounds.keys())].array - elif self.require_mask: - msg = "mask is required" - raise ValueError(msg) - else: - idx = xp.ones_like(x, dtype=xp.int) - # shape (1, F) so that it can broadcast with (N, F) - - D = idx.sum(dim=1) # Dimensionality (N,) # noqa: N806 - dmm = idx * (x - mu) # Data - model (N, F) - - # Covariance related - cov = idx * sigma**2 # (N, F) positive definite - det = (cov * idx + (1 - idx)).prod(dim=1) # (N,) + sel = where_[:, None, :].expand(-1, self.nI, -1) + x = xp.zeros((len(data), self.nI, self.nF), dtype=data.dtype) + x[sel] = ( + data[self.coord_names].array[:, None, :].expand(-1, self.nI, -1)[sel] + - mu[sel] + ) - return -0.5 * ( + return -0.5 * ( # (N, I, 1, 1) -> (N, I) D * _log2pi - + xp.log(det) + + logdet + ( - dmm[:, None, :] # (N, 1, F) - @ xp.linalg.pinv(xp.diag_embed(cov)) # (N, F, F) - @ dmm[..., None] # (N, F, 1) - ).flatten() # (N, 1, 1) -> (N,) - ) # (N,) + x[:, None, :] # (N, 1, F) + @ xp.linalg.pinv(cov) # (N, F, F) + @ x[..., None] # (N, F, 1) + )[..., 0, 0] + ) From 8a87b513a226398faf97079cf3d6ed196751b39c Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 11 Aug 2023 10:56:13 -0400 Subject: [PATCH 2/3] skip if no err names Signed-off-by: nstarman --- src/stream_ml/pytorch/builtin/_multinormal.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/stream_ml/pytorch/builtin/_multinormal.py b/src/stream_ml/pytorch/builtin/_multinormal.py index 9f6fc9d..c1bde27 100644 --- a/src/stream_ml/pytorch/builtin/_multinormal.py +++ b/src/stream_ml/pytorch/builtin/_multinormal.py @@ -99,9 +99,8 @@ def ln_likelihood( # Covariance data "The covariance matrix can be written as the rescaling # of a correlation matrix by the marginal variances:" # (https://en.wikipedia.org/wiki/Covariance_matrix#Correlation_matrix) - stds = data[self.coord_err_names].array std_data = ( - xp.diag_embed(stds) + xp.diag_embed(data[self.coord_err_names].array) if self.coord_err_names is not None else self.xp.zeros(1) ) From 8a9242b2e6e9337f274021e152ec9acc15f0a008 Mon Sep 17 00:00:00 2001 From: nstarman Date: Fri, 11 Aug 2023 11:02:45 -0400 Subject: [PATCH 3/3] fix exp Signed-off-by: nstarman --- src/stream_ml/pytorch/builtin/_multinormal.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/stream_ml/pytorch/builtin/_multinormal.py b/src/stream_ml/pytorch/builtin/_multinormal.py index c1bde27..941af12 100644 --- a/src/stream_ml/pytorch/builtin/_multinormal.py +++ b/src/stream_ml/pytorch/builtin/_multinormal.py @@ -93,9 +93,6 @@ def ln_likelihood( msg = "Must provide `correlation_det`." raise ValueError(msg) - # Covariance: model (N, F, F) - lnsigma = self.xp.exp(self._stack_param(mpars, "ln-sigma", self.coord_names)) - cov_model = xp.diag_embed(xp.exp(2 * lnsigma)) # Covariance data "The covariance matrix can be written as the rescaling # of a correlation matrix by the marginal variances:" # (https://en.wikipedia.org/wiki/Covariance_matrix#Correlation_matrix) @@ -109,9 +106,12 @@ def ln_likelihood( if correlation_matrix is None else std_data @ correlation_matrix[:, :, :] @ std_data ) + # Covariance model (N, F, F) + lnsigma = self._stack_param(mpars, "ln-sigma", self.coord_names) + cov_model = xp.diag_embed(self.xp.exp(2 * lnsigma)) # The covariance, setting non-observed dimensions to 0. (N, F, F) # positive definite. - idx_cov = xp.diag_embed(where_.to(dtype=data.dtype)) # (N, F, F) + idx_cov = xp.diag_embed(where_.to(dtype=data.dtype)) cov = idx_cov @ (cov_data + cov_model) @ idx_cov # The determinant, dropping the dimensionality of non-observed # dimensions.