-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support conversion of Transforms and TransformedDistributions to and from Funsors #365
Conversation
…sformedDist) case
else: | ||
raise NotImplementedError("cannot get raw dist for {}".format(self)) | ||
value_name = [name for name, domain in self.value.inputs.items() # TODO is this right? | ||
if domain == self.value.output][0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks weird. Can you explain what's going on? When is value_name != "value"
? The assertion in self.__init__()
suggests value_name == "value"
IIUC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value_name
can potentially be anything, since self.value
is a lazy expression. This logic is meant to solve the problem of identifying the value name when self.value
is not a Variable
or even has more than one input (which happens when constructing the Funsor version of a TransformedDistribution
).
The simplest nontrivial example of the latter case would be an affine or power transform where the parameters are funsor.Tensor
s with nontrivial .inputs
, although these are not handled in this PR since funsor.delta.solve
does not yet support inverting such expressions.
The main use for value_name
in this PR is in Distribution.unscaled_sample
, which needs to know value_name
to construct a sample Delta
with the correct .inputs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing nits!
Addresses #309, #386.
Blocked by #387, #388, #389.This PR adds support for converting
Transform
s andTransformedDistribution
s to and from Funsors, and for sampling and scoring withfunsor.Distribution
s with transforms in theirvalue
slot.The latter is achieved with a new term I've called(Update: I'm movingLebesgue
for want of a better idea, which computes log-det-jacobians of lazy expressions substituted into its free variable usingfunsor.delta.solve
.Lebesgue
to another PR, since it no longer seems necessary for the basic task of convertingTransformedDistribution
s.)There are several ways this functionality could have been instantiated. The design choices in this version were made with a view toward locality of changes - the overall distribution API is left unchanged, as is the behavior of
normalize
andeager
, and I've reusedfunsor.delta.solve
even though it should probably be broken up into separate inversion and linearization transformations.The number of transforms currently supported is very limited, partly because we don't have many transforms wrapped as
TransformOp
s in Funsor and partly becausesolve
cannot handle inverting more complex expressions. I've at least tried to get the two most important higher-order transformsComposeTransform
andInverseTransform
working, so that adding more transforms later will be relatively straightforward. For ease of review, I have also chosen to leave support for the JAX backend to a followup PR, though it should simply involve copying the additions tofunsor.torch.distributions
verbatim.Tested:
TransformedDistribution
test cases using the infrastructure from Automate distribution testing #389Tasks remaining:
Out of scope for this PR:
Independent
- added oneIndependent
case, but conversion failingLebesgue
- movingLebesgue
to another PR