Skip to content

Commit

Permalink
add generic stat methods to distribution (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Oct 30, 2020
1 parent cf46731 commit 94d4251
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
42 changes: 34 additions & 8 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ def eager_log_prob(cls, *params):
data = cls.dist_class(**params).log_prob(value)
return Tensor(data, inputs)

def _get_raw_dist(self):
"""
Internal method for working with underlying distribution attributes
"""
if isinstance(self.value, Variable):
value_name = self.value.name
else:
raise NotImplementedError("cannot get raw dist for {}".format(self))
# arbitrary name-dim mapping, since we're converting back to a funsor anyway
name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items())
if isinstance(domain.dtype, int) and name != value_name}
raw_dist = to_data(self, name_to_dim=name_to_dim)
dim_to_name = {dim: name for name, dim in name_to_dim.items()}
# also return value output, dim_to_name for converting results back to funsor
return raw_dist, self.value.output, dim_to_name

@property
def has_rsample(self):
return getattr(self.dist_class, "has_rsample", False)
Expand Down Expand Up @@ -139,16 +155,26 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
return result

def enumerate_support(self, expand=False):
if not self.has_enumerate_support or not isinstance(self.value, Variable):
raise ValueError("cannot enumerate support of {}".format(repr(self)))
# arbitrary name-dim mapping, since we're converting back to a funsor anyway
name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items())
if isinstance(domain.dtype, int) and name != self.value.name}
raw_dist = to_data(self, name_to_dim=name_to_dim)
assert self.has_enumerate_support and isinstance(self.value, Variable)
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.enumerate_support(expand=expand)
dim_to_name = {dim: name for name, dim in name_to_dim.items()}
dim_to_name[min(dim_to_name.keys(), default=0)-1] = self.value.name
return to_funsor(raw_value, output=self.value.output, dim_to_name=dim_to_name)
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def entropy(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.entropy()
return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name)

def mean(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.mean
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def variance(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.variance
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def __getattribute__(self, attr):
if attr in type(self)._ast_fields and attr != 'name':
Expand Down
23 changes: 10 additions & 13 deletions test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from funsor.domains import Bint, Real, Reals
from funsor.integrate import Integrate
from funsor.interpreter import interpretation, reinterpret
from funsor.tensor import Einsum, Tensor, align_tensors, numeric_array, stack
from funsor.tensor import Einsum, Tensor, numeric_array, stack
from funsor.terms import Independent, Variable, eager, lazy, to_funsor
from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param
from funsor.util import get_backend
Expand Down Expand Up @@ -701,29 +701,26 @@ def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statis
check_funsor(sample_value, expected_inputs, Real)

if sample_inputs:

actual_mean = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))

inputs, tensors = align_tensors(*list(funsor_dist.params.values())[:-1])
raw_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], tensors)))
expected_mean = Tensor(raw_dist.mean, inputs)

if statistic == "mean":
actual_stat, expected_stat = actual_mean, expected_mean
actual_stat = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = funsor_dist.mean()
elif statistic == "variance":
actual_mean = Integrate(
sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
actual_stat = Integrate(
sample_value,
(Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2,
frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = Tensor(raw_dist.variance, inputs)
expected_stat = funsor_dist.variance()
elif statistic == "entropy":
actual_stat = -Integrate(
sample_value, funsor_dist, frozenset(['value'])
).reduce(ops.add, frozenset(sample_inputs))
expected_stat = Tensor(raw_dist.entropy(), inputs)
expected_stat = funsor_dist.entropy()
else:
raise ValueError("invalid test statistic")

Expand Down

0 comments on commit 94d4251

Please sign in to comment.