Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite Tensor._sample to handle dice_factor as Delta's log_density #569

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 45 additions & 81 deletions funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,94 +335,58 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
if not sampled_vars:
return self

# Partition inputs into sample_inputs + batch_inputs + event_inputs.
sample_inputs = OrderedDict(
(k, d) for k, d in sample_inputs.items() if k not in self.inputs
)
sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
results = []
backend = get_backend()
remaining_vars = set(sampled_vars)
term = self
shape = tuple(d.size for k, d in sample_inputs.items() if k not in self.inputs)
shape += tuple(d.size for k, d in self.inputs.items() if k not in sampled_vars)
batch_inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if k not in sampled_vars
(k, v) for k, v in sample_inputs.items() if k not in self.inputs
)
event_inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if k in sampled_vars
batch_inputs.update(
(k, v) for k, v in self.inputs.items() if k not in sampled_vars
)
be_inputs = batch_inputs.copy()
be_inputs.update(event_inputs)
sb_inputs = sample_inputs.copy()
sb_inputs.update(batch_inputs)

# Sample all variables in a single Categorical call.
logits = align_tensor(be_inputs, self)
batch_shape = logits.shape[: len(batch_inputs)]
flat_logits = logits.reshape(batch_shape + (-1,))
sample_shape = tuple(d.dtype for d in sample_inputs.values())
while remaining_vars:
name = remaining_vars.pop()
domain = self.inputs[name]
logits = funsor.Lambda(
Variable(name, domain), term.reduce(ops.logaddexp, remaining_vars)
)

backend = get_backend()
if backend != "numpy":
from importlib import import_module
if backend != "numpy":
from importlib import import_module

dist = import_module(
funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]
)
sample_args = (
(sample_shape,) if rng_key is None else (rng_key, sample_shape)
)
flat_sample = dist.CategoricalLogits.dist_class(logits=flat_logits).sample(
*sample_args
)
else: # default numpy backend
assert backend == "numpy"
shape = sample_shape + flat_logits.shape[:-1]
logit_max = np.amax(flat_logits, -1, keepdims=True)
probs = np.exp(flat_logits - logit_max)
probs = probs / np.sum(probs, -1, keepdims=True)
s = np.cumsum(probs, -1)
r = np.random.rand(*shape)
flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1)

assert flat_sample.shape == sample_shape + batch_shape
results = []
mod_sample = flat_sample
for name, domain in reversed(list(event_inputs.items())):
size = domain.dtype
point = Tensor(mod_sample % size, sb_inputs, size)
mod_sample = mod_sample // size
results.append(Delta(name, point))

# Account for the log normalizer factor.
# Derivation: Let f be a nonnormalized distribution (a funsor), and
# consider operations in linear space (source code is in log space).
# Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|.
# f(x0) / |f| # dice numerator
# Let g = delta(x=x0) |f| -----------------
# detach(f(x0)/|f|) # dice denominator
# |detach(f)| f(x0)
# = delta(x=x0) ----------------- be a dice approximation of f.
# detach(f(x0))
# Then g is an unbiased estimator of f in value and all derivatives.
# In the special case f = detach(f), we can simplify to
# g = delta(x=x0) |f|.
if (backend == "torch" and flat_logits.requires_grad) or backend == "jax":
# Apply a dice factor to preserve differentiability.
index = [
ops.new_arange(self.data, n).reshape(
(n,) + (1,) * (len(flat_logits.shape) - i - 2)
dist = import_module(
funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]
)
for i, n in enumerate(flat_logits.shape[:-1])
]
index.append(flat_sample)
log_prob = flat_logits[tuple(index)]
assert log_prob.shape == flat_sample.shape
results.append(
Tensor(
ops.logsumexp(ops.detach(flat_logits), -1)
+ (log_prob - ops.detach(log_prob)),
sb_inputs,
sample_inputs = OrderedDict(
(k, v) for k, v in sample_inputs.items() if k not in logits.inputs
)
)
else:
# This is the special case f = detach(f).
results.append(Tensor(ops.logsumexp(flat_logits, -1), batch_inputs))
delta = dist.CategoricalLogits(logits=logits, value=name)._sample(
frozenset({name}), sample_inputs, rng_key
)
point = delta.terms[0][1][0]
log_density = delta.terms[0][1][1] + self.reduce(
ops.logaddexp, sampled_vars
) / len(sampled_vars)
term = term(**{name: point})
sample = Delta(name, point, log_density)
results.append(sample)
else: # default numpy backend
assert backend == "numpy"
probs = (logits - ops.logsumexp(logits)).exp().data
# shape = sample_shape + flat_logits.shape[:-1]
# logit_max = np.amax(flat_logits, -1, keepdims=True)
# probs = np.exp(flat_logits - logit_max)
# probs = probs / np.sum(probs, -1, keepdims=True)
s = np.cumsum(probs, -1)
r = np.random.rand(*shape)
flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1)
point = Tensor(flat_sample, batch_inputs, domain.dtype)
term = term(**{name: point})
sample = Delta(name, point)
results.append(sample)

return reduce(ops.add, results)

Expand Down