diff --git a/edward/inferences/klpq.py b/edward/inferences/klpq.py index a362d73e9..0dcaa71b4 100644 --- a/edward/inferences/klpq.py +++ b/edward/inferences/klpq.py @@ -9,6 +9,11 @@ from edward.models import RandomVariable from edward.util import copy, get_descendants +try: + from edward.models import Normal +except Exception as e: + raise ImportError("{0}. Your TensorFlow version is not supported.".format(e)) + class KLpq(VariationalInference): """Variational inference with the KL divergence @@ -41,8 +46,38 @@ class KLpq(VariationalInference): where $z^{(s)} \sim q(z; \lambda)$ and$\\beta^{(s)} \sim q(\\beta)$. """ - def __init__(self, *args, **kwargs): - super(KLpq, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(KLpq, self).__init__(latent_vars, data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/edward/inferences/klqp.py b/edward/inferences/klqp.py index 491620d13..c6a9b386b 100644 --- a/edward/inferences/klqp.py +++ b/edward/inferences/klqp.py @@ -47,8 +47,38 @@ class KLqp(VariationalInference): where $z^{(s)} \sim q(z; \lambda)$ and $\\beta^{(s)} \sim q(\\beta)$. """ - def __init__(self, *args, **kwargs): - super(KLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(KLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, kl_scaling=None, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -135,8 +165,38 @@ class ReparameterizationKLqp(VariationalInference): This class minimizes the objective using the reparameterization gradient. """ - def __init__(self, *args, **kwargs): - super(ReparameterizationKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ReparameterizationKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -162,8 +222,38 @@ class ReparameterizationKLKLqp(VariationalInference): This class minimizes the objective using the reparameterization gradient and an analytic KL term. """ - def __init__(self, *args, **kwargs): - super(ReparameterizationKLKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ReparameterizationKLKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, kl_scaling=None, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -203,8 +293,38 @@ class ReparameterizationEntropyKLqp(VariationalInference): This class minimizes the objective using the reparameterization gradient and an analytic entropy term. """ - def __init__(self, *args, **kwargs): - super(ReparameterizationEntropyKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ReparameterizationEntropyKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -231,8 +351,38 @@ class ScoreKLqp(VariationalInference): This class minimizes the objective using the score function gradient. """ - def __init__(self, *args, **kwargs): - super(ScoreKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ScoreKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -258,8 +408,38 @@ class ScoreKLKLqp(VariationalInference): This class minimizes the objective using the score function gradient and an analytic KL term. """ - def __init__(self, *args, **kwargs): - super(ScoreKLKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ScoreKLKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, kl_scaling=None, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -299,8 +479,38 @@ class ScoreEntropyKLqp(VariationalInference): This class minimizes the objective using the score function gradient and an analytic entropy term. """ - def __init__(self, *args, **kwargs): - super(ScoreEntropyKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ScoreEntropyKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters @@ -333,8 +543,38 @@ class ScoreRBKLqp(VariationalInference): Rao-Blackwellize within a node such as when a node represents multiple random variables via non-scalar batch shape. """ - def __init__(self, *args, **kwargs): - super(ScoreRBKLqp, self).__init__(*args, **kwargs) + def __init__(self, latent_vars=None, data=None): + """Create an inference algorithm. + + Args: + latent_vars: list of RandomVariable or + dict of RandomVariable to RandomVariable. + Collection of random variables to perform inference on. If + list, each random variable will be implictly optimized using a + `Normal` random variable that is defined internally with a + free parameter per location and scale and is initialized using + standard normal draws. The random variables to approximate + must be continuous. + """ + if isinstance(latent_vars, list): + with tf.variable_scope(None, default_name="posterior"): + latent_vars_dict = {} + continuous = \ + ('01', 'nonnegative', 'simplex', 'real', 'multivariate_real') + for z in latent_vars: + if not hasattr(z, 'support') or z.support not in continuous: + raise AttributeError( + "Random variable {} is not continuous or a random " + "variable with supported continuous support.".format(z)) + batch_event_shape = z.batch_shape.concatenate(z.event_shape) + loc = tf.Variable(tf.random_normal(batch_event_shape)) + scale = tf.nn.softplus( + tf.Variable(tf.random_normal(batch_event_shape))) + latent_vars_dict[z] = Normal(loc=loc, scale=scale) + latent_vars = latent_vars_dict + del latent_vars_dict + + super(ScoreRBKLqp, self).__init__(latent_vars, data) def initialize(self, n_samples=1, *args, **kwargs): """Initialize inference algorithm. It initializes hyperparameters diff --git a/tests/inferences/test_klpq.py b/tests/inferences/test_klpq.py index 994594a46..8447f1f8c 100644 --- a/tests/inferences/test_klpq.py +++ b/tests/inferences/test_klpq.py @@ -11,7 +11,7 @@ class test_klpq_class(tf.test.TestCase): - def _test_normal_normal(self, Inference, *args, **kwargs): + def _test_normal_normal(self, Inference, default, *args, **kwargs): with self.test_session() as sess: x_data = np.array([0.0] * 50, dtype=np.float32) @@ -22,8 +22,16 @@ def _test_normal_normal(self, Inference, *args, **kwargs): qmu_scale = tf.nn.softplus(tf.Variable(tf.random_normal([]))) qmu = Normal(loc=qmu_loc, scale=qmu_scale) - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = Inference({mu: qmu}, data={x: x_data}) + if not default: + qmu_loc = tf.Variable(tf.random_normal([])) + qmu_scale = tf.nn.softplus(tf.Variable(tf.random_normal([]))) + qmu = Normal(loc=qmu_loc, scale=qmu_scale) + + # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) + inference = Inference({mu: qmu}, data={x: x_data}) + else: + inference = Inference([mu], data={x: x_data}) + qmu = inference.latent_vars[mu] inference.run(*args, **kwargs) self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) @@ -52,7 +60,8 @@ def _test_model_parameter(self, Inference, *args, **kwargs): self.assertAllClose(p.eval(), 0.2, rtol=5e-2, atol=5e-2) def test_klpq(self): - self._test_normal_normal(ed.KLpq, n_samples=25, n_iter=100) + self._test_normal_normal(ed.KLpq, default=False, n_samples=25, n_iter=100) + self._test_normal_normal(ed.KLpq, default=True, n_samples=25, n_iter=100) self._test_model_parameter(ed.KLpq, n_iter=50) if __name__ == '__main__': diff --git a/tests/inferences/test_klqp.py b/tests/inferences/test_klqp.py index a73e90d0c..dce58cf7f 100644 --- a/tests/inferences/test_klqp.py +++ b/tests/inferences/test_klqp.py @@ -11,24 +11,28 @@ class test_klqp_class(tf.test.TestCase): - def _test_normal_normal(self, Inference, *args, **kwargs): + def _test_normal_normal(self, Inference, default, *args, **kwargs): with self.test_session() as sess: x_data = np.array([0.0] * 50, dtype=np.float32) mu = Normal(loc=0.0, scale=1.0) x = Normal(loc=mu, scale=1.0, sample_shape=50) - qmu_loc = tf.Variable(tf.random_normal([])) - qmu_scale = tf.nn.softplus(tf.Variable(tf.random_normal([]))) - qmu = Normal(loc=qmu_loc, scale=qmu_scale) + if not default: + qmu_loc = tf.Variable(tf.random_normal([])) + qmu_scale = tf.nn.softplus(tf.Variable(tf.random_normal([]))) + qmu = Normal(loc=qmu_loc, scale=qmu_scale) - # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) - inference = Inference({mu: qmu}, data={x: x_data}) + # analytic solution: N(loc=0.0, scale=\sqrt{1/51}=0.140) + inference = Inference({mu: qmu}, data={x: x_data}) + else: + inference = Inference([mu], data={x: x_data}) + qmu = inference.latent_vars[mu] inference.run(*args, **kwargs) - self.assertAllClose(qmu.mean().eval(), 0, rtol=1e-1, atol=1e-1) + self.assertAllClose(qmu.mean().eval(), 0, rtol=0.15, atol=0.15) self.assertAllClose(qmu.stddev().eval(), np.sqrt(1 / 51), - rtol=1e-1, atol=1e-1) + rtol=0.15, atol=0.15) variables = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope='optimizer') @@ -52,35 +56,57 @@ def _test_model_parameter(self, Inference, *args, **kwargs): self.assertAllClose(p.eval(), 0.2, rtol=5e-2, atol=5e-2) def test_klqp(self): - self._test_normal_normal(ed.KLqp, n_iter=5000) + self._test_normal_normal(ed.KLqp, default=False, n_iter=5000) + self._test_normal_normal(ed.KLqp, default=True, n_iter=5000) self._test_model_parameter(ed.KLqp, n_iter=50) def test_reparameterization_entropy_klqp(self): - self._test_normal_normal(ed.ReparameterizationEntropyKLqp, n_iter=5000) + self._test_normal_normal( + ed.ReparameterizationEntropyKLqp, default=False, n_iter=5000) + self._test_normal_normal( + ed.ReparameterizationEntropyKLqp, default=True, n_iter=5000) self._test_model_parameter(ed.ReparameterizationEntropyKLqp, n_iter=50) def test_reparameterization_klqp(self): - self._test_normal_normal(ed.ReparameterizationKLqp, n_iter=5000) + self._test_normal_normal( + ed.ReparameterizationKLqp, default=False, n_iter=5000) + self._test_normal_normal( + ed.ReparameterizationKLqp, default=True, n_iter=5000) self._test_model_parameter(ed.ReparameterizationKLqp, n_iter=50) def test_reparameterization_kl_klqp(self): - self._test_normal_normal(ed.ReparameterizationKLKLqp, n_iter=5000) + self._test_normal_normal( + ed.ReparameterizationKLKLqp, default=False, n_iter=5000) + self._test_normal_normal( + ed.ReparameterizationKLKLqp, default=True, n_iter=5000) self._test_model_parameter(ed.ReparameterizationKLKLqp, n_iter=50) def test_score_entropy_klqp(self): - self._test_normal_normal(ed.ScoreEntropyKLqp, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreEntropyKLqp, default=False, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreEntropyKLqp, default=True, n_samples=5, n_iter=5000) self._test_model_parameter(ed.ScoreEntropyKLqp, n_iter=50) def test_score_klqp(self): - self._test_normal_normal(ed.ScoreKLqp, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreKLqp, default=False, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreKLqp, default=True, n_samples=5, n_iter=5000) self._test_model_parameter(ed.ScoreKLqp, n_iter=50) def test_score_kl_klqp(self): - self._test_normal_normal(ed.ScoreKLKLqp, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreKLKLqp, default=False, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreKLKLqp, default=True, n_samples=5, n_iter=5000) self._test_model_parameter(ed.ScoreKLKLqp, n_iter=50) def test_score_rb_klqp(self): - self._test_normal_normal(ed.ScoreRBKLqp, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreRBKLqp, default=False, n_samples=5, n_iter=5000) + self._test_normal_normal( + ed.ScoreRBKLqp, default=True, n_samples=5, n_iter=5000) self._test_model_parameter(ed.ScoreRBKLqp, n_iter=50) if __name__ == '__main__':