Skip to content

Commit

Permalink
Add an optional epsilon to avoid NANs when big and small are too clos…
Browse files Browse the repository at this point in the history
…e in computing log_prob.

PiperOrigin-RevId: 558746118
  • Loading branch information
DistraxDev authored and DistraxDev committed Aug 21, 2023
1 parent 09c0ce1 commit 56a13dc
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion distrax/_src/distributions/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from distrax._src.distributions import distribution as base_distribution
from distrax._src.utils import conversion
from distrax._src.utils import math
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

Expand Down Expand Up @@ -48,7 +49,8 @@ class Quantized(
def __init__(self,
distribution: DistributionLike,
low: Optional[Numeric] = None,
high: Optional[Numeric] = None):
high: Optional[Numeric] = None,
eps: Optional[Numeric] = None):
"""Initializes a Quantized distribution.
Args:
Expand All @@ -61,9 +63,13 @@ def __init__(self,
floor(high)`. Its shape must broadcast with the shape of samples from
`distribution` and must not result in additional batch dimensions after
broadcasting.
eps: An optional gap to enforce between "big" and "small". Useful for
avoiding NANs in computing log_probs, when "big" and "small"
are too close.
"""
self._dist: base_distribution.Distribution[Array, Tuple[
int, ...], jnp.dtype] = conversion.as_distribution(distribution)
self._eps = eps
if self._dist.event_shape:
raise ValueError(f'The base distribution must be univariate, but its '
f'`event_shape` is {self._dist.event_shape}.')
Expand Down Expand Up @@ -180,6 +186,10 @@ def log_prob(self, value: EventT) -> Array:
# which happens to the right of the median of the distribution.
big = jnp.where(log_sf < log_cdf, log_sf_m1, log_cdf)
small = jnp.where(log_sf < log_cdf, log_sf, log_cdf_m1)
if self._eps is not None:
# use stop_gradient to block updating in this case
big = jnp.where(big - small > self._eps, big,
jax.lax.stop_gradient(small) + self._eps)
log_probs = math.log_expbig_minus_expsmall(big, small)

# Return -inf when evaluating on non-integer value.
Expand Down

0 comments on commit 56a13dc

Please sign in to comment.