-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cascading flow surrogate posterior #1345
base: main
Are you sure you want to change the base?
Cascading flow surrogate posterior #1345
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Gianluigi, here's a first round of comments---I think there are some fairly subtle challenges, but this is a really nice start!
return distribution | ||
|
||
|
||
def register_cf_substitution_rule(condition, substitution_fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we actually need this substitution machinery for cascading flows.
We originally added it because ASVI (at least as implemented in TFP) depends on the parameterized distribution family: e.g., a Uniform(0., 1.)
prior distribution is the same distribution as Beta(1., 1.)
, but they give rise to different posterior families as you tune the parameters. But cascading flows just use the prior distribution as-is.
I'd go ahead and delete this (and _as_substituted_distribution
, and the specific registrations below) for now unless we have a specific use that requires it. (we can always add it later if we decide it's useful).
_extract_variables_from_coroutine_model( | ||
posterior_generator, seed=seed))) | ||
|
||
# Temporary workaround for bijector caching issues with autobatched JDs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can delete this comment now.
return surrogate_posterior, variables | ||
|
||
|
||
# todo: sample_shape is not used.. can remove? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not special-casing tfd.Sample
the way that ASVI does, so yes, I think we can get rid of sample_shape
(here and elsewhere).
# save the variables. | ||
value_out = yield (surrogate_posterior if flat_variables | ||
else (surrogate_posterior, variables)) | ||
if type(value_out) == list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this is to detect auxiliary variables, but it would also fire (incorrectly I believe) on list-valued distributions like JointDistributionSequential
. Would it work to directly check if num_auxiliary_variables > 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could probably also use a short explanatory comment.
else (surrogate_posterior, variables)) | ||
if type(value_out) == list: | ||
if len(dist.event_shape) == 0: | ||
dist = prior_gen.send(tf.squeeze(value_out[0], -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a style matter, we prefer array slicing notation where possible. e.g.:
x[..., 0]
instead of squeeze(x, -1)
(as here)
x[..., tf.newaxis]
instead of expand_dims(x, -1)
etc.
[tf.random.uniform((1,), minval=0.01, maxval=1.) | ||
for _ in range(num_aux_vars)], bijector=tfb.Softplus()), -1)), 1) | ||
|
||
def target_log_prob_aux_vars(z_and_eps): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to do something to help the user define the loss. Maybe we could add a method along the lines of augment_target_log_prob
that would takes the target_log_prob_fn
and the prior distribution as arguments, and returns the equivalent of this method target_log_prob_aux_vars
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside from convenience, one reason to do this is that it gives us leeway on the structure of the surrogate posterior. Instead of an exact specification, we can just provide the contract that
augmented_log_prob = augment_target_log_prob(target_log_prob, prior)
surrogate_posterior = build_cf_surrogate_posterior(prior, ...)
lp = augmented_log_prob(*surrogate_posterior.sample())
works to compute a valid log-density. Then if we decide later on to change the structure of the surrogate (as in my comment below about using the Restructure
bijector), we just have to make the corresponding changes to augment_target_log_prob
in order to avoid breaking code that uses this pattern.
|
||
|
||
@test_util.test_all_tf_execution_regimes | ||
class TestCFDistributionSubstitution(test_util.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably delete this test along with the substitution code.
|
||
|
||
# todo: sample_shape is not used.. can remove? | ||
def _cf_convex_update_for_base_distribution(dist, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a convex update any more --- maybe call it something like _cascading_flow_update_for_base_distribution
?
def make_prior_dist(self): | ||
def _prior_model_fn(): | ||
innovation_noise = 1. | ||
yield tfd.HalfNormal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we need to think about is: how should cascading flows apply to distributions that have constrained support (like this HalfNormal)?
Since flows don't really respect constraints on their own, IMHO the natural approach would be to transform the distribution into an unconstrained space, apply the flow, and then reapply the constraint. That is, in _cf_update_for_base_distribution
, you'd do something like:
constraining_bijector = dist.experimental_default_event_space_bijector()
unconstrained_dist = invert.Invert(constraining_bijector)(dist)
cascading_flow = ... # Build cascading flow from unconstrained dist.
# Now reapply the constraint to the sampled event (but not the auxiliary part).
constrained_cascading_flow = tfb.JointMap([constraining_bijector, identity.Identity()])(cascading_flow)
Then we'd want to test that the surrogate has the same support as the prior. Probably the easiest way to do this is to verify that the log-probs are finite. For example, you could add a test to _TrainableCFSurrogate
that checks self.assertAllFinite(prior_dist.log_prob(surrogate_posterior.sample(10, seed=test_util.test_seed())))
surrogate posterior distribution, with the same structure and `name` as | ||
`dist`, and with addition of global and local auxiliary variables if | ||
`num_auxiliary_variables > 0`. | ||
variables: Nested structure of `tf.Variable` trainable parameters for the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the components of variables
are now bijectors, maybe we should call it a nested structure 'containing' tf.Variables
rather than 'of' tf.Variables
?
…rrogate_posterior
Cascading Flows algorithm