Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support converting funsor.terms.Independent to and from data #396

Merged
merged 8 commits into from
Nov 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,20 @@ def distribution_to_data(funsor_dist, name_to_dim=None):

@to_data.register(Independent[typing.Union[Independent, Distribution], str, str, str])
def indep_to_data(funsor_dist, name_to_dim=None):
raise NotImplementedError("TODO implement conversion of Independent")
if not isinstance(funsor_dist.fn, (Independent, Distribution, Gaussian)):
raise NotImplementedError(f"cannot convert {funsor_dist} to data")
name_to_dim = OrderedDict((name, dim - 1) for name, dim in name_to_dim.items())
name_to_dim.update({funsor_dist.bint_var: -1})
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
result = to_data(funsor_dist.fn, name_to_dim=name_to_dim)

# collapse nested Independents into a single Independent for conversion
reinterpreted_batch_ndims = 1
while isinstance(result, backend_dist.Independent):
result = result.base_dist
reinterpreted_batch_ndims += 1

return backend_dist.Independent(result, reinterpreted_batch_ndims)


@to_data.register(Gaussian)
Expand Down
9 changes: 9 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1543,6 +1543,15 @@ def eager_subs(self, subs):
result = result.reduce(ops.add, self.bint_var)
return result

def mean(self):
raise NotImplementedError("mean() not yet implemented for Independent")

def variance(self):
raise NotImplementedError("variance() not yet implemented for Independent")

def entropy(self):
raise NotImplementedError("entropy() not yet implemented for Independent")


@eager.register(Independent, Funsor, str, str, str)
def eager_independent_trivial(fn, reals_var, bint_var, diag_var):
Expand Down
32 changes: 26 additions & 6 deletions test/test_distribution_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def __hash__(self):
funsor.Real,
)

# TODO figure out what this should be...
# Delta
for event_shape in [(), (4,), (3, 2)]:
DistTestCase(
Expand Down Expand Up @@ -298,7 +297,6 @@ def __hash__(self):
funsor.Real,
)

# TODO implement RelaxedBernoulli._infer_param_domain for temperature
# RelaxedBernoulli
DistTestCase(
"backend_dist.RelaxedBernoulli(temperature=case.temperature, logits=case.logits)",
Expand Down Expand Up @@ -334,6 +332,23 @@ def __hash__(self):
funsor.Real
)

# Independent
for indep_shape in [(3,), (2, 3)]:
# Beta.to_event
DistTestCase(
f"backend_dist.Beta(case.concentration1, case.concentration0).to_event({len(indep_shape)})",
(("concentration1", f"ops.exp(randn({batch_shape + indep_shape}))"),
("concentration0", f"ops.exp(randn({batch_shape + indep_shape}))")),
funsor.Reals[indep_shape],
)
# Dirichlet.to_event
for event_shape in [(2,), (4,)]:
DistTestCase(
f"backend_dist.Dirichlet(case.concentration).to_event({len(indep_shape)})",
(("concentration", f"rand({batch_shape + indep_shape + event_shape})"),),
funsor.Reals[indep_shape + event_shape],
)


###########################
# Generic tests:
Expand All @@ -360,14 +375,19 @@ def test_generic_distribution_to_funsor(case):
dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape)
with interpretation(lazy):
funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name)
assert funsor_dist.inputs["value"] == expected_value_domain

actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim)

assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers
assert funsor_dist.inputs["value"] == expected_value_domain
for param_name in funsor_dist.params.keys():
if param_name == "value":
continue
while isinstance(raw_dist, backend_dist.Independent):
raw_dist = raw_dist.base_dist
actual_dist = actual_dist.base_dist
assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers

for param_name, _ in case.raw_params:
assert hasattr(raw_dist, param_name)
assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name))

Expand Down