From 0f904b9505899f409fe7bcf44a1c077df28814d2 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 17:32:06 -0400 Subject: [PATCH] add generic stat methods to distribution --- funsor/distribution.py | 42 +++++++++++++++++++++++++++++++-------- test/test_distribution.py | 23 ++++++++++----------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 98a98c0be..51fbd0a68 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -105,6 +105,22 @@ def eager_log_prob(cls, *params): data = cls.dist_class(**params).log_prob(value) return Tensor(data, inputs) + def _get_raw_dist(self): + """ + Internal method for working with underlying distribution attributes + """ + if isinstance(self.value, Variable): + value_name = self.value.name + else: + raise NotImplementedError("cannot get raw dist for {}".format(self)) + # arbitrary name-dim mapping, since we're converting back to a funsor anyway + name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items()) + if isinstance(domain.dtype, int) and name != value_name} + raw_dist = to_data(self, name_to_dim=name_to_dim) + dim_to_name = {dim: name for name, dim in name_to_dim.items()} + # also return value output, dim_to_name for converting results back to funsor + return raw_dist, self.value.output, dim_to_name + @property def has_rsample(self): return getattr(self.dist_class, "has_rsample", False) @@ -139,16 +155,26 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): return result def enumerate_support(self, expand=False): - if not self.has_enumerate_support or not isinstance(self.value, Variable): - raise ValueError("cannot enumerate support of {}".format(repr(self))) - # arbitrary name-dim mapping, since we're converting back to a funsor anyway - name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items()) - if isinstance(domain.dtype, int) and name != self.value.name} - raw_dist = to_data(self, name_to_dim=name_to_dim) + assert self.has_enumerate_support and isinstance(self.value, Variable) + raw_dist, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.enumerate_support(expand=expand) - dim_to_name = {dim: name for name, dim in name_to_dim.items()} dim_to_name[min(dim_to_name.keys(), default=0)-1] = self.value.name - return to_funsor(raw_value, output=self.value.output, dim_to_name=dim_to_name) + return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) + + def entropy(self): + raw_dist, value_output, dim_to_name = self._get_raw_dist() + raw_value = raw_dist.entropy() + return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name) + + def mean(self): + raw_dist, value_output, dim_to_name = self._get_raw_dist() + raw_value = raw_dist.mean + return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) + + def variance(self): + raw_dist, value_output, dim_to_name = self._get_raw_dist() + raw_value = raw_dist.variance + return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) def __getattribute__(self, attr): if attr in type(self)._ast_fields and attr != 'name': diff --git a/test/test_distribution.py b/test/test_distribution.py index 0e7489da5..f789414b4 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -17,7 +17,7 @@ from funsor.domains import Bint, Real, Reals from funsor.integrate import Integrate from funsor.interpreter import interpretation, reinterpret -from funsor.tensor import Einsum, Tensor, align_tensors, numeric_array, stack +from funsor.tensor import Einsum, Tensor, numeric_array, stack from funsor.terms import Independent, Variable, eager, lazy, to_funsor from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param from funsor.util import get_backend @@ -701,29 +701,26 @@ def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statis check_funsor(sample_value, expected_inputs, Real) if sample_inputs: - - actual_mean = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - - inputs, tensors = align_tensors(*list(funsor_dist.params.values())[:-1]) - raw_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], tensors))) - expected_mean = Tensor(raw_dist.mean, inputs) - if statistic == "mean": - actual_stat, expected_stat = actual_mean, expected_mean + actual_stat = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.mean() elif statistic == "variance": + actual_mean = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) actual_stat = Integrate( sample_value, (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = Tensor(raw_dist.variance, inputs) + expected_stat = funsor_dist.variance() elif statistic == "entropy": actual_stat = -Integrate( sample_value, funsor_dist, frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = Tensor(raw_dist.entropy(), inputs) + expected_stat = funsor_dist.entropy() else: raise ValueError("invalid test statistic")