Skip to content

Commit

Permalink
Remove Gaussian.negate attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed Oct 11, 2021
1 parent fe0c7c5 commit 10b3432
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 133 deletions.
6 changes: 1 addition & 5 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,6 @@ def eager_normal(loc, scale, value):
white_vec=white_vec,
prec_sqrt=prec_sqrt,
inputs=inputs,
negate=False,
)
return gaussian(**{var: value - loc})

Expand All @@ -870,10 +869,7 @@ def eager_mvn(loc, scale_tril, value):
var = gensym("value")
inputs[var] = Reals[scale_diag.shape[0]]
gaussian = log_prob + Gaussian(
white_vec=white_vec,
prec_sqrt=prec_sqrt,
inputs=inputs,
negate=False,
white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs
)
return gaussian(**{var: value - loc})

Expand Down
98 changes: 18 additions & 80 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@
from funsor.affine import affine_inputs, extract_affine, is_affine
from funsor.delta import Delta
from funsor.domains import Real, Reals
from funsor.ops import AddOp, NegOp, SubOp
from funsor.ops import AddOp
from funsor.tensor import Tensor, align_tensor, align_tensors
from funsor.terms import (
Align,
Binary,
Funsor,
FunsorMeta,
Number,
Slice,
Subs,
Unary,
Variable,
eager,
reflect,
Expand Down Expand Up @@ -334,23 +332,13 @@ def __call__(
white_vec=None,
prec_sqrt=None,
inputs=None,
negate=None,
*,
mean=None,
info_vec=None,
precision=None,
scale_tril=None,
covariance=None,
):
# We intentionally avoid a default value for negate
# so as to loudly error in old code that uses the obsolete interface
# Gaussian(info_vec, precision, inputs).
if negate is None:
raise ValueError(
"Missing negate argument to Gaussian(). Note interface changes."
)
assert isinstance(negate, bool)

# Convert inputs.
assert inputs is not None
if isinstance(inputs, OrderedDict):
Expand Down Expand Up @@ -396,12 +384,10 @@ def __call__(
white_vec, prec_sqrt, shift = _compress_rank(white_vec, prec_sqrt)

# Create a Gaussian.
result = super().__call__(white_vec, prec_sqrt, inputs, negate)
result = super().__call__(white_vec, prec_sqrt, inputs)

# Add compression byproducts.
if shift is not None:
if negate:
shift = -shift
int_inputs = OrderedDict((k, v) for k, v in inputs if v.dtype != "real")
result += Tensor(shift, int_inputs)

Expand Down Expand Up @@ -431,8 +417,7 @@ class Gaussian(Funsor, metaclass=GaussianMeta):
Not only are Gaussians non-normalized, but they may be rank deficient
and non-normalizable, in which case sampling and marginalization are
not supported. See the :meth:`rank` , :meth:`is_full_rank` , and
:meth:`is_normalizable` properties.
not supported. See the :meth:`rank` and :meth:`is_full_rank` properties.
:param torch.Tensor white_vec: An batched white noise vector, where
``white_vec = prec_sqrt.T @ mean``. Alternatively you can specify one
Expand All @@ -447,12 +432,9 @@ class Gaussian(Funsor, metaclass=GaussianMeta):
be converted to ``prec_sqrt``.
:param OrderedDict inputs: Mapping from name to
:class:`~funsor.domains.Domain` .
:param bool negate: If false this represents a concave function 🙁. If true,
this represents a convex function 🙂. Convex log densities are not
normalizable and do not support marginalization or sampling.
"""

def __init__(self, white_vec, prec_sqrt, inputs, negate):
def __init__(self, white_vec, prec_sqrt, inputs):
assert ops.is_numeric_array(white_vec) and ops.is_numeric_array(prec_sqrt)
assert isinstance(inputs, tuple)
inputs = OrderedDict(inputs)
Expand Down Expand Up @@ -480,7 +462,6 @@ def __init__(self, white_vec, prec_sqrt, inputs, negate):
super().__init__(inputs, output, fresh, bound)
self.white_vec = white_vec
self.prec_sqrt = prec_sqrt
self.negate = negate
self.batch_shape = batch_shape
self.event_shape = (dim,)

Expand All @@ -495,18 +476,9 @@ def rank(self):

@property
def is_full_rank(self):
if self.negate:
return False
dim, rank = self.prec_sqrt.shape[-2:]
return rank == dim

@property
def is_normalizable(self):
"""
Whether this Gaussian is full rank and not negated.
"""
return self.is_full_rank and not self.negate

# TODO Consider weak-memoizing these so they persist through alpha conversion.
# https://github.com/pyro-ppl/pyro/blob/ac3c588/pyro/distributions/coalescent.py#L412
@lazy_property
Expand Down Expand Up @@ -559,7 +531,7 @@ def align(self, names):
inputs = OrderedDict((name, self.inputs[name]) for name in names)
inputs.update(self.inputs)
white_vec, prec_sqrt = align_gaussian(inputs, self)
return Gaussian(white_vec, prec_sqrt, inputs, self.negate)
return Gaussian(white_vec, prec_sqrt, inputs)

def eager_subs(self, subs):
assert isinstance(subs, tuple)
Expand Down Expand Up @@ -616,7 +588,7 @@ def _eager_subs_var(self, subs, remaining_subs):
inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items())
if len(inputs) != len(self.inputs):
raise ValueError("Variable substitution name conflict")
var_result = Gaussian(self.white_vec, self.prec_sqrt, inputs, self.negate)
var_result = Gaussian(self.white_vec, self.prec_sqrt, inputs)
return Subs(var_result, remaining_subs) if remaining_subs else var_result

def _eager_subs_int(self, subs, remaining_subs):
Expand All @@ -631,7 +603,7 @@ def _eager_subs_int(self, subs, remaining_subs):
funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors]
inputs = funsors[0].inputs.copy()
inputs.update(real_inputs)
int_result = Gaussian(funsors[0].data, funsors[1].data, inputs, self.negate)
int_result = Gaussian(funsors[0].data, funsors[1].data, inputs)
return Subs(int_result, remaining_subs) if remaining_subs else int_result

def _eager_subs_real(self, subs, remaining_subs):
Expand Down Expand Up @@ -674,9 +646,7 @@ def _eager_subs_real(self, subs, remaining_subs):
value = value.as_tensor()

# Evaluate the non-normalized log density.
result = (0.5 if self.negate else -0.5) * _norm2(
_vm(value, prec_sqrt) - white_vec
)
result = -0.5 * _norm2(_vm(value, prec_sqrt) - white_vec)
result = Tensor(result, int_inputs)
assert result.output == Real
return Subs(result, remaining_subs) if remaining_subs else result
Expand All @@ -701,7 +671,7 @@ def _eager_subs_real(self, subs, remaining_subs):
for k, d in self.inputs.items():
if k not in subs:
inputs[k] = d
result = Gaussian(white_vec_a, prec_sqrt_a, inputs, self.negate)
result = Gaussian(white_vec_a, prec_sqrt_a, inputs)
return Subs(result, remaining_subs) if remaining_subs else result

def _eager_subs_affine(self, subs, remaining_subs):
Expand Down Expand Up @@ -796,14 +766,14 @@ def _eager_subs_affine(self, subs, remaining_subs):
# where P' = A P and w' = w - b P parametrize the new Gaussian.
white_vec = white_vec - _vm(subs_vector, prec_sqrt)
prec_sqrt = subs_matrix @ prec_sqrt
result = Gaussian(white_vec, prec_sqrt, new_inputs, self.negate)
result = Gaussian(white_vec, prec_sqrt, new_inputs)
return Subs(result, remaining_subs) if remaining_subs else result

def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.inputs)
if op is ops.logaddexp:
# Marginalize out real variables, but keep mixtures lazy.
assert self.is_normalizable
assert self.is_full_rank
assert all(v in self.inputs for v in reduced_vars)
real_vars = frozenset(
k for k, d in self.inputs.items() if d.dtype == "real"
Expand Down Expand Up @@ -877,9 +847,7 @@ def eager_reduce(self, op, reduced_vars):
proj_b = _mtm(ops.triangular_solve(prec_sqrt_b, precision_chol_b))
prec_sqrt = prec_sqrt_a - prec_sqrt_a @ proj_b
white_vec = self.white_vec - _vm(self.white_vec, proj_b)
result = b_log_normalizer + Gaussian(
white_vec, prec_sqrt, inputs, False
)
result = b_log_normalizer + Gaussian(white_vec, prec_sqrt, inputs)
else:
raise NotImplementedError(
f"rank = {self.rank:d}, marginalised_dim = {dim_b:d}, "
Expand Down Expand Up @@ -921,15 +889,15 @@ def eager_reduce(self, op, reduced_vars):
assert prec_sqrt.shape[:-2] == white_vec.shape[:-1]
assert prec_sqrt.shape[-1] == white_vec.shape[-1]

return Gaussian(white_vec, prec_sqrt, inputs, self.negate)
return Gaussian(white_vec, prec_sqrt, inputs)

return None # defer to default implementation

def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
assert self.is_normalizable
assert self.is_full_rank
if any(self.inputs[k].dtype != "real" for k in sampled_vars):
raise ValueError(
"Sampling from non-normalized Gaussian mixtures is intentionally "
Expand Down Expand Up @@ -996,40 +964,10 @@ def eager_add_gaussian_gaussian(op, lhs, rhs):
lhs_white_vec, lhs_prec_sqrt = align_gaussian(inputs, lhs, expand=True)
rhs_white_vec, rhs_prec_sqrt = align_gaussian(inputs, rhs, expand=True)

if lhs.negate == rhs.negate:
# Fuse aligned Gaussians via concatenation.
white_vec = ops.cat([lhs_white_vec, rhs_white_vec], -1)
prec_sqrt = ops.cat([lhs_prec_sqrt, rhs_prec_sqrt], -1)
return Gaussian(white_vec, prec_sqrt, inputs, lhs.negate)

# Subtract Gaussians.
lhs_info_vec = _mv(lhs_prec_sqrt, lhs_white_vec)
rhs_info_vec = _mv(rhs_prec_sqrt, rhs_white_vec)
lhs_precision = _mmt(lhs_prec_sqrt)
rhs_precision = _mmt(rhs_prec_sqrt)
if lhs.negate:
info_vec = rhs_info_vec - lhs_info_vec
precision = rhs_precision - lhs_precision
else:
info_vec = lhs_info_vec - rhs_info_vec
precision = lhs_precision - rhs_precision
return Gaussian(
info_vec=info_vec,
precision=precision,
inputs=inputs,
negate=False,
)


@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian))
@eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian)
def eager_sub(op, lhs, rhs):
return lhs + -rhs


@eager.register(Unary, NegOp, Gaussian)
def eager_neg(op, arg):
return Gaussian(arg.white_vec, arg.prec_sqrt, arg.inputs, not arg.negate)
# Fuse aligned Gaussians via concatenation.
white_vec = ops.cat([lhs_white_vec, rhs_white_vec], -1)
prec_sqrt = ops.cat([lhs_prec_sqrt, rhs_prec_sqrt], -1)
return Gaussian(white_vec, prec_sqrt, inputs)


__all__ = [
Expand Down
7 changes: 2 additions & 5 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def eager_integrate(log_measure, integrand, reduced_vars):

@eager.register(Integrate, Gaussian, Gaussian, frozenset)
def eager_integrate(log_measure, integrand, reduced_vars):
assert not log_measure.negate
assert log_measure.is_full_rank
reduced_names = frozenset(v.name for v in reduced_vars)
real_vars = frozenset(v.name for v in reduced_vars if v.dtype == "real")
if real_vars:
Expand All @@ -212,10 +212,7 @@ def eager_integrate(log_measure, integrand, reduced_vars):
lhs_white_vec, lhs_prec_sqrt = align_gaussian(inputs, log_measure)
rhs_white_vec, rhs_prec_sqrt = align_gaussian(inputs, integrand)
lhs = Gaussian(
white_vec=lhs_white_vec,
prec_sqrt=lhs_prec_sqrt,
inputs=inputs,
negate=False,
white_vec=lhs_white_vec, prec_sqrt=lhs_prec_sqrt, inputs=inputs
)

# Compute the expectation of a non-normalized quadratic form.
Expand Down
7 changes: 1 addition & 6 deletions funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss

new_inputs = new_loc.inputs.copy()
new_inputs.update((k, d) for k, d in gaussian.inputs.items() if d.dtype == "real")
new_gaussian = Gaussian(
mean=new_loc.data,
covariance=new_cov,
inputs=new_inputs,
negate=False,
)
new_gaussian = Gaussian(mean=new_loc.data, covariance=new_cov, inputs=new_inputs)
new_discrete -= new_gaussian.log_normalizer

return new_discrete + new_gaussian
Expand Down
13 changes: 2 additions & 11 deletions funsor/pyro/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,7 @@ def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()):
)
inputs.update(real_inputs)
return discrete + Gaussian(
white_vec=gaussian.white_vec,
prec_sqrt=gaussian.prec_sqrt,
inputs=inputs,
negate=False,
white_vec=gaussian.white_vec, prec_sqrt=gaussian.prec_sqrt, inputs=inputs
)


Expand Down Expand Up @@ -265,12 +262,7 @@ def matrix_and_mvn_to_funsor(
inputs[i.name] = i.dtype
inputs[x_name] = Reals[x_size]
inputs[y_i.name] = Real
g_i = Gaussian(
white_vec=white_vec,
prec_sqrt=prec_sqrt,
inputs=inputs,
negate=False,
)
g_i = Gaussian(white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs)

# Convert to a joint Gaussian over x and y, possibly lazily.
# This expands the y part of the matrix from linear to square,
Expand Down Expand Up @@ -304,6 +296,5 @@ def matrix_and_mvn_to_funsor(
white_vec=white_vec.expand(batch_shape + (-1,)),
prec_sqrt=prec_sqrt.expand(batch_shape + (-1, -1)),
inputs=inputs,
negate=False,
)
return g + log_prob
2 changes: 1 addition & 1 deletion funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def random_gaussian(inputs):
prec_sqrt = ops.cholesky(precision)
loc = randn(batch_shape + event_shape)
white_vec = ops.matmul(prec_sqrt, ops.unsqueeze(loc, -1)).squeeze(-1)
return Gaussian(white_vec, prec_sqrt, inputs, False)
return Gaussian(white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs)


def random_mvn(batch_shape, dim, diag=False):
Expand Down
Loading

0 comments on commit 10b3432

Please sign in to comment.