Skip to content

Commit

Permalink
Add an optional epsilon to avoid NANs when big and small are too clos…
Browse files Browse the repository at this point in the history
…e in computing log_prob.

PiperOrigin-RevId: 558746118
  • Loading branch information
DistraxDev authored and DistraxDev committed Aug 21, 2023
1 parent 09c0ce1 commit 8a0b5ff
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions distrax/_src/distributions/quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from distrax._src.distributions import distribution as base_distribution
from distrax._src.utils import conversion
from distrax._src.utils import math
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

Expand Down Expand Up @@ -47,12 +48,16 @@ class Quantized(

def __init__(self,
distribution: DistributionLike,
eps: Optional[Numeric] = None,
low: Optional[Numeric] = None,
high: Optional[Numeric] = None):
"""Initializes a Quantized distribution.
Args:
distribution: The base distribution to be quantized.
eps: An optional gap to enforce between "big" and "small". Useful for
avoiding NANs in computing log_probs, when "big" and "small"
are too close.
low: Lowest possible quantized value, such that samples are `y >=
ceil(low)`. Its shape must broadcast with the shape of samples from
`distribution` and must not result in additional batch dimensions after
Expand All @@ -64,6 +69,7 @@ def __init__(self,
"""
self._dist: base_distribution.Distribution[Array, Tuple[
int, ...], jnp.dtype] = conversion.as_distribution(distribution)
self._eps = eps
if self._dist.event_shape:
raise ValueError(f'The base distribution must be univariate, but its '
f'`event_shape` is {self._dist.event_shape}.')
Expand Down Expand Up @@ -180,6 +186,10 @@ def log_prob(self, value: EventT) -> Array:
# which happens to the right of the median of the distribution.
big = jnp.where(log_sf < log_cdf, log_sf_m1, log_cdf)
small = jnp.where(log_sf < log_cdf, log_sf, log_cdf_m1)
if self._eps is not None:
# use stop_gradient to block updating in this case
big = jnp.where(big - small > self._eps, big,
jax.lax.stop_gradient(small) + self._eps)
log_probs = math.log_expbig_minus_expsmall(big, small)

# Return -inf when evaluating on non-integer value.
Expand Down

0 comments on commit 8a0b5ff

Please sign in to comment.