diff --git a/tensorflow_probability/python/experimental/vi/BUILD b/tensorflow_probability/python/experimental/vi/BUILD index e57f884ca5..d693fa80ea 100644 --- a/tensorflow_probability/python/experimental/vi/BUILD +++ b/tensorflow_probability/python/experimental/vi/BUILD @@ -31,6 +31,7 @@ py_library( srcs_version = "PY3", deps = [ ":automatic_structured_vi", + ":cascading_flows", ":surrogate_posteriors", "//tensorflow_probability/python/experimental/vi/util", "//tensorflow_probability/python/internal:all_util", @@ -67,6 +68,29 @@ py_library( ], ) +py_library( + name = "cascading_flows", + srcs = ["cascading_flows.py"], + srcs_version = "PY3", + deps = [ + # tensorflow dep, + "//tensorflow_probability/python/bijectors:chain", + "//tensorflow_probability/python/bijectors:reshape", + "//tensorflow_probability/python/bijectors:split", + "//tensorflow_probability/python/distributions:batch_broadcast", + "//tensorflow_probability/python/distributions:blockwise", + "//tensorflow_probability/python/distributions:deterministic", + "//tensorflow_probability/python/distributions:independent", + "//tensorflow_probability/python/distributions:joint_distribution_auto_batched", + "//tensorflow_probability/python/distributions:joint_distribution_coroutine", + "//tensorflow_probability/python/distributions:normal", + "//tensorflow_probability/python/distributions:sample", + "//tensorflow_probability/python/distributions:transformed_distribution", + "//tensorflow_probability/python/experimental/bijectors:build_trainable_highway_flow", + "//tensorflow_probability/python/internal:samplers", + ], +) + py_library( name = "surrogate_posteriors", srcs = ["surrogate_posteriors.py"], @@ -111,6 +135,22 @@ py_test( ], ) +py_test( + name = "cascading_flows_test", + size = "large", + srcs = ["cascading_flows_test.py"], + python_version = "PY3", + shard_count = 4, + srcs_version = "PY3", + deps = [ + # absl/testing:parameterized dep, + # numpy dep, + # tensorflow dep, + "//tensorflow_probability", + "//tensorflow_probability/python/internal:test_util", + ], +) + py_test( name = "surrogate_posteriors_test", size = "large", diff --git a/tensorflow_probability/python/experimental/vi/__init__.py b/tensorflow_probability/python/experimental/vi/__init__.py index 0cb4971fcc..1f2fa2f900 100644 --- a/tensorflow_probability/python/experimental/vi/__init__.py +++ b/tensorflow_probability/python/experimental/vi/__init__.py @@ -17,6 +17,7 @@ from tensorflow_probability.python.experimental.vi import util from tensorflow_probability.python.experimental.vi.automatic_structured_vi import build_asvi_surrogate_posterior from tensorflow_probability.python.experimental.vi.automatic_structured_vi import register_asvi_substitution_rule +from tensorflow_probability.python.experimental.vi.cascading_flows import build_cascading_flow_surrogate_posterior from tensorflow_probability.python.experimental.vi.surrogate_posteriors import build_affine_surrogate_posterior from tensorflow_probability.python.experimental.vi.surrogate_posteriors import build_affine_surrogate_posterior_from_base_distribution from tensorflow_probability.python.experimental.vi.surrogate_posteriors import build_factored_surrogate_posterior @@ -29,6 +30,7 @@ 'build_affine_surrogate_posterior', 'build_affine_surrogate_posterior_from_base_distribution', 'build_asvi_surrogate_posterior', + 'build_cascading_flow_surrogate_posterior', 'build_factored_surrogate_posterior', 'build_split_flow_surrogate_posterior', 'build_trainable_location_scale_distribution', diff --git a/tensorflow_probability/python/experimental/vi/cascading_flows.py b/tensorflow_probability/python/experimental/vi/cascading_flows.py new file mode 100644 index 0000000000..e8796b35f7 --- /dev/null +++ b/tensorflow_probability/python/experimental/vi/cascading_flows.py @@ -0,0 +1,544 @@ +# Copyright 2021 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utilities for constructing structured surrogate posteriors.""" + +from __future__ import absolute_import +from __future__ import division +# [internal] enable type annotations +from __future__ import print_function + +import copy +import functools + +import tensorflow.compat.v2 as tf + +from tensorflow_probability.python.bijectors import chain +from tensorflow_probability.python.bijectors import identity +from tensorflow_probability.python.bijectors import invert +from tensorflow_probability.python.bijectors import joint_map +from tensorflow_probability.python.bijectors import reshape +from tensorflow_probability.python.bijectors import restructure +from tensorflow_probability.python.bijectors import split +from tensorflow_probability.python.distributions import batch_broadcast +from tensorflow_probability.python.distributions import blockwise +from tensorflow_probability.python.distributions import deterministic +from tensorflow_probability.python.distributions import independent +from tensorflow_probability.python.distributions import \ + joint_distribution_auto_batched +from tensorflow_probability.python.distributions import \ + joint_distribution_coroutine +from tensorflow_probability.python.distributions import normal +from tensorflow_probability.python.distributions import sample +from tensorflow_probability.python.distributions import transformed_distribution + +from tensorflow_probability.python.experimental.bijectors import build_trainable_highway_flow +from tensorflow_probability.python.internal import samplers +from tensorflow_probability.python.internal import prefer_static as ps + +__all__ = [ + 'build_cascading_flow_surrogate_posterior' +] + +Root = joint_distribution_coroutine.JointDistributionCoroutine.Root + +# todo: add check that id use_global_auxiliary_variables is true then num_auxiliary variables must be >=1 +def build_cascading_flow_surrogate_posterior( + prior, + num_auxiliary_variables=0, + initial_prior_weight=0.98, + num_layers=3, + use_global_auxiliary_variables=False, + seed=None, + name=None): + """Builds a structured surrogate posterior with cascading flows. + + Cascading Flows (CF) [1] is a method that automatically construct a + variational approximation given an input probabilistic program. CF combines + ASVI [2] with the flexibility of normalizing flows, by transforming the + conditional distributions of the prior program with HighwayFlow architectures, + to steer the prior towards the observed data. More details on the HighwayFlow + architecture can be found in [1] and in the tfp bijector `HighwayFlow`. + It is possible to add auxiliary variables to the prior program to further + increase the flexibility of cascading flows, useful especially in the + cases where the input program has low dimensionality. The auxiliary variables + are sampled from a global linear flow, to account for statistical dependencies + among variables, and then transformed with local HighwayFlows together with + samples form the prior. Note that when using auxiliary variables it is + necessary to modify the variational lower bound [3]. + + Args: + prior: tfd.JointDistribution instance of the prior. + num_auxiliary_variables: The number of auxiliary variables to use for each + variable in the input program. Default value: `0`. + initial_prior_weight: Optional float value (either static or tensor value) + on the interval [0, 1]. A larger value creates an initial surrogate + distribution with more dependence on the prior structure. Default value: + `0.98`. + num_layers: Number of layers to use in each Highway Flow architecture. All + the layers will have `softplus` activation function, apart from the last one + which will have linear activation. Default value: `3`. + seed: Python `int` seed for random initialization. + name: Optional string. Default value: `build_cascading_flow_surrogate_posterior`. + + Returns: + surrogate_posterior: A `tfd.JointDistributionCoroutineAutoBatched` instance + whose samples have shape and structure matching that of `prior`. + + Raises: + TypeError: The `prior` argument cannot be a nested `JointDistribution`. + + ### Examples + + Consider a Brownian motion model expressed as a JointDistribution: + + ```python + prior_loc = 0. + innovation_noise = .1 + + def model_fn(): + new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise) + for i in range(4): + new = yield tfd.Normal(loc=new, scale=innovation_noise) + + prior = tfd.JointDistributionCoroutineAutoBatched(model_fn) + ``` + + Let's use variational inference to approximate the posterior. We'll build a + surrogate posterior distribution by feeding in the prior distribution. + + ```python + surrogate_posterior = + tfp.experimental.vi.build_cascading_flow_surrogate_posterior(prior) + ``` + + This creates a trainable joint distribution, defined by variables in + `surrogate_posterior.trainable_variables`. We use `fit_surrogate_posterior` + to fit this distribution by minimizing a divergence to the true posterior. + + ```python + losses = tfp.vi.fit_surrogate_posterior( + target_log_prob_fn, + surrogate_posterior=surrogate_posterior, + num_steps=100, + optimizer=tf.optimizers.Adam(0.1), + sample_size=10) + + # After optimization, samples from the surrogate will approximate + # samples from the true posterior. + samples = surrogate_posterior.sample(100) + posterior_mean = [tf.reduce_mean(x) for x in samples] + posterior_std = [tf.math.reduce_std(x) for x in samples] + ``` + + When using auxiliary variables, we need some modifications for loss and + samples, as samples will return also the global variables and transformed + auxiliary variables + + ```python + num_aux_vars=10 + event_len = len(prior.event_shape_tensor()) + target_dist = tfd.Independent( + tfd.Normal(loc=tf.Variable(tf.random.normal((event_len,num_aux_vars))), + scale=tfp.util.TransformedVariable( + tf.random.uniform((event_len,num_aux_vars), minval=0.01, maxval=1.) + , bijector=tfb.Softplus())), 2) + + def target_log_prob_aux_vars(z_and_eps): + z = [x[0] for x in z_and_eps[1:]] + eps = [x[1] for x in z_and_eps[1:]] + lp_z = target_log_prob_fn(z) + lp_eps = tf.reshape(tf.reduce_sum(target_dist.log_prob(eps), 0), lp_z.shape) + return lp_z + lp_eps + + target_log_prob = lambda *values: target_log_prob_aux_vars(values) + cascading_flow_surrogate_posterior = build_cascading_flow_surrogate_posterior(prior, + num_auxiliary_variables=num_aux_vars) + trainable_variables = list(cascading_flow_surrogate_posterior.trainable_variables) + trainable_variables.extend(list(target_dist.trainable_variables)) + cascading_flow_losses = tfp.vi.fit_surrogate_posterior(target_log_prob, + cascading_flow_surrogate_posterior, + optimizer=tf.optimizers.Adam(0.01), + num_steps=8000, + sample_size=50, + trainable_variables=trainable_variables) + + cascading_flow_posterior_samples = cascading_flow_surrogate_posterior.sample(num_samples) + cascading_flow_posterior_samples = tf.convert_to_tensor( + [s[0] for s in cascading_flow_posterior_samples[1:]]) + ``` + + #### References + [1]: Ambrogioni, Luca, Gianluigi Silvestri, and Marcel van Gerven. "Automatic + variational inference with cascading flows." arXiv preprint arXiv:2102.04801 + (2021). + + [2]: Ambrogioni, Luca, et al. "Automatic structured variational inference." + International Conference on Artificial Intelligence and Statistics. PMLR, + 2021. + + [3]: Ranganath, Rajesh, Dustin Tran, and David Blei. "Hierarchical variational + models." International Conference on Machine Learning. PMLR, 2016. + + """ + if num_auxiliary_variables == 0 and use_global_auxiliary_variables == True: + raise ValueError('cannot use global auxiliary variables if auxiliary variables is 0') + with tf.name_scope(name or 'build_cascading_flow_surrogate_posterior'): + surrogate_posterior, variables = _cascading_flow_surrogate_for_distribution( + dist=prior, + base_distribution_surrogate_fn=functools.partial( + _cascading_flow_update_for_base_distribution, + initial_prior_weight=initial_prior_weight, + num_auxiliary_variables=num_auxiliary_variables, + num_layers=num_layers, + use_global_auxiliary_variables=use_global_auxiliary_variables,), + num_auxiliary_variables=num_auxiliary_variables, + num_layers=num_layers, + use_global_auxiliary_variables=use_global_auxiliary_variables, + seed=seed) + surrogate_posterior.also_track = variables + return surrogate_posterior + + +def _cascading_flow_surrogate_for_distribution(dist, + base_distribution_surrogate_fn, + num_auxiliary_variables, + num_layers, + use_global_auxiliary_variables, + global_auxiliary_variables=None, + variables=None, + seed=None): + """Recursively creates CF surrogates, and creates new variables if needed. + + Args: + dist: a `tfd.Distribution` instance. + base_distribution_surrogate_fn: Callable to build a surrogate posterior + for a 'base' (non-meta and non-joint) distribution, with signature + `surrogate_posterior, variables = base_distribution_fn( + dist, variables=None, seed=None)`. + num_auxiliary_variables: The number of auxiliary variables to use for each + variable in the input program. + num_layers: Number of layers to use in each Highway Flow architecture. + global_auxiliary_variables: The sampled global auxiliary variables + (available only if using auxiliary variables). Default value: None. + variables: Optional nested structure containing `tf.Variable`s returned from a + previous call to `_cascading_flow_surrogate_for_distribution`. If `None`, + new variables will be created; otherwise, constructs a surrogate posterior + backed by the passed-in variables. + Default value: `None`. + seed: Python `int` seed for random initialization. + Returns: + surrogate_posterior: Instance of `tfd.Distribution` representing a trainable + surrogate posterior distribution, with the same structure and `name` as + `dist`, and with addition of global and local auxiliary variables if + `num_auxiliary_variables > 0`. + variables: Nested structure containing `tf.Variable` trainable parameters for the + surrogate posterior. If `dist` is a base distribution, this is + a `tfb.Chain` of bijectors containing HighwayFlow blocks and `Reshape` + bijectors. If `dist` is a joint distribution, this is a `dist.dtype` + structure of such `tfb.Chain`s. + """ + + if hasattr(dist, '_model_coroutine'): + surrogate_posterior, variables = _cascading_flow_surrogate_for_joint_distribution( + dist, + base_distribution_surrogate_fn=base_distribution_surrogate_fn, + variables=variables, + num_auxiliary_variables=num_auxiliary_variables, + num_layers=num_layers, + use_global_auxiliary_variables=use_global_auxiliary_variables, + global_auxiliary_variables=global_auxiliary_variables, + seed=seed) + else: + surrogate_posterior, variables = base_distribution_surrogate_fn( + dist=dist, variables=variables, + use_global_auxiliary_variables=use_global_auxiliary_variables, + global_auxiliary_variables=global_auxiliary_variables, + num_layers=num_layers, + seed=seed) + return surrogate_posterior, variables + + +def _build_highway_flow_block(num_layers, width, + residual_fraction_initial_value, gate_first_n, + seed): + bijectors = [] + + for _ in range(0, num_layers - 1): + bijectors.append( + build_trainable_highway_flow(width, + residual_fraction_initial_value=residual_fraction_initial_value, + activation_fn=tf.nn.softplus, + gate_first_n=gate_first_n, seed=seed)) + bijectors.append( + build_trainable_highway_flow(width, + residual_fraction_initial_value=residual_fraction_initial_value, + activation_fn=None, + gate_first_n=gate_first_n, seed=seed)) + + return bijectors + + +def _cascading_flow_surrogate_for_joint_distribution( + dist, base_distribution_surrogate_fn, variables, + num_auxiliary_variables, num_layers, use_global_auxiliary_variables, global_auxiliary_variables, + seed=None): + """Builds a structured joint surrogate posterior for a joint model.""" + + # Probabilistic program for CF surrogate posterior. + flat_variables = dist._model_flatten( + variables) if variables else None # pylint: disable=protected-access + prior_coroutine = dist._model_coroutine # pylint: disable=protected-access + prior_batch_shape = dist.batch_shape_tensor() + #fixme + if tf.nest.is_nested(prior_batch_shape): + prior_batch_shape = functools.reduce(ps.broadcast_shape, + dist._model_flatten(prior_batch_shape)) + + def posterior_generator(seed=seed): + prior_gen = prior_coroutine() + dist = next(prior_gen) + + if use_global_auxiliary_variables == True: + i = 1 + + if flat_variables: + variables = flat_variables[0] + + else: + + bijectors = _build_highway_flow_block( + num_layers, + width=num_auxiliary_variables, + residual_fraction_initial_value=0, # not used + gate_first_n=0, seed=seed) + variables = chain.Chain(bijectors=list(reversed(bijectors))) + + eps = transformed_distribution.TransformedDistribution( + distribution=batch_broadcast.BatchBroadcast(sample.Sample(normal.Normal(0., 1.), + num_auxiliary_variables), prior_batch_shape), + bijector=variables) + + eps = Root(eps) + + global_auxiliary_variables = yield (eps if flat_variables + else (eps, variables)) + + else: + global_auxiliary_variables = None + i = 0 + + try: + while True: + was_root = isinstance(dist, Root) + if was_root: + dist = dist.distribution + + seed, init_seed = samplers.split_seed(seed) + surrogate_posterior, variables = _cascading_flow_surrogate_for_distribution( + dist, + base_distribution_surrogate_fn=base_distribution_surrogate_fn, + num_auxiliary_variables=num_auxiliary_variables, + num_layers=num_layers, + variables=flat_variables[i] if flat_variables else None, + use_global_auxiliary_variables=use_global_auxiliary_variables, + global_auxiliary_variables=global_auxiliary_variables, + seed=init_seed) + + if was_root and use_global_auxiliary_variables == False: + surrogate_posterior = Root(surrogate_posterior) + # If variables were not given---i.e., we're creating new + # variables---then yield the new variables along with the surrogate + # posterior. This assumes an execution context such as + # `_extract_variables_from_coroutine_model` below that will capture and + # save the variables. + value_out = yield (surrogate_posterior if flat_variables + else (surrogate_posterior, variables)) + + # When using auxiliary variables, value out is a list containing + # [latent_variables, auxiliary_variables]. + if num_auxiliary_variables>0: + dist = prior_gen.send(value_out[0]) + else: + dist = prior_gen.send(value_out) + i += 1 + except StopIteration: + pass + + if variables is None: + # Run the generator to create variables, then call ourselves again + # to construct the surrogate JD from these variables. Note that we can't + # just create a JDC from the current `posterior_generator`, because it will + # try to build new variables on every invocation; the recursive call will + # define a new `posterior_generator` that knows about the variables we're + # about to create. + return _cascading_flow_surrogate_for_joint_distribution( + dist=dist, + base_distribution_surrogate_fn=base_distribution_surrogate_fn, + num_auxiliary_variables=num_auxiliary_variables, + num_layers=num_layers, + use_global_auxiliary_variables=use_global_auxiliary_variables, + global_auxiliary_variables=global_auxiliary_variables, + variables=dist._model_unflatten( + # pylint: disable=protected-access + _extract_variables_from_coroutine_model( + posterior_generator, seed=seed))) + + surrogate_posterior = joint_distribution_auto_batched.JointDistributionCoroutineAutoBatched( + posterior_generator, + use_vectorized_map=dist.use_vectorized_map, + name=_get_name(dist)) + + tokenize = lambda jd: tf.nest.pack_sequence_as( + jd.dtype, + range(len(tf.nest.flatten(jd.dtype)))) + + dist_tokens = tokenize(dist) + + if num_auxiliary_variables == 0: + try: + tf.nest.assert_same_structure(dist.dtype, surrogate_posterior.dtype) + except TypeError: + surrogate_posterior = restructure.Restructure( + output_structure=dist_tokens, + input_structure=tokenize(surrogate_posterior))( + surrogate_posterior, name=_get_name(dist)) + + #FIXME: this part is commented out as blows up RAM memory + '''elif use_global_auxiliary_variables: + surrogate_posterior = restructure.Restructure( + output_structure=( + tf.nest.map_structure(lambda k: 2 * k + 1, dist_tokens), + [0] + [2 * k + 2 for k in tf.nest.flatten(dist_tokens)]), + input_structure=tokenize(surrogate_posterior))( + surrogate_posterior, name=_get_name(dist)) + + else: + surrogate_posterior = restructure.Restructure( + output_structure=( + tf.nest.map_structure(lambda k: 2 * k, dist_tokens), + [2 * k + 1 for k in tf.nest.flatten(dist_tokens)]), + input_structure=tokenize(surrogate_posterior))( + surrogate_posterior, name=_get_name(dist))''' + + return surrogate_posterior, variables + + +def _cascading_flow_update_for_base_distribution(dist, + initial_prior_weight, + num_auxiliary_variables, + num_layers, + use_global_auxiliary_variables, + global_auxiliary_variables, + variables, + seed=None): + """Creates a trainable surrogate for a (non-meta, non-joint) distribution.""" + event_shape = dist.event_shape_tensor() + flat_event_shape = tf.nest.flatten(event_shape) + flat_event_size = tf.nest.map_structure(ps.reduce_prod, flat_event_shape) + ndims = ps.reduce_sum(flat_event_size) + constraining_bijector = dist.experimental_default_event_space_bijector() + flatten_bijector = reshape.Reshape( + event_shape_out=flat_event_size, + event_shape_in=dist.event_shape_tensor()) + + constraining_and_flattening_bijector = chain.Chain([flatten_bijector, invert.Invert(constraining_bijector)]) + processed_dist = transformed_distribution.TransformedDistribution(distribution=dist, + bijector=constraining_and_flattening_bijector) + if variables is None: + + bijectors = [] + + bijectors.extend( + _build_highway_flow_block( + num_layers, + width=tf.reduce_prod( + ndims + num_auxiliary_variables), + residual_fraction_initial_value=initial_prior_weight, + gate_first_n=ndims, seed=seed)) + + variables = chain.Chain(bijectors=list(reversed(bijectors))) + + if num_auxiliary_variables > 0 and use_global_auxiliary_variables == True: + batch_shape = global_auxiliary_variables.shape[:-1] if len( + global_auxiliary_variables.shape) > 1 else [] + cascading_flows = split.Split( + [-1, num_auxiliary_variables])( + transformed_distribution.TransformedDistribution( + distribution=blockwise.Blockwise([ + batch_broadcast.BatchBroadcast(processed_dist, to_shape=batch_shape), + independent.Independent( + deterministic.Deterministic(global_auxiliary_variables, ), + reinterpreted_batch_ndims=1)]), + bijector=variables)) + + cascading_flows = joint_map.JointMap( + [invert.Invert(constraining_and_flattening_bijector), identity.Identity()])(cascading_flows) + + elif num_auxiliary_variables > 0 and use_global_auxiliary_variables == False: + cascading_flows = split.Split( + [-1, num_auxiliary_variables])( + transformed_distribution.TransformedDistribution( + distribution=blockwise.Blockwise([processed_dist, + batch_broadcast.BatchBroadcast( + sample.Sample(normal.Normal(0.,1.), num_auxiliary_variables), to_shape=processed_dist.batch_shape)]), + bijector=variables)) + + cascading_flows = joint_map.JointMap( + [invert.Invert(constraining_and_flattening_bijector), + identity.Identity()])(cascading_flows) + else: + cascading_flows = transformed_distribution.TransformedDistribution( + distribution=processed_dist, + bijector=variables) + + cascading_flows = invert.Invert(constraining_and_flattening_bijector)(cascading_flows) + + return cascading_flows, variables + + +def _extract_variables_from_coroutine_model(model_fn, seed=None): + """Extracts variables from a generator that yields (dist, variables) pairs.""" + gen = model_fn() + try: + dist, dist_variables = next(gen) + flat_variables = [dist_variables] + while True: + seed, local_seed = samplers.split_seed(seed, n=2) + sampled_value = (dist.distribution.sample(seed=local_seed) + if isinstance(dist, Root) + else dist.sample(seed=local_seed)) + dist, dist_variables = gen.send( + sampled_value) # tf.concat(sampled_value, axis=0) + flat_variables.append(dist_variables) + except StopIteration: + pass + return flat_variables + + +def _set_name(dist, name): + """Copies a distribution-like object, replacing its name.""" + if hasattr(dist, 'copy'): + return dist.copy(name=name) + # Some distribution-like entities such as JointDistributionPinned don't + # inherit from tfd.Distribution and don't define `self.copy`. We'll try to set + # the name directly. + dist = copy.copy(dist) + dist._name = name # pylint: disable=protected-access + return dist + + +def _get_name(dist): + """Attempts to get a distribution's short name, excluding the name scope.""" + return getattr(dist, 'parameters', {}).get('name', dist.name) diff --git a/tensorflow_probability/python/experimental/vi/cascading_flows_test.py b/tensorflow_probability/python/experimental/vi/cascading_flows_test.py new file mode 100644 index 0000000000..de10179486 --- /dev/null +++ b/tensorflow_probability/python/experimental/vi/cascading_flows_test.py @@ -0,0 +1,305 @@ +# Copyright 2021 The TensorFlow Probability Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tests for structured surrogate posteriors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf1 +import tensorflow.compat.v2 as tf + +import tensorflow_probability as tfp +from tensorflow_probability.python.internal import prefer_static as ps +from tensorflow_probability.python.internal import test_util +from tensorflow.python.util import nest + +# Dependency imports + +tfb = tfp.bijectors +tfd = tfp.distributions +Root = tfd.JointDistributionCoroutine.Root + +@test_util.test_all_tf_execution_regimes +class CascadingFlowTests(test_util.TestCase): + + def test_shapes(self): + + def test_shapes_model(): + # Matrix-valued random variable with batch shape [3]. + A = yield Root( + tfd.WishartTriL(df=2, scale_tril=tf.eye(2, batch_shape=[3]), name='A')) + # Vector-valued random variable with batch shape [3] (inherited from `A`) + x = yield tfd.MultivariateNormalTriL(loc=tf.zeros([2]), + scale_tril=tf.linalg.cholesky(A), + name='x') + # Scalar-valued random variable, with batch shape `[3]`. + y = yield tfd.Normal(loc=tf.reduce_sum(x, axis=-1), scale=tf.ones([3])) + + prior = tfd.JointDistributionCoroutineAutoBatched(test_shapes_model, batch_ndims=1) + surrogate_posterior = tfp.experimental.vi.build_cascading_flow_surrogate_posterior(prior) #num_auxiliary_variables=10) + + x1 = surrogate_posterior.sample() + x2 = nest.map_structure_up_to( + x1, + # Strip auxiliary variables. + lambda *rv_and_aux: rv_and_aux[0], + surrogate_posterior.sample()) + + # Assert that samples from the surrogate have the same shape as the prior. + get_shapes = lambda x: tf.nest.map_structure(lambda xp: xp.shape, x) + self.assertAllEqualNested(get_shapes(x1), get_shapes(x2)) + + +@test_util.test_all_tf_execution_regimes +class _TrainableCFSurrogate(object): + + def _expected_num_trainable_variables(self, prior_dist, num_layers): + """Infers the expected number of trainable variables for a non-nested JD.""" + prior_dists = prior_dist._get_single_sample_distributions() # pylint: disable=protected-access + expected_num_trainable_variables = 0 + + # For each distribution in the prior, we will have one highway flow with + # `num_layers` blocks, and each block has 4 trainable variables: + # `residual_fraction`, `lower_diagonal_weights_matrix`, + # `upper_diagonal_weights_matrix` and `bias`. + for original_dist in prior_dists: + expected_num_trainable_variables += (4 * num_layers) + return expected_num_trainable_variables + + def test_dims_and_gradients(self): + prior_dist = self.make_prior_dist() + num_layers = 3 + surrogate_posterior = tfp.experimental.vi.build_cascading_flow_surrogate_posterior( + prior=prior_dist, num_layers=num_layers) + + # Test that the correct number of trainable variables are being tracked + self.assertLen(surrogate_posterior.trainable_variables, + self._expected_num_trainable_variables(prior_dist, + num_layers)) + + # Test that the sample shape is correct + three_posterior_samples = surrogate_posterior.sample( + 3, seed=(0,0)) + three_prior_samples = prior_dist.sample( + 3, seed=(0,0)) + self.assertAllEqualNested( + [s.shape for s in tf.nest.flatten(three_prior_samples)], + [s.shape for s in tf.nest.flatten(three_posterior_samples)]) + + # Test that gradients are available wrt the variational parameters. + with tf.GradientTape() as tape: + posterior_sample = surrogate_posterior.sample( + seed=(0,0)) + posterior_logprob = surrogate_posterior.log_prob(posterior_sample) + grad = tape.gradient(posterior_logprob, + surrogate_posterior.trainable_variables) + self.assertTrue(all(g is not None for g in grad)) + + def test_initialization_is_deterministic_following_seed(self): + prior_dist = self.make_prior_dist() + + surrogate_posterior = tfp.experimental.vi.build_cascading_flow_surrogate_posterior( + prior=prior_dist, + seed=(0,0)) + self.evaluate( + [v.initializer for v in surrogate_posterior.trainable_variables]) + posterior_sample = surrogate_posterior.sample( + seed=(0,0)) + + surrogate_posterior2 = tfp.experimental.vi.build_cascading_flow_surrogate_posterior( + prior=prior_dist, + seed=(0,0)) + self.evaluate( + [v.initializer for v in surrogate_posterior2.trainable_variables]) + posterior_sample2 = surrogate_posterior2.sample( + seed=(0,0)) + + self.assertAllEqualNested(posterior_sample, posterior_sample2) + + def test_surrogate_and_prior_have_same_domain(self): + prior_dist = self.make_prior_dist() + surrogate_posterior = tfp.experimental.vi.build_cascading_flow_surrogate_posterior( + prior=prior_dist, + seed=(0,0)) + self.assertAllFinite(prior_dist.log_prob( + surrogate_posterior.sample(10, seed=(0,0)))) + +@test_util.test_all_tf_execution_regimes +class CFSurrogatePosteriorTestBrownianMotion(test_util.TestCase, + _TrainableCFSurrogate): + + def make_prior_dist(self): + + def _prior_model_fn(): + innovation_noise = 0.1 + prior_loc = 0. + new = yield tfd.Normal(loc=prior_loc, scale=innovation_noise) + for _ in range(4): + new = yield tfd.Normal(loc=new, scale=innovation_noise) + + return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) + + def make_likelihood_model(self, x, observation_noise): + + def _likelihood_model(): + for i in range(5): + yield tfd.Normal(loc=x[i], scale=observation_noise) + + return tfd.JointDistributionCoroutineAutoBatched(_likelihood_model) + + def get_observations(self, prior_dist): + observation_noise = 0.15 + ground_truth = prior_dist.sample() + likelihood = self.make_likelihood_model( + x=ground_truth, observation_noise=observation_noise) + + return likelihood.sample(1) + + def get_target_log_prob(self, observations, prior_dist): + + def target_log_prob(*x): + observation_noise = 0.15 + likelihood_dist = self.make_likelihood_model( + x=x, observation_noise=observation_noise) + return likelihood_dist.log_prob(observations) + prior_dist.log_prob( + x) + + return target_log_prob + + def test_fitting_surrogate_posterior(self): + + prior_dist = self.make_prior_dist() + observations = self.get_observations(prior_dist) + surrogate_posterior = tfp.experimental.vi.build_cascading_flow_surrogate_posterior( + prior=prior_dist) + target_log_prob = self.get_target_log_prob(observations, prior_dist) + + # Test vi fit surrogate posterior works + losses = tfp.vi.fit_surrogate_posterior( + target_log_prob, + surrogate_posterior, + num_steps=5, # Don't optimize to completion. + optimizer=tf.optimizers.Adam(1e-3), + sample_size=10) + + # Compute posterior statistics. + with tf.control_dependencies([losses]): + posterior_samples = surrogate_posterior.sample(100) + posterior_mean = tf.nest.map_structure(tf.reduce_mean, + posterior_samples) + posterior_stddev = tf.nest.map_structure(tf.math.reduce_std, + posterior_samples) + + self.evaluate(tf1.global_variables_initializer()) + _ = self.evaluate(losses) + _ = self.evaluate(posterior_mean) + _ = self.evaluate(posterior_stddev) + + +@test_util.test_all_tf_execution_regimes +class CFSurrogatePosteriorTestEightSchools(test_util.TestCase, + _TrainableCFSurrogate): + def make_prior_dist(self): + treatment_effects = tf.constant([28, 8, -3, 7, -1, 1, 18, 12], + dtype=tf.float32) + num_schools = ps.shape(treatment_effects)[-1] + + return tfd.JointDistributionNamed({ + 'avg_effect': + tfd.Normal(loc=0., scale=10., name='avg_effect'), + 'log_stddev': + tfd.Normal(loc=5., scale=1., name='log_stddev'), + 'school_effects': + lambda log_stddev, avg_effect: ( + # pylint: disable=g-long-lambda + tfd.Independent( + tfd.Normal( + loc=avg_effect[..., None] * tf.ones(num_schools), + scale=tf.exp(log_stddev[..., None]) * tf.ones( + num_schools), + name='school_effects'), + reinterpreted_batch_ndims=1)) + }) + + +@test_util.test_all_tf_execution_regimes +class CFSurrogatePosteriorTestEightSchoolsSample(test_util.TestCase, + _TrainableCFSurrogate): + + def make_prior_dist(self): + return tfd.JointDistributionNamed({ + 'avg_effect': + tfd.Normal(loc=0., scale=10., name='avg_effect'), + 'log_stddev': + tfd.Normal(loc=5., scale=1., name='log_stddev'), + 'school_effects': + lambda log_stddev, avg_effect: ( + # pylint: disable=g-long-lambda + tfd.Sample( + tfd.Normal( + loc=avg_effect[..., None], + scale=tf.exp(log_stddev[..., None]), + name='school_effects'), + sample_shape=[8])) + }) + + +@test_util.test_all_tf_execution_regimes +class CFSurrogatePosteriorTestHalfNormal(test_util.TestCase, + _TrainableCFSurrogate): + + def make_prior_dist(self): + def _prior_model_fn(): + innovation_noise = 1. + yield tfd.HalfNormal( + scale=innovation_noise, validate_args=True, + allow_nan_stats=False) + + return tfd.JointDistributionCoroutineAutoBatched(_prior_model_fn) + +@test_util.test_all_tf_execution_regimes +class CFSurrogatePosteriorTestNesting(test_util.TestCase, + _TrainableCFSurrogate): + + def make_prior_dist(self): + def nested_model(): + a = yield tfd.Sample( + tfd.Sample( + tfd.Normal(0., 1.), + sample_shape=4), + sample_shape=[2], + name='a') + b = yield tfb.Sigmoid()( + tfb.Square()( + tfd.Exponential(rate=tf.exp(a))), + name='b') + # pylint: disable=g-long-lambda + yield tfd.JointDistributionSequential( + [tfd.Laplace(loc=a, scale=b), + lambda c1: tfd.Independent( + tfd.Beta(concentration1=1., + concentration0=tf.nn.softplus(c1)), + reinterpreted_batch_ndims=1), + lambda c1, c2: tfd.JointDistributionNamed({ + 'x': tfd.Gamma(concentration=tf.nn.softplus(c1), rate=c2)}) + ], name='c') + # pylint: enable=g-long-lambda + + return tfd.JointDistributionCoroutineAutoBatched(nested_model) + + +if __name__ == '__main__': + tf.test.main()