diff --git a/funsor/distribution.py b/funsor/distribution.py index 63197d69..b5bef194 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -848,7 +848,6 @@ def eager_normal(loc, scale, value): white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs, - negate=False, ) return gaussian(**{var: value - loc}) @@ -870,10 +869,7 @@ def eager_mvn(loc, scale_tril, value): var = gensym("value") inputs[var] = Reals[scale_diag.shape[0]] gaussian = log_prob + Gaussian( - white_vec=white_vec, - prec_sqrt=prec_sqrt, - inputs=inputs, - negate=False, + white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs ) return gaussian(**{var: value - loc}) diff --git a/funsor/gaussian.py b/funsor/gaussian.py index 00d9d610..65f83a14 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -12,17 +12,15 @@ from funsor.affine import affine_inputs, extract_affine, is_affine from funsor.delta import Delta from funsor.domains import Real, Reals -from funsor.ops import AddOp, NegOp, SubOp +from funsor.ops import AddOp from funsor.tensor import Tensor, align_tensor, align_tensors from funsor.terms import ( - Align, Binary, Funsor, FunsorMeta, Number, Slice, Subs, - Unary, Variable, eager, reflect, @@ -334,7 +332,6 @@ def __call__( white_vec=None, prec_sqrt=None, inputs=None, - negate=None, *, mean=None, info_vec=None, @@ -342,15 +339,6 @@ def __call__( scale_tril=None, covariance=None, ): - # We intentionally avoid a default value for negate - # so as to loudly error in old code that uses the obsolete interface - # Gaussian(info_vec, precision, inputs). - if negate is None: - raise ValueError( - "Missing negate argument to Gaussian(). Note interface changes." - ) - assert isinstance(negate, bool) - # Convert inputs. assert inputs is not None if isinstance(inputs, OrderedDict): @@ -396,12 +384,10 @@ def __call__( white_vec, prec_sqrt, shift = _compress_rank(white_vec, prec_sqrt) # Create a Gaussian. - result = super().__call__(white_vec, prec_sqrt, inputs, negate) + result = super().__call__(white_vec, prec_sqrt, inputs) # Add compression byproducts. if shift is not None: - if negate: - shift = -shift int_inputs = OrderedDict((k, v) for k, v in inputs if v.dtype != "real") result += Tensor(shift, int_inputs) @@ -431,8 +417,7 @@ class Gaussian(Funsor, metaclass=GaussianMeta): Not only are Gaussians non-normalized, but they may be rank deficient and non-normalizable, in which case sampling and marginalization are - not supported. See the :meth:`rank` , :meth:`is_full_rank` , and - :meth:`is_normalizable` properties. + not supported. See the :meth:`rank` and :meth:`is_full_rank` properties. :param torch.Tensor white_vec: An batched white noise vector, where ``white_vec = prec_sqrt.T @ mean``. Alternatively you can specify one @@ -447,12 +432,9 @@ class Gaussian(Funsor, metaclass=GaussianMeta): be converted to ``prec_sqrt``. :param OrderedDict inputs: Mapping from name to :class:`~funsor.domains.Domain` . - :param bool negate: If false this represents a concave function 🙁. If true, - this represents a convex function 🙂. Convex log densities are not - normalizable and do not support marginalization or sampling. """ - def __init__(self, white_vec, prec_sqrt, inputs, negate): + def __init__(self, white_vec, prec_sqrt, inputs): assert ops.is_numeric_array(white_vec) and ops.is_numeric_array(prec_sqrt) assert isinstance(inputs, tuple) inputs = OrderedDict(inputs) @@ -480,7 +462,6 @@ def __init__(self, white_vec, prec_sqrt, inputs, negate): super().__init__(inputs, output, fresh, bound) self.white_vec = white_vec self.prec_sqrt = prec_sqrt - self.negate = negate self.batch_shape = batch_shape self.event_shape = (dim,) @@ -495,18 +476,9 @@ def rank(self): @property def is_full_rank(self): - if self.negate: - return False dim, rank = self.prec_sqrt.shape[-2:] return rank == dim - @property - def is_normalizable(self): - """ - Whether this Gaussian is full rank and not negated. - """ - return self.is_full_rank and not self.negate - # TODO Consider weak-memoizing these so they persist through alpha conversion. # https://github.com/pyro-ppl/pyro/blob/ac3c588/pyro/distributions/coalescent.py#L412 @lazy_property @@ -559,7 +531,7 @@ def align(self, names): inputs = OrderedDict((name, self.inputs[name]) for name in names) inputs.update(self.inputs) white_vec, prec_sqrt = align_gaussian(inputs, self) - return Gaussian(white_vec, prec_sqrt, inputs, self.negate) + return Gaussian(white_vec, prec_sqrt, inputs) def eager_subs(self, subs): assert isinstance(subs, tuple) @@ -616,7 +588,7 @@ def _eager_subs_var(self, subs, remaining_subs): inputs = OrderedDict((rename.get(k, k), d) for k, d in self.inputs.items()) if len(inputs) != len(self.inputs): raise ValueError("Variable substitution name conflict") - var_result = Gaussian(self.white_vec, self.prec_sqrt, inputs, self.negate) + var_result = Gaussian(self.white_vec, self.prec_sqrt, inputs) return Subs(var_result, remaining_subs) if remaining_subs else var_result def _eager_subs_int(self, subs, remaining_subs): @@ -631,7 +603,7 @@ def _eager_subs_int(self, subs, remaining_subs): funsors = [Subs(Tensor(x, int_inputs), subs) for x in tensors] inputs = funsors[0].inputs.copy() inputs.update(real_inputs) - int_result = Gaussian(funsors[0].data, funsors[1].data, inputs, self.negate) + int_result = Gaussian(funsors[0].data, funsors[1].data, inputs) return Subs(int_result, remaining_subs) if remaining_subs else int_result def _eager_subs_real(self, subs, remaining_subs): @@ -674,9 +646,7 @@ def _eager_subs_real(self, subs, remaining_subs): value = value.as_tensor() # Evaluate the non-normalized log density. - result = (0.5 if self.negate else -0.5) * _norm2( - _vm(value, prec_sqrt) - white_vec - ) + result = -0.5 * _norm2(_vm(value, prec_sqrt) - white_vec) result = Tensor(result, int_inputs) assert result.output == Real return Subs(result, remaining_subs) if remaining_subs else result @@ -701,7 +671,7 @@ def _eager_subs_real(self, subs, remaining_subs): for k, d in self.inputs.items(): if k not in subs: inputs[k] = d - result = Gaussian(white_vec_a, prec_sqrt_a, inputs, self.negate) + result = Gaussian(white_vec_a, prec_sqrt_a, inputs) return Subs(result, remaining_subs) if remaining_subs else result def _eager_subs_affine(self, subs, remaining_subs): @@ -796,14 +766,14 @@ def _eager_subs_affine(self, subs, remaining_subs): # where P' = A P and w' = w - b P parametrize the new Gaussian. white_vec = white_vec - _vm(subs_vector, prec_sqrt) prec_sqrt = subs_matrix @ prec_sqrt - result = Gaussian(white_vec, prec_sqrt, new_inputs, self.negate) + result = Gaussian(white_vec, prec_sqrt, new_inputs) return Subs(result, remaining_subs) if remaining_subs else result def eager_reduce(self, op, reduced_vars): assert reduced_vars.issubset(self.inputs) if op is ops.logaddexp: # Marginalize out real variables, but keep mixtures lazy. - assert self.is_normalizable + assert self.is_full_rank assert all(v in self.inputs for v in reduced_vars) real_vars = frozenset( k for k, d in self.inputs.items() if d.dtype == "real" @@ -877,9 +847,7 @@ def eager_reduce(self, op, reduced_vars): proj_b = _mtm(ops.triangular_solve(prec_sqrt_b, precision_chol_b)) prec_sqrt = prec_sqrt_a - prec_sqrt_a @ proj_b white_vec = self.white_vec - _vm(self.white_vec, proj_b) - result = b_log_normalizer + Gaussian( - white_vec, prec_sqrt, inputs, False - ) + result = b_log_normalizer + Gaussian(white_vec, prec_sqrt, inputs) else: raise NotImplementedError( f"rank = {self.rank:d}, marginalised_dim = {dim_b:d}, " @@ -921,7 +889,7 @@ def eager_reduce(self, op, reduced_vars): assert prec_sqrt.shape[:-2] == white_vec.shape[:-1] assert prec_sqrt.shape[-1] == white_vec.shape[-1] - return Gaussian(white_vec, prec_sqrt, inputs, self.negate) + return Gaussian(white_vec, prec_sqrt, inputs) return None # defer to default implementation @@ -929,7 +897,7 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): sampled_vars = sampled_vars.intersection(self.inputs) if not sampled_vars: return self - assert self.is_normalizable + assert self.is_full_rank if any(self.inputs[k].dtype != "real" for k in sampled_vars): raise ValueError( "Sampling from non-normalized Gaussian mixtures is intentionally " @@ -996,40 +964,10 @@ def eager_add_gaussian_gaussian(op, lhs, rhs): lhs_white_vec, lhs_prec_sqrt = align_gaussian(inputs, lhs, expand=True) rhs_white_vec, rhs_prec_sqrt = align_gaussian(inputs, rhs, expand=True) - if lhs.negate == rhs.negate: - # Fuse aligned Gaussians via concatenation. - white_vec = ops.cat([lhs_white_vec, rhs_white_vec], -1) - prec_sqrt = ops.cat([lhs_prec_sqrt, rhs_prec_sqrt], -1) - return Gaussian(white_vec, prec_sqrt, inputs, lhs.negate) - - # Subtract Gaussians. - lhs_info_vec = _mv(lhs_prec_sqrt, lhs_white_vec) - rhs_info_vec = _mv(rhs_prec_sqrt, rhs_white_vec) - lhs_precision = _mmt(lhs_prec_sqrt) - rhs_precision = _mmt(rhs_prec_sqrt) - if lhs.negate: - info_vec = rhs_info_vec - lhs_info_vec - precision = rhs_precision - lhs_precision - else: - info_vec = lhs_info_vec - rhs_info_vec - precision = lhs_precision - rhs_precision - return Gaussian( - info_vec=info_vec, - precision=precision, - inputs=inputs, - negate=False, - ) - - -@eager.register(Binary, SubOp, Gaussian, (Funsor, Align, Gaussian)) -@eager.register(Binary, SubOp, (Funsor, Align, Delta), Gaussian) -def eager_sub(op, lhs, rhs): - return lhs + -rhs - - -@eager.register(Unary, NegOp, Gaussian) -def eager_neg(op, arg): - return Gaussian(arg.white_vec, arg.prec_sqrt, arg.inputs, not arg.negate) + # Fuse aligned Gaussians via concatenation. + white_vec = ops.cat([lhs_white_vec, rhs_white_vec], -1) + prec_sqrt = ops.cat([lhs_prec_sqrt, rhs_prec_sqrt], -1) + return Gaussian(white_vec, prec_sqrt, inputs) __all__ = [ diff --git a/funsor/integrate.py b/funsor/integrate.py index 80193b15..3c980d27 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -194,7 +194,7 @@ def eager_integrate(log_measure, integrand, reduced_vars): @eager.register(Integrate, Gaussian, Gaussian, frozenset) def eager_integrate(log_measure, integrand, reduced_vars): - assert not log_measure.negate + assert log_measure.is_full_rank reduced_names = frozenset(v.name for v in reduced_vars) real_vars = frozenset(v.name for v in reduced_vars if v.dtype == "real") if real_vars: @@ -212,10 +212,7 @@ def eager_integrate(log_measure, integrand, reduced_vars): lhs_white_vec, lhs_prec_sqrt = align_gaussian(inputs, log_measure) rhs_white_vec, rhs_prec_sqrt = align_gaussian(inputs, integrand) lhs = Gaussian( - white_vec=lhs_white_vec, - prec_sqrt=lhs_prec_sqrt, - inputs=inputs, - negate=False, + white_vec=lhs_white_vec, prec_sqrt=lhs_prec_sqrt, inputs=inputs ) # Compute the expectation of a non-normalized quadratic form. diff --git a/funsor/joint.py b/funsor/joint.py index 0bf32d0e..72fd8426 100644 --- a/funsor/joint.py +++ b/funsor/joint.py @@ -131,12 +131,7 @@ def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gauss new_inputs = new_loc.inputs.copy() new_inputs.update((k, d) for k, d in gaussian.inputs.items() if d.dtype == "real") - new_gaussian = Gaussian( - mean=new_loc.data, - covariance=new_cov, - inputs=new_inputs, - negate=False, - ) + new_gaussian = Gaussian(mean=new_loc.data, covariance=new_cov, inputs=new_inputs) new_discrete -= new_gaussian.log_normalizer return new_discrete + new_gaussian diff --git a/funsor/pyro/convert.py b/funsor/pyro/convert.py index 26f98195..53d8fa45 100644 --- a/funsor/pyro/convert.py +++ b/funsor/pyro/convert.py @@ -154,10 +154,7 @@ def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()): ) inputs.update(real_inputs) return discrete + Gaussian( - white_vec=gaussian.white_vec, - prec_sqrt=gaussian.prec_sqrt, - inputs=inputs, - negate=False, + white_vec=gaussian.white_vec, prec_sqrt=gaussian.prec_sqrt, inputs=inputs ) @@ -265,12 +262,7 @@ def matrix_and_mvn_to_funsor( inputs[i.name] = i.dtype inputs[x_name] = Reals[x_size] inputs[y_i.name] = Real - g_i = Gaussian( - white_vec=white_vec, - prec_sqrt=prec_sqrt, - inputs=inputs, - negate=False, - ) + g_i = Gaussian(white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs) # Convert to a joint Gaussian over x and y, possibly lazily. # This expands the y part of the matrix from linear to square, @@ -304,6 +296,5 @@ def matrix_and_mvn_to_funsor( white_vec=white_vec.expand(batch_shape + (-1,)), prec_sqrt=prec_sqrt.expand(batch_shape + (-1, -1)), inputs=inputs, - negate=False, ) return g + log_prob diff --git a/funsor/testing.py b/funsor/testing.py index 13a60c8c..223a6a08 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -416,7 +416,7 @@ def random_gaussian(inputs): prec_sqrt = ops.cholesky(precision) loc = randn(batch_shape + event_shape) white_vec = ops.matmul(prec_sqrt, ops.unsqueeze(loc, -1)).squeeze(-1) - return Gaussian(white_vec, prec_sqrt, inputs, False) + return Gaussian(white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs) def random_mvn(batch_shape, dim, diag=False): diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 6ae8d693..61c0a86d 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -23,7 +23,7 @@ ) from funsor.integrate import Integrate from funsor.tensor import Einsum, Tensor, numeric_array -from funsor.terms import Number, Variable +from funsor.terms import Number, Unary, Variable from funsor.testing import ( assert_close, id_from_inputs, @@ -249,7 +249,6 @@ def test_meta(loc, scale): scale.strip("_"): getattr(expected, scale), loc.strip("_"): getattr(expected, loc), "inputs": expected.inputs, - "negate": False, } actual = Gaussian(**kwargs) assert_close( @@ -264,7 +263,7 @@ def test_meta(loc, scale): @pytest.mark.parametrize( "expr,expected_type", [ - ("-g1", Gaussian), + ("-g1", Unary), ("g1 + 1", Contraction), ("g1 - 1", Contraction), ("1 + g1", Contraction), @@ -301,7 +300,6 @@ def test_smoke(expr, expected_type): ) ), inputs=OrderedDict([("i", Bint[2]), ("x", Reals[3])]), - negate=False, ) assert isinstance(g1, Gaussian) @@ -311,7 +309,6 @@ def test_smoke(expr, expected_type): numeric_array([[[1.0, 0.2], [0.2, 1.0]], [[1.0, 0.2], [0.2, 1.0]]]) ), inputs=OrderedDict([("i", Bint[2]), ("y", Reals[2])]), - negate=False, ) assert isinstance(g2, Gaussian) @@ -780,7 +777,6 @@ def test_mc_plate_gaussian(): white_vec=numeric_array([0.0]), prec_sqrt=numeric_array([[1.0]]), inputs=(("loc", Real),), - negate=False, ) + numeric_array(-0.9189) ) @@ -790,7 +786,6 @@ def test_mc_plate_gaussian(): white_vec=randn((plate_size, 1)) + 3.0, prec_sqrt=ones((plate_size, 1, 1)), inputs=(("data", Bint[plate_size]), ("loc", Real)), - negate=False, ) rng_key = None if get_backend() != "jax" else np.array([0, 0], dtype=np.uint32) diff --git a/test/test_joint.py b/test/test_joint.py index 2d4efa8f..95e3a517 100644 --- a/test/test_joint.py +++ b/test/test_joint.py @@ -120,7 +120,6 @@ def test_smoke(expr, expected_type): ) ), inputs=OrderedDict([("i", Bint[2]), ("x", Reals[3])]), - negate=False, ) assert isinstance(g, Gaussian) @@ -259,12 +258,7 @@ def test_reduce_moment_matching_univariate(): prec_sqrt = ops.cholesky(precision) white_vec = (loc[..., None, :] @ prec_sqrt)[..., 0, :] discrete = Tensor(ops.log(numeric_array([1 - p, p])) + t, int_inputs) - gaussian = Gaussian( - white_vec=white_vec, - prec_sqrt=prec_sqrt, - inputs=inputs, - negate=False, - ) + gaussian = Gaussian(white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with moment_matching: @@ -280,7 +274,6 @@ def test_reduce_moment_matching_univariate(): white_vec=expected_white_vec, prec_sqrt=expected_prec_sqrt, inputs=real_inputs, - negate=False, ) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(numeric_array(t)) @@ -306,12 +299,7 @@ def test_reduce_moment_matching_multivariate(): prec_sqrt = zeros(4, 1, 1) + ops.new_eye(loc, (2,)) white_vec = (loc[..., None, :] @ prec_sqrt)[..., 0, :] discrete = Tensor(zeros(4), int_inputs) - gaussian = Gaussian( - white_vec=white_vec, - prec_sqrt=prec_sqrt, - inputs=inputs, - negate=False, - ) + gaussian = Gaussian(white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with moment_matching: @@ -322,10 +310,7 @@ def test_reduce_moment_matching_multivariate(): expected_precision = numeric_array([[1 / 101.0, 0.0], [0.0, 1 / 2.0]]) expected_prec_sqrt = ops.pow(expected_precision, 0.5) expected_gaussian = Gaussian( - white_vec=expected_white_vec, - prec_sqrt=expected_prec_sqrt, - inputs=real_inputs, - negate=False, + white_vec=expected_white_vec, prec_sqrt=expected_prec_sqrt, inputs=real_inputs ) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(ops.log(numeric_array(4.0)))