Skip to content

Commit

Permalink
Add where to flow (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
nstarman authored Aug 4, 2023
1 parent 3faedd2 commit 5bba9c6
Showing 1 changed file with 34 additions and 7 deletions.
41 changes: 34 additions & 7 deletions src/stream_ml/pytorch/builtin/compat/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch as xp

from stream_ml.core.params.scaler import scale_params
from stream_ml.core.builtin._utils import WhereRequiredError

from stream_ml.pytorch._base import ModelBase
from stream_ml.pytorch.utils import names_intersect
Expand All @@ -30,7 +30,13 @@ class FlowModel(ModelBase):
with_grad: bool = True

def ln_likelihood(
self, mpars: Params[Array], /, data: Data[Array], **kwargs: Array
self,
mpars: Params[Array],
/,
data: Data[Array],
*,
where: Data[Array] | None = None,
**kwargs: Array,
) -> Array:
"""Log-likelihood of the array.
Expand All @@ -43,26 +49,47 @@ def ln_likelihood(
data : Data[Array]
Data (phi1, phi2).
where : Data[Array[(N,), bool]] | None, optional keyword-only
Where to evaluate the log-likelihood. If not provided, then the
log-likelihood is evaluated at all data points. ``where`` must
contain the fields in ``coord_names``. Each field must be a boolean
array of the same length as `data`. `True` indicates that the data
point is available, and `False` indicates that the data point is not
available.
**kwargs : Array
Additional arguments.
Returns
-------
Array
"""
# TODO: support `where` argument.
# 'where' is used to indicate which data points are available. If
# 'where' is not provided, then all data points are assumed to be
# available.
where_: Array # (N, F)
if where is not None:
where_ = where[self.coord_names].array
elif self.require_where:
raise WhereRequiredError
else:
where_ = self.xp.ones((len(data), self.ndim), dtype=bool)
idx = where_.all(axis=1)
# TODO: allow for missing data in only some of the dimensions

data = self.data_scaler.transform(
data, names=names_intersect(data, self.data_scaler), xp=self.xp
)
mpars = scale_params(self, mpars)

out = self.xp.zeros(len(data), dtype=data.dtype)
with nullcontext() if self.with_grad else xp.no_grad():
return self.jacobian_logdet + self.net.log_prob(
inputs=data[self.coord_names].array,
context=data[self.indep_coord_names].array
out[idx] = self.jacobian_logdet + self.net.log_prob(
inputs=data[self.coord_names].array[idx],
context=data[self.indep_coord_names].array[idx]
if self.indep_coord_names is not None
else None,
)
return out

def forward(self, data: Data[Array]) -> Array:
"""Forward pass.
Expand Down

0 comments on commit 5bba9c6

Please sign in to comment.