diff --git a/tensorflow_probability/python/stats/sample_stats_test.py b/tensorflow_probability/python/stats/sample_stats_test.py index de1e4bfb96..ce0d91b357 100644 --- a/tensorflow_probability/python/stats/sample_stats_test.py +++ b/tensorflow_probability/python/stats/sample_stats_test.py @@ -699,24 +699,26 @@ def apply_slice_along_axis(self, func, arr, low, high, axis): specified through `low` and `high`. Support broadcasting. """ np.testing.assert_equal(low.shape, high.shape) - ni, _, nk = arr.shape[:axis], arr.shape[axis], arr.shape[axis + 1:] - si, j, sk = low.shape[:axis], low.shape[axis], low.shape[axis + 1:] - mk = max(nk, sk) - mi = max(ni, si) - out = np.empty(mi + (j,) + mk) - for ki in np.ndindex(ni): - for kk in np.ndindex(mk): - ak = tuple(np.mod(kk, nk)) - ik = tuple(np.mod(kk, sk)) - ai = tuple(np.mod(ki, ni)) - ii = tuple(np.mod(ki, si)) - a_1d = arr[ai + np.s_[:, ] + ak] - out_1d = out[ki + np.s_[:, ] + kk] - low_1d = low[ii + np.s_[:, ] + ik] - high_1d = high[ii + np.s_[:, ] + ik] - - for r in range(j): - out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]]) + + def apply_func(vector, l, h): + return func(vector[l:h]) + + apply_func_1d = np.vectorize(apply_func, signature='(n), (), ()->()') + vectorized_func = np.vectorize(apply_func_1d, + signature='(n), (k), (k)->(m)') + + # Put `axis` at the innermost dimension + dims = list(range(arr.ndim)) + dims[-1] = axis + dims[axis] = arr.ndim - 1 + t_arr = np.transpose(arr, axes=dims) + t_low = np.transpose(low, axes=dims) + t_high = np.transpose(high, axes=dims) + + t_out = vectorized_func(t_arr, t_low, t_high) + + # Replace `axis` at its place + out = np.transpose(t_out, axes=dims) return out def check_gaussian_windowed(self, shape, indice_shape, axis, @@ -797,10 +799,6 @@ def test_windowed_mean_graph(self): def test_windowed_variance(self): self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var) - def test_windowed_variance_graph(self): - func = tf.function(sample_stats.windowed_variance) - self.check_windowed(func=func, numpy_func=np.var) - @test_util.test_all_tf_execution_regimes class LogAverageProbsTest(test_util.TestCase):