diff --git a/tensorflow_probability/python/internal/backend/jax/rewrite.py b/tensorflow_probability/python/internal/backend/jax/rewrite.py index 68efbd20f5..f7ac15a370 100644 --- a/tensorflow_probability/python/internal/backend/jax/rewrite.py +++ b/tensorflow_probability/python/internal/backend/jax/rewrite.py @@ -45,6 +45,8 @@ def main(argv): if FLAGS.rewrite_numpy_import: contents = contents.replace('\nimport numpy as np', '\nimport numpy as onp; import jax.numpy as np') + contents = contents.replace('\nimport numpy as tnp', + '\nimport jax.numpy as tnp') else: contents = contents.replace('\nimport numpy as np', '\nimport numpy as np; onp = np') diff --git a/tensorflow_probability/python/stats/sample_stats.py b/tensorflow_probability/python/stats/sample_stats.py index 1c9e82166f..fecf5c2547 100644 --- a/tensorflow_probability/python/stats/sample_stats.py +++ b/tensorflow_probability/python/stats/sample_stats.py @@ -17,6 +17,7 @@ # Dependency imports import numpy as np import tensorflow.compat.v2 as tf +import tensorflow.experimental.numpy as tnp from tensorflow_probability.python.internal import assert_util from tensorflow_probability.python.internal import distribution_util @@ -694,8 +695,8 @@ def cumulative_variance(x, sample_axis=0, name=None): excl_counts = tf.reshape(tf.range(size, dtype=x.dtype), shape=counts_shp) incl_counts = excl_counts + 1 excl_sums = tf.cumsum(x, axis=sample_axis, exclusive=True) - discrepancies = (excl_sums / excl_counts - x)**2 - discrepancies = tf.where(excl_counts == 0, x**2, discrepancies) + discrepancies = tf.math.square(excl_sums / excl_counts - x) + discrepancies = tf.where(excl_counts == 0, tf.math.square(x), discrepancies) adjustments = excl_counts / incl_counts # The zeroth item's residual contribution is 0, because it has no # other items to vary from. The preceding expressions, however, @@ -712,11 +713,11 @@ def windowed_variance( Computes variances among data in the Tensor `x` along the given windows: - result[i] = variance(x[low_indices[i]:high_indices[i]+1]) + result[i] = variance(x[low_indices[i]:high_indices[i]]) - accurately and efficiently. To wit, if K is the size of - `low_indices` and `high_indices`, and `N` is the size of `x` along - the given `axis`, the computation takes O(K + N) work, O(log(N)) + accurately and efficiently. To wit, if `m` is the size of + `low_indices` and `high_indices`, and `n` is the size of `x` along + the given `axis`, the computation takes O(n + m) work, O(log(n)) depth (the length of the longest series of operations that are performed sequentially), and only uses O(1) TensorFlow kernel invocations. The underlying algorithm is an adaptation of the @@ -726,11 +727,19 @@ def windowed_variance( trailing-window estimators from some iterative process, such as the last half of an MCMC chain. - Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has - rank `axis`, and `low_indices` and `high_indices` broadcast to shape - `[M]`. Then each element of `low_indices` and `high_indices` - must be between 0 and N+1, and the shape of the output will be - `Bx + [M] + E`. Batch shape in the indices is not currently supported. + Suppose `x` has shape `Bx + [n] + E`, `low_indices` and `high_indices` + have shape `Bi + [m] + F`, such that `rank(Bx) = rank(Bi) = axis`. + Then each element of `low_indices` and `high_indices` must be + between 0 and `n+1`, and the shape of the output will be + `broadcast(Bx, Bi) + [m] + broadcast(E, F)`. + + The shape `Bi + [1] + F` must be implicitly broadcastable with the + shape of `x`, the following implicit broadcasting rules are applied: + + If `rank(Bi + [m] + F) < rank(x)`, then the indices are expanded + with extra inner dimensions to match the rank of `x`. + If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + the indices are reshaped to `[1] * rank(Bx) + [m] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -739,14 +748,14 @@ def windowed_variance( in the variance of the last half of the data at each point. Args: - x: A numeric `Tensor` holding `N` samples along the given `axis`, + x: A numeric `Tensor` holding `n` samples along the given `axis`, whose windowed variances are desired. low_indices: An integer `Tensor` defining the lower boundary (inclusive) of each window. Default: elementwise half of `high_indices`. high_indices: An integer `Tensor` defining the upper boundary (exclusive) of each window. Must be broadcast-compatible with - `low_indices`. Default: `tf.range(1, N+1)`, i.e., N windows + `low_indices`. Default: `tf.range(1, n+1)`, i.e., n windows that each end in the corresponding datum from `x` (inclusive)`. axis: Scalar `Tensor` designating the axis holding samples. This is the axis of `x` along which we take windows, and therefore @@ -769,7 +778,7 @@ def windowed_variance( """ with tf.name_scope(name or 'windowed_variance'): x = tf.convert_to_tensor(x) - low_indices, high_indices, low_counts, high_counts = _prepare_window_args( + x, indices, axis = _prepare_window_args( x, low_indices, high_indices, axis) # We have a problem with indexing: the standard convention demands @@ -786,15 +795,11 @@ def windowed_variance( def index_for_cumulative(indices): return tf.maximum(indices - 1, 0) cum_sums = tf.cumsum(x, axis=axis) - low_sums = tf.gather( - cum_sums, index_for_cumulative(low_indices), axis=axis) - high_sums = tf.gather( - cum_sums, index_for_cumulative(high_indices), axis=axis) + sums = tnp.take_along_axis( + cum_sums, index_for_cumulative(indices), axis=axis) cum_variances = cumulative_variance(x, sample_axis=axis) - low_variances = tf.gather( - cum_variances, index_for_cumulative(low_indices), axis=axis) - high_variances = tf.gather( - cum_variances, index_for_cumulative(high_indices), axis=axis) + variances = tnp.take_along_axis( + cum_variances, index_for_cumulative(indices), axis=axis) # This formula is the binary accurate variance merge from [1], # adapted to subtract and batched across the indexed counts, sums, @@ -812,15 +817,18 @@ def index_for_cumulative(indices): # This formula can also be read as implementing the above variance # computation by "unioning" A u B with a notional "negative B" # multiset. - counts = high_counts - low_counts # |A| - discrepancies = ( - _safe_average(high_sums, high_counts) - - _safe_average(low_sums, low_counts))**2 # (mean(A u B) - mean(B))**2 - adjustments = high_counts * (-low_counts) / counts # |A u B| * -|B| / |A| - residuals = (high_variances * high_counts - - low_variances * low_counts + + bounds = ps.cast(indices, sums.dtype) + counts = bounds[1] - bounds[0] # |A| + sum_averages = tf.math.divide_no_nan(sums, bounds) + # (mean(A u B) - mean(B))**2 + discrepancies = tf.square(sum_averages[1] - sum_averages[0]) + # |A u B| * -|B| / |A| + adjustments = tf.math.divide_no_nan(bounds[1] * (-bounds[0]), counts) + variances_scaled = variances * bounds + residuals = (variances_scaled[1] - + variances_scaled[0] + adjustments * discrepancies) - return _safe_average(residuals, counts) + return tf.math.divide_no_nan(residuals, counts) def windowed_mean( @@ -829,11 +837,11 @@ def windowed_mean( Computes means among data in the Tensor `x` along the given windows: - result[i] = mean(x[low_indices[i]:high_indices[i]+1]) + result[i] = mean(x[low_indices[i]:high_indices[i]]) - efficiently. To wit, if K is the size of `low_indices` and - `high_indices`, and `N` is the size of `x` along the given `axis`, - the computation takes O(K + N) work, O(log(N)) depth (the length of + efficiently. To wit, if `m` is the size of `low_indices` and + `high_indices`, and `n` is the size of `x` along the given `axis`, + the computation takes O(m + n) work, O(log(n)) depth (the length of the longest series of operations that are performed sequentially), and only uses O(1) TensorFlow kernel invocations. @@ -841,11 +849,19 @@ def windowed_mean( trailing-window estimators from some iterative process, such as the last half of an MCMC chain. - Suppose `x` has shape `Bx + [N] + E`, where the `Bx` component has - rank `axis`, and `low_indices` and `high_indices` broadcast to shape - `[M]`. Then each element of `low_indices` and `high_indices` - must be between 0 and N+1, and the shape of the output will be - `Bx + [M] + E`. Batch shape in the indices is not currently supported. + Suppose `x` has shape `Bx + [n] + E`, `low_indices` and `high_indices` + have shape `Bi + [m] + F`, such that `rank(Bx) = rank(Bi) = axis`. + Then each element of `low_indices` and `high_indices` must be + between 0 and `n+1`, and the shape of the output will be + `broadcast(Bx, Bi) + [m] + broadcast(E, F)`. + + The shape `Bi + [1] + F` must be implicitly broadcastable with the + shape of `x`, the following implicit broadcasting rules are applied: + + If `rank(Bi + [m] + F) < rank(x)`, then the indices are expanded + with extra inner dimensions to match the rank of `x`. + If rank of indices is one, i.e when `rank(Bi) = rank(F) = 0`, + the indices are reshaped to `[1] * rank(Bx) + [m] + [1] * rank(E)`. The default windows are `[0, 1), [1, 2), [1, 3), [2, 4), [2, 5), ...` @@ -854,14 +870,14 @@ def windowed_mean( in the variance of the last half of the data at each point. Args: - x: A numeric `Tensor` holding `N` samples along the given `axis`, + x: A numeric `Tensor` holding `n` samples along the given `axis`, whose windowed means are desired. low_indices: An integer `Tensor` defining the lower boundary (inclusive) of each window. Default: elementwise half of `high_indices`. high_indices: An integer `Tensor` defining the upper boundary (exclusive) of each window. Must be broadcast-compatible with - `low_indices`. Default: `tf.range(1, N+1)`, i.e., N windows + `low_indices`. Default: `tf.range(1, n+1)`, i.e., n windows that each end in the corresponding datum from `x` (inclusive). axis: Scalar `Tensor` designating the axis holding samples. This is the axis of `x` along which we take windows, and therefore @@ -878,58 +894,60 @@ def windowed_mean( """ with tf.name_scope(name or 'windowed_mean'): x = tf.convert_to_tensor(x) - low_indices, high_indices, low_counts, high_counts = _prepare_window_args( - x, low_indices, high_indices, axis) + x, indices, axis = _prepare_window_args(x, low_indices, high_indices, axis) raw_cumsum = tf.cumsum(x, axis=axis) - cum_sums = tf.concat( - [tf.zeros_like(tf.gather(raw_cumsum, [0], axis=axis)), raw_cumsum], - axis=axis) - low_sums = tf.gather(cum_sums, low_indices, axis=axis) - high_sums = tf.gather(cum_sums, high_indices, axis=axis) - - counts = high_counts - low_counts - return _safe_average(high_sums - low_sums, counts) + rank = ps.rank(x) + paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32), + (rank, 2)) + cum_sums = ps.pad(raw_cumsum, paddings) + sums = tnp.take_along_axis(cum_sums, indices, axis=axis) + counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype) + return tf.math.divide_no_nan(sums[1] - sums[0], counts) def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0): """Common argument defaulting logic for windowed statistics.""" if high_indices is None: - high_indices = tf.range(ps.shape(x)[axis]) + 1 + high_indices = ps.range(ps.shape(x)[axis]) + 1 else: high_indices = tf.convert_to_tensor(high_indices) if low_indices is None: low_indices = high_indices // 2 else: low_indices = tf.convert_to_tensor(low_indices) + + indices_rank = tf.get_static_value(ps.rank(low_indices)) + x_rank = tf.get_static_value(ps.rank(x)) + if indices_rank is None or x_rank is None: + raise ValueError("`indices` and `x` ranks must be statically known.") + # Broadcast indices together. high_indices = high_indices + tf.zeros_like(low_indices) low_indices = low_indices + tf.zeros_like(high_indices) - # TODO(axch): Support batch low and high indices. That would - # complicate this shape munging (though tf.gather should work - # fine). - - # We want to place `low_counts` and `high_counts` at the `axis` - # position, so we reshape them to shape `[1, 1, ..., 1, N, 1, ..., - # 1]`, where the `N` is at `axis`. The `counts_shp`, below, - # is this shape. - size = ps.size(high_indices) - counts_shp = ps.one_hot( - axis, depth=ps.rank(x), on_value=size, off_value=1) - - low_counts = tf.reshape(tf.cast(low_indices, dtype=x.dtype), - shape=counts_shp) - high_counts = tf.reshape(tf.cast(high_indices, dtype=x.dtype), - shape=counts_shp) - return low_indices, high_indices, low_counts, high_counts - - -def _safe_average(totals, counts): - # This tf.where protects `totals` from getting a gradient signal - # when `counts` is 0. - safe_totals = tf.where(~tf.equal(counts, 0), totals, 0) - return tf.where(~tf.equal(counts, 0), safe_totals / counts, 0) + indices_shape = ps.shape(low_indices) + if ps.rank(low_indices) < ps.rank(x): + if ps.rank(low_indices) == 1: + size = ps.size(low_indices) + bc_shape = ps.one_hot(axis, depth=ps.rank(x), on_value=size, + off_value=1) + else: + # we assume the first dimensions are broadcastable with `x`, + # we add trailing dimensions + extra_dims = ps.rank(x) - ps.rank(low_indices) + bc_shape = ps.concat([indices_shape, [1]*extra_dims], axis=0) + else: + bc_shape = indices_shape + + bc_shape = ps.concat([[2], bc_shape], axis=0) + indices = ps.stack([low_indices, high_indices], axis=0) + indices = ps.reshape(indices, bc_shape) + x = tf.expand_dims(x, axis=0) + axis += 1 + # `take_along_axis` requires the type to be int32 + indices = ps.cast(indices, dtype=tf.int32) + return x, indices, axis def log_average_probs(logits, sample_axis=0, event_axis=None, keepdims=False, diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index 235e32a014..a274926848 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -15,10 +15,14 @@ """Tests for Sample Stats Ops.""" # Dependency imports +import itertools import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf +from absl.testing import parameterized +from tensorflow.python.framework.errors_impl import InvalidArgumentError + from tensorflow_probability.python.internal import test_util from tensorflow_probability.python.stats import sample_stats @@ -679,6 +683,129 @@ def test_windowed_mean_corner_cases(self): self.evaluate(sample_stats.windowed_mean(y))) +@test_util.test_all_tf_execution_regimes +class WindowedStatsTest(test_util.TestCase): + + def _maybe_expand_dims_to_make_broadcastable(self, x, shape, axis): + if len(shape) > len(x.shape): + if len(x.shape) == 1: + bc_shape = np.ones(len(shape), dtype=np.int32) + bc_shape[axis] = x.shape[0] + return x.reshape(bc_shape) + else: + extra_dims = len(shape) - len(x.shape) + bc_shape = x.shape + (1,) * extra_dims + return x.reshape(bc_shape) + return x + + def apply_slice_along_axis(self, func, arr, low, high, axis): + """Applies `func` over slices of `arr` along `axis`. Slices intervals are + specified through `low` and `high`. Support broadcasting. + """ + np.testing.assert_equal(low.shape, high.shape) + + def apply_func(vector, l, h): + return func(vector[l:h]) + + apply_func_1d = np.vectorize(apply_func, signature='(n), (), ()->()') + vectorized_func = np.vectorize(apply_func_1d, + signature='(n), (k), (k)->(m)') + + # Put `axis` at the innermost dimension + dims = list(range(arr.ndim)) + dims[-1] = axis + dims[axis] = arr.ndim - 1 + t_arr = np.transpose(arr, axes=dims) + t_low = np.transpose(low, axes=dims) + t_high = np.transpose(high, axes=dims) + + t_out = vectorized_func(t_arr, t_low, t_high) + + # Replace `axis` at its place + out = np.transpose(t_out, axes=dims) + return out + + def check_gaussian_windowed_func(self, shape, indice_shape, axis, + window_func, np_func): + stat_shape = np.array(shape).astype(np.int32) + stat_shape[axis] = 1 + loc = np.arange(np.prod(stat_shape)).reshape(stat_shape) + scale = 0.1 * np.arange(np.prod(stat_shape)).reshape(stat_shape) + rng = test_util.test_np_rng() + x = rng.normal(loc=loc, scale=scale, size=shape) + indice_shape = [2] + list(indice_shape) + indices = rng.randint(shape[axis] + 1, size=indice_shape) + indices = np.sort(indices, axis=0) + low_indices, high_indices = indices[0], indices[1] + + tf_low_indices = self._make_dynamic_shape(low_indices) + tf_high_indices = self._make_dynamic_shape(high_indices) + tf_x = self._make_dynamic_shape(x) + + a = window_func(tf_x, low_indices=tf_low_indices, + high_indices=tf_high_indices, axis=axis) + + low_indices = self._maybe_expand_dims_to_make_broadcastable( + low_indices, x.shape, axis) + high_indices = self._maybe_expand_dims_to_make_broadcastable( + high_indices, x.shape, axis) + b = self.apply_slice_along_axis(np_func, x, low_indices, high_indices, + axis=axis) + b[np.isnan(b)] = 0 # We treat stats computed on empty sets as zeros + self.assertAllClose(a, b) + + def _make_dynamic_shape(self, x): + return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape)) + + @parameterized.named_parameters(*[( + f"{np_func.__name__} shape={s} indices_shape={i} axis={axis}", s, i, axis, + tf_func, np_func) for s, (i, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8)], + [((128, 1, 1), 0), + ((32, 1, 1), 0), + ((32, 4, 1), 0), + ((32, 4, 8), 0), + ((64, 4, 8), 0), + ((128, 1), 0), + ((32,), 0), + ((32, 4), 0), + + ((64, 64, 1), 1), + ((1, 64, 1), 1), + ((64, 2, 8), 1), + ((64, 4, 8), 1), + ((16,), 1), + ((1, 64), 1), + + ((64, 4, 64), 2), + ((1, 1, 64), 2), + ((64, 4, 4), 2), + ((1, 1, 4), 2), + ((64, 4, 8), 2), + ((16,), 2), + ((1, 4), 2), + ((64, 4), 2)], + [(sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var)])]) + def test_windowed(self, shape, indice_shape, axis, window_func, np_func): + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, + np_func) + + @parameterized.named_parameters(*[( + f"{np_func.__name__} shape={s} indices_shape={i} axis={axis}", s, i, axis, + tf_func, np_func) for s, (i, axis), (tf_func, np_func) in + itertools.product([(64, 4, 8)], + [((4, 1, 4), 2), ((2, 4), 2)], + [(sample_stats.windowed_mean, np.mean), + (sample_stats.windowed_variance, np.var)])]) + def test_non_broadcastable_shapes(self, shape, indice_shape, axis, + window_func, np_func): + with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError), + '^shape mismatch|Incompatible shapes'): + self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func, + np_func) + + @test_util.test_all_tf_execution_regimes class LogAverageProbsTest(test_util.TestCase): diff --git a/tensorflow_probability/substrates/meta/rewrite.py b/tensorflow_probability/substrates/meta/rewrite.py index e71a75bf4e..8cbca19144 100644 --- a/tensorflow_probability/substrates/meta/rewrite.py +++ b/tensorflow_probability/substrates/meta/rewrite.py @@ -29,6 +29,7 @@ TF_REPLACEMENTS = { 'import tensorflow ': 'from tensorflow_probability.python.internal.backend import numpy ', + 'import tensorflow.experimental.numpy as tnp': 'import numpy as tnp', 'import tensorflow.compat.v1': 'from tensorflow_probability.python.internal.backend.numpy.compat ' 'import v1',