Skip to content

Commit

Permalink
Support conversion of Transforms and TransformedDistributions to and …
Browse files Browse the repository at this point in the history
…from Funsors (#365)
  • Loading branch information
eb8680 authored Nov 20, 2020
1 parent 5432c7d commit 668aa70
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 33 deletions.
9 changes: 9 additions & 0 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import opt_einsum
from multipledispatch.variadic import Variadic

import funsor
import funsor.ops as ops
from funsor.affine import affine_inputs
from funsor.delta import Delta
Expand Down Expand Up @@ -127,6 +128,14 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
terms.append(-gaussian.log_normalizer)
terms.append(term.unscaled_sample(greedy_vars, sample_inputs, rng_keys[0]))
result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *terms)
elif any(isinstance(term, funsor.distribution.Distribution)
and not greedy_vars.isdisjoint(term.value.inputs) for term in greedy_terms):
sampled_terms = [
term.unscaled_sample(greedy_vars.intersection(term.value.inputs), sample_inputs)
for term in greedy_terms if isinstance(term, funsor.distribution.Distribution)
and not greedy_vars.isdisjoint(term.value.inputs)
]
result = Contraction(self.red_op, self.bin_op, self.reduced_vars, *(terms + sampled_terms))
else:
raise NotImplementedError('Unhandled case: {}'.format(
', '.join(str(type(t)) for t in greedy_terms)))
Expand Down
58 changes: 39 additions & 19 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from funsor.interpreter import gensym
from funsor.tensor import (Tensor, align_tensors, dummy_numeric_array, get_default_prototype,
ignore_jit_warnings, numeric_array, stack)
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, eager, to_data, to_funsor
from funsor.terms import Funsor, FunsorMeta, Independent, Number, Variable, \
eager, to_data, to_funsor
from funsor.util import broadcast_shape, get_backend, getargspec, lazy_property


Expand Down Expand Up @@ -109,17 +110,16 @@ def _get_raw_dist(self):
"""
Internal method for working with underlying distribution attributes
"""
if isinstance(self.value, Variable):
value_name = self.value.name
else:
raise NotImplementedError("cannot get raw dist for {}".format(self))
value_name = [name for name, domain in self.value.inputs.items() # TODO is this right?
if domain == self.value.output][0]
# arbitrary name-dim mapping, since we're converting back to a funsor anyway
name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items())
if isinstance(domain.dtype, int) and name != value_name}
raw_dist = to_data(self, name_to_dim=name_to_dim)
dim_to_name = {dim: name for name, dim in name_to_dim.items()}
# also return value output, dim_to_name for converting results back to funsor
return raw_dist, self.value.output, dim_to_name
value_output = self.inputs[value_name]
return raw_dist, value_name, value_output, dim_to_name

@property
def has_rsample(self):
Expand All @@ -130,16 +130,15 @@ def has_enumerate_support(self):
return getattr(self.dist_class, "has_enumerate_support", False)

def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
params = OrderedDict(self.params)
value = params.pop("value")
assert all(isinstance(v, (Number, Tensor)) for v in params.values())
assert isinstance(value, Variable) and value.name in sampled_vars

value_name = value.name
raw_dist, value_output, dim_to_name = self._get_raw_dist()
# note this should handle transforms correctly via distribution_to_data
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()):
dim_to_name[-d - len(raw_dist.batch_shape)] = name

if value_name not in sampled_vars:
return self

sample_shape = tuple(v.size for v in sample_inputs.values())
sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape)
if self.has_rsample:
Expand All @@ -161,23 +160,23 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):

def enumerate_support(self, expand=False):
assert self.has_enumerate_support and isinstance(self.value, Variable)
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.enumerate_support(expand=expand)
dim_to_name[min(dim_to_name.keys(), default=0)-1] = self.value.name
dim_to_name[min(dim_to_name.keys(), default=0)-1] = value_name
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def entropy(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.entropy()
return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name)

def mean(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.mean
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

def variance(self):
raw_dist, value_output, dim_to_name = self._get_raw_dist()
raw_dist, value_name, value_output, dim_to_name = self._get_raw_dist()
raw_value = raw_dist.variance
return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)

Expand Down Expand Up @@ -234,7 +233,6 @@ def _infer_param_domain(cls, name, raw_shape):
# Distribution Wrappers
################################################################################


def make_dist(backend_dist_class, param_names=(), generate_eager=True, generate_to_funsor=True):
if not param_names:
param_names = tuple(name for name in inspect.getfullargspec(backend_dist_class.__init__)[0][1:]
Expand Down Expand Up @@ -312,8 +310,19 @@ def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
return mask * funsor_base_dist


# converts TransformedDistributions
def transformeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
raise NotImplementedError("TODO implement conversion of TransformedDistribution")
dist_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
base_dist, transforms = backend_dist, []
while isinstance(base_dist, dist_module.TransformedDistribution):
transforms = base_dist.transforms + transforms
base_dist = base_dist.base_dist
funsor_base_dist = to_funsor(base_dist, output=output, dim_to_name=dim_to_name)
# TODO make this work with transforms that change the output type
transform = to_funsor(dist_module.transforms.ComposeTransform(transforms),
funsor_base_dist.inputs["value"], dim_to_name)
_, inv_transform, ldj = funsor.delta.solve(transform, to_funsor("value", funsor_base_dist.inputs["value"]))
return -ldj + funsor_base_dist(value=inv_transform)


class CoerceDistributionToFunsor:
Expand Down Expand Up @@ -396,6 +405,17 @@ def distribution_to_data(funsor_dist, name_to_dim=None):
pyro_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], params)))
funsor_event_shape = funsor_dist.value.output.shape
pyro_dist = pyro_dist.to_event(max(len(funsor_event_shape) - len(pyro_dist.event_shape), 0))

# TODO get this working for all backends
if not isinstance(funsor_dist.value, Variable):
if get_backend() != "torch":
raise NotImplementedError("transformed distributions not yet supported under this backend,"
"try set_backend('torch')")
inv_value = funsor.delta.solve(funsor_dist.value, Variable("value", funsor_dist.value.output))[1]
transforms = to_data(inv_value, name_to_dim=name_to_dim)
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist
pyro_dist = backend_dist.TransformedDistribution(pyro_dist, transforms)

if pyro_dist.event_shape != funsor_event_shape:
raise ValueError("Event shapes don't match, something went wrong")
return pyro_dist
Expand Down
86 changes: 85 additions & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor, dummy_numeric_array
from funsor.terms import Binary, Funsor, Reduce, Variable, eager, to_data, to_funsor
from funsor.terms import Binary, Funsor, Reduce, Unary, Variable, eager, to_data, to_funsor
from funsor.util import methodof


Expand Down Expand Up @@ -222,10 +222,94 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):
return dist.Delta(v, log_density, event_dim=len(funsor_dist.v.output.shape))


@functools.singledispatch
def op_to_torch_transform(op, name_to_dim=None):
raise NotImplementedError("cannot convert {} to a Transform".format(op))


@op_to_torch_transform.register(ops.TransformOp)
def transform_to_torch_transform(op, name_to_dim=None):
raise NotImplementedError("{} is not a currently supported transform".format(op))


@op_to_torch_transform.register(ops.ExpOp)
def exp_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.ExpTransform()


@op_to_torch_transform.register(ops.LogOp)
def log_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.ExpTransform().inv


@op_to_torch_transform.register(ops.SigmoidOp)
def sigmoid_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.SigmoidTransform()


@op_to_torch_transform.register(ops.TanhOp)
def tanh_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.TanhTransform()


@op_to_torch_transform.register(ops.AtanhOp)
def atanh_to_torch_transform(op, name_to_dim=None):
return torch.distributions.transforms.TanhTransform().inv


@to_data.register(Unary[ops.TransformOp, Union[Unary, Variable]])
def transform_to_data(expr, name_to_dim=None):
if isinstance(expr.op, ops.TransformOp):
tfm = op_to_torch_transform(expr.op, name_to_dim=name_to_dim)
if isinstance(expr.arg, Unary):
tfm = torch.distributions.transforms.ComposeTransform([to_data(expr.arg, name_to_dim=name_to_dim), tfm])
return tfm
raise NotImplementedError("cannot convert to data: {}".format(expr))


###############################################
# Converting PyTorch Distributions to funsors
###############################################

@to_funsor.register(torch.distributions.Transform)
def transform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
raise NotImplementedError("{} is not a currently supported transform".format(tfm))


@to_funsor.register(torch.distributions.transforms.ExpTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.exp(Variable(name, output))


@to_funsor.register(torch.distributions.transforms.TanhTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.tanh(Variable(name, output))


@to_funsor.register(torch.distributions.transforms.SigmoidTransform)
def exptransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
return ops.sigmoid(Variable(name, output))


@to_funsor.register(torch.distributions.transforms._InverseTransform)
def inversetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
expr = to_funsor(tfm._inv, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs)
assert isinstance(expr, Unary)
return expr.op.inv(expr.arg)


@to_funsor.register(torch.distributions.transforms.ComposeTransform)
def composetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=None):
name = next(real_inputs.keys()) if real_inputs else "value"
expr = Variable(name, output)
for part in tfm.parts:
expr = to_funsor(part, output=output, dim_to_name=dim_to_name, real_inputs=real_inputs)(**{name: expr})
return expr


to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor)
Expand Down
Loading

0 comments on commit 668aa70

Please sign in to comment.