Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cascading flow surrogate posterior #1345

Open
wants to merge 59 commits into
base: main
Choose a base branch
from

Conversation

gisilvs
Copy link
Contributor

@gisilvs gisilvs commented May 28, 2021

Cascading Flows algorithm

@gisilvs
Copy link
Contributor Author

gisilvs commented May 28, 2021

@davmre

@googlebot googlebot added the cla: yes Declares that the user has signed CLA label May 28, 2021
Copy link
Contributor

@davmre davmre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Gianluigi, here's a first round of comments---I think there are some fairly subtle challenges, but this is a really nice start!

return distribution


def register_cf_substitution_rule(condition, substitution_fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we actually need this substitution machinery for cascading flows.

We originally added it because ASVI (at least as implemented in TFP) depends on the parameterized distribution family: e.g., a Uniform(0., 1.) prior distribution is the same distribution as Beta(1., 1.), but they give rise to different posterior families as you tune the parameters. But cascading flows just use the prior distribution as-is.

I'd go ahead and delete this (and _as_substituted_distribution, and the specific registrations below) for now unless we have a specific use that requires it. (we can always add it later if we decide it's useful).

_extract_variables_from_coroutine_model(
posterior_generator, seed=seed)))

# Temporary workaround for bijector caching issues with autobatched JDs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can delete this comment now.

return surrogate_posterior, variables


# todo: sample_shape is not used.. can remove?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not special-casing tfd.Sample the way that ASVI does, so yes, I think we can get rid of sample_shape (here and elsewhere).

# save the variables.
value_out = yield (surrogate_posterior if flat_variables
else (surrogate_posterior, variables))
if type(value_out) == list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is to detect auxiliary variables, but it would also fire (incorrectly I believe) on list-valued distributions like JointDistributionSequential. Would it work to directly check if num_auxiliary_variables > 0 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably also use a short explanatory comment.

else (surrogate_posterior, variables))
if type(value_out) == list:
if len(dist.event_shape) == 0:
dist = prior_gen.send(tf.squeeze(value_out[0], -1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a style matter, we prefer array slicing notation where possible. e.g.:

x[..., 0] instead of squeeze(x, -1) (as here)

x[..., tf.newaxis] instead of expand_dims(x, -1)

etc.

[tf.random.uniform((1,), minval=0.01, maxval=1.)
for _ in range(num_aux_vars)], bijector=tfb.Softplus()), -1)), 1)

def target_log_prob_aux_vars(z_and_eps):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to do something to help the user define the loss. Maybe we could add a method along the lines of augment_target_log_prob that would takes the target_log_prob_fn and the prior distribution as arguments, and returns the equivalent of this method target_log_prob_aux_vars ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside from convenience, one reason to do this is that it gives us leeway on the structure of the surrogate posterior. Instead of an exact specification, we can just provide the contract that

augmented_log_prob = augment_target_log_prob(target_log_prob, prior)
surrogate_posterior = build_cf_surrogate_posterior(prior, ...)
lp = augmented_log_prob(*surrogate_posterior.sample())

works to compute a valid log-density. Then if we decide later on to change the structure of the surrogate (as in my comment below about using the Restructure bijector), we just have to make the corresponding changes to augment_target_log_prob in order to avoid breaking code that uses this pattern.



@test_util.test_all_tf_execution_regimes
class TestCFDistributionSubstitution(test_util.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably delete this test along with the substitution code.



# todo: sample_shape is not used.. can remove?
def _cf_convex_update_for_base_distribution(dist,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a convex update any more --- maybe call it something like _cascading_flow_update_for_base_distribution ?

def make_prior_dist(self):
def _prior_model_fn():
innovation_noise = 1.
yield tfd.HalfNormal(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we need to think about is: how should cascading flows apply to distributions that have constrained support (like this HalfNormal)?

Since flows don't really respect constraints on their own, IMHO the natural approach would be to transform the distribution into an unconstrained space, apply the flow, and then reapply the constraint. That is, in _cf_update_for_base_distribution, you'd do something like:

constraining_bijector = dist.experimental_default_event_space_bijector()
unconstrained_dist = invert.Invert(constraining_bijector)(dist)
cascading_flow = ... # Build cascading flow from unconstrained dist.
# Now reapply the constraint to the sampled event (but not the auxiliary part).
constrained_cascading_flow = tfb.JointMap([constraining_bijector, identity.Identity()])(cascading_flow)

Then we'd want to test that the surrogate has the same support as the prior. Probably the easiest way to do this is to verify that the log-probs are finite. For example, you could add a test to _TrainableCFSurrogate that checks self.assertAllFinite(prior_dist.log_prob(surrogate_posterior.sample(10, seed=test_util.test_seed())))

surrogate posterior distribution, with the same structure and `name` as
`dist`, and with addition of global and local auxiliary variables if
`num_auxiliary_variables > 0`.
variables: Nested structure of `tf.Variable` trainable parameters for the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the components of variables are now bijectors, maybe we should call it a nested structure 'containing' tf.Variables rather than 'of' tf.Variables ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Declares that the user has signed CLA
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants