Skip to content

Commit

Permalink
initial attempt at refactoring ed.MetropolisHastings' impl
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Oct 4, 2017
1 parent c2b61a2 commit 12509cd
Showing 1 changed file with 87 additions and 43 deletions.
130 changes: 87 additions & 43 deletions edward/inferences/metropolis_hastings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

try:
from edward.models import Uniform
from tensorflow.contrib.bayesflow.metropolis_hastings import evolve
except Exception as e:
raise ImportError("{0}. Your TensorFlow version is not supported.".format(e))

Expand Down Expand Up @@ -64,6 +65,19 @@ def __init__(self, latent_vars, proposal_vars, data=None):

def initialize(self, *args, **kwargs):
kwargs['auto_transform'] = False

# TODO In general, each latent variable has arbitrary shape and
# dtype. We cannot simply batch them into a single tf.Tensor with
# an extra dimension.
initial_sample = tf.stack([tf.gather(qz.params, 0)
for qz in six.itervalues(self.latent_vars)])
self._state = tf.Variable(initial_sample, trainable=False, name="state")
self._state_log_density = tf.Variable(
self._log_joint(initial_sample),
trainable=False, name="state_log_density")
self._log_accept_ratio = tf.Variable(
tf.zeros_like(self._state_log_density.initialized_value()),
trainable=False, name="log_accept_ratio")
return super(MetropolisHastings, self).initialize(*args, **kwargs)

def build_update(self):
Expand All @@ -80,9 +94,75 @@ def build_update(self):
The updates assume each Empirical random variable is directly
parameterized by `tf.Variable`s.
"""
old_sample = {z: tf.gather(qz.params, tf.maximum(self.t - 1, 0))
for z, qz in six.iteritems(self.latent_vars)}
old_sample = OrderedDict(old_sample)
old_state = self._state
forward_step = evolve(self._state,
self._state_log_density,
self._log_accept_ratio,
self._log_density,
self._proposal_fn,
n_steps=1)
assign_ops = [forward_step]

with tf.control_dependencies([forward_step]):
# Update Empirical random variables.
for state, qz in zip(tf.unstack(self._state),
six.itervalues(self.latent_vars)):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, state))

# Increment n_accept (if accepted).
# TODO old_state might always be same. It would be great if we
# could more naturally get the acceptance rate from ``evolve``.
is_proposal_accepted = tf.where(
tf.reduce_any(tf.not_equal(old_state, self._state)), 1, 0)
assign_ops.append(self.n_accept.assign_add(is_proposal_accepted))

return tf.group(*assign_ops)

def _log_joint(self, state):
"""Utility function to calculate model's log joint density,
log p(x, z), for inputs z (and fixed data x).
Args:
state: tf.Tensor.
"""
scope = self._scope + tf.get_default_graph().unique_name("sample")
# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
# TODO verify ordering is preserved
dict_swap = {z: sample for z, sample in
zip(six.iterkeys(self.latent_vars), state)}
for x, qx in six.iteritems(self.data):
if isinstance(x, RandomVariable):
if isinstance(qx, RandomVariable):
qx_copy = copy(qx, scope=scope)
dict_swap[x] = qx_copy.value()
else:
dict_swap[x] = qx

log_joint = 0.0
for z in six.iterkeys(self.latent_vars):
z_copy = copy(z, dict_swap, scope=scope)
log_joint += tf.reduce_sum(z_copy.log_prob(dict_swap[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
x_copy = copy(x, dict_swap, scope=scope)
log_joint += tf.reduce_sum(x_copy.log_prob(dict_swap[x]))

return log_joint

def proposal_fn(state):
"""Utility function to propose new state,
znew ~ g(znew | zold) for inputs zold, and return the log density
ratio of log g(znew | zold) - log g(zold | znew).
Args:
state: tf.Tensor.
"""
# TODO verify ordering is preserved
old_sample = {z: sample for z, sample in
zip(six.iterkeys(self.latent_vars), state)}

# Form dictionary in order to replace conditioning on prior or
# observed variable with conditioning on a specific value.
Expand All @@ -99,7 +179,6 @@ def build_update(self):
dict_swap_old.update(old_sample)
base_scope = tf.get_default_graph().unique_name("inference") + '/'
scope_old = base_scope + 'old'
scope_new = base_scope + 'new'

# Draw proposed sample and calculate acceptance ratio.
new_sample = old_sample.copy() # copy to ensure same order
Expand All @@ -114,49 +193,14 @@ def build_update(self):

dict_swap_new = dict_swap.copy()
dict_swap_new.update(new_sample)
scope_new = base_scope + 'new'

for z, proposal_z in six.iteritems(self.proposal_vars):
# Build proposal g(zold | znew).
proposal_zold = copy(proposal_z, dict_swap_new, scope=scope_new)
# Increment ratio.
ratio -= tf.reduce_sum(proposal_zold.log_prob(dict_swap_old[z]))

for z in six.iterkeys(self.latent_vars):
# Build priors p(znew) and p(zold).
znew = copy(z, dict_swap_new, scope=scope_new)
zold = copy(z, dict_swap_old, scope=scope_old)
# Increment ratio.
ratio += tf.reduce_sum(znew.log_prob(dict_swap_new[z]))
ratio -= tf.reduce_sum(zold.log_prob(dict_swap_old[z]))

for x in six.iterkeys(self.data):
if isinstance(x, RandomVariable):
# Build likelihoods p(x | znew) and p(x | zold).
x_znew = copy(x, dict_swap_new, scope=scope_new)
x_zold = copy(x, dict_swap_old, scope=scope_old)
# Increment ratio.
ratio += tf.reduce_sum(x_znew.log_prob(dict_swap[x]))
ratio -= tf.reduce_sum(x_zold.log_prob(dict_swap[x]))

# Accept or reject sample.
u = Uniform(low=tf.constant(0.0, dtype=ratio.dtype),
high=tf.constant(1.0, dtype=ratio.dtype)).sample()
accept = tf.log(u) < ratio
sample_values = tf.cond(accept, lambda: list(six.itervalues(new_sample)),
lambda: list(six.itervalues(old_sample)))
if not isinstance(sample_values, list):
# `tf.cond` returns tf.Tensor if output is a list of size 1.
sample_values = [sample_values]

sample = {z: sample_value for z, sample_value in
zip(six.iterkeys(new_sample), sample_values)}

# Update Empirical random variables.
assign_ops = []
for z, qz in six.iteritems(self.latent_vars):
variable = qz.get_variables()[0]
assign_ops.append(tf.scatter_update(variable, self.t, sample[z]))

# Increment n_accept (if accepted).
assign_ops.append(self.n_accept.assign_add(tf.where(accept, 1, 0)))
return tf.group(*assign_ops)
# TODO verify ordering is preserved
new_sample = tf.stack(list(six.itervalues(new_sample)))
return (new_sample, ratio)

0 comments on commit 12509cd

Please sign in to comment.