Skip to content

Commit

Permalink
Rewrite apply_slice_along_axis using np.vectorize
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaspi committed Sep 26, 2022
1 parent 26f4f12 commit 169f7f5
Showing 1 changed file with 20 additions and 22 deletions.
42 changes: 20 additions & 22 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,24 +699,26 @@ def apply_slice_along_axis(self, func, arr, low, high, axis):
specified through `low` and `high`. Support broadcasting.
"""
np.testing.assert_equal(low.shape, high.shape)
ni, _, nk = arr.shape[:axis], arr.shape[axis], arr.shape[axis + 1:]
si, j, sk = low.shape[:axis], low.shape[axis], low.shape[axis + 1:]
mk = max(nk, sk)
mi = max(ni, si)
out = np.empty(mi + (j,) + mk)
for ki in np.ndindex(ni):
for kk in np.ndindex(mk):
ak = tuple(np.mod(kk, nk))
ik = tuple(np.mod(kk, sk))
ai = tuple(np.mod(ki, ni))
ii = tuple(np.mod(ki, si))
a_1d = arr[ai + np.s_[:, ] + ak]
out_1d = out[ki + np.s_[:, ] + kk]
low_1d = low[ii + np.s_[:, ] + ik]
high_1d = high[ii + np.s_[:, ] + ik]

for r in range(j):
out_1d[r] = func(a_1d[low_1d[r]:high_1d[r]])

def apply_func(vector, l, h):
return func(vector[l:h])

apply_func_1d = np.vectorize(apply_func, signature='(n), (), ()->()')
vectorized_func = np.vectorize(apply_func_1d,
signature='(n), (k), (k)->(m)')

# Put `axis` at the innermost dimension
dims = list(range(arr.ndim))
dims[-1] = axis
dims[axis] = arr.ndim - 1
t_arr = np.transpose(arr, axes=dims)
t_low = np.transpose(low, axes=dims)
t_high = np.transpose(high, axes=dims)

t_out = vectorized_func(t_arr, t_low, t_high)

# Replace `axis` at its place
out = np.transpose(t_out, axes=dims)
return out

def check_gaussian_windowed(self, shape, indice_shape, axis,
Expand Down Expand Up @@ -797,10 +799,6 @@ def test_windowed_mean_graph(self):
def test_windowed_variance(self):
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)

def test_windowed_variance_graph(self):
func = tf.function(sample_stats.windowed_variance)
self.check_windowed(func=func, numpy_func=np.var)


@test_util.test_all_tf_execution_regimes
class LogAverageProbsTest(test_util.TestCase):
Expand Down

0 comments on commit 169f7f5

Please sign in to comment.