Skip to content

Commit

Permalink
Expose tensorflow.experimental.numpy API to numpy and jax backends
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaspi committed Sep 26, 2022
1 parent 56c5c16 commit 26f4f12
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 14 deletions.
2 changes: 2 additions & 0 deletions tensorflow_probability/python/internal/backend/jax/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def main(argv):
if FLAGS.rewrite_numpy_import:
contents = contents.replace('\nimport numpy as np',
'\nimport numpy as onp; import jax.numpy as np')
contents = contents.replace('\nimport numpy as tnp',
'\nimport jax.numpy as tnp')
else:
contents = contents.replace('\nimport numpy as np',
'\nimport numpy as np; onp = np')
Expand Down
14 changes: 4 additions & 10 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@
# Dependency imports
import numpy as np
import tensorflow.compat.v2 as tf

if NUMPY_MODE:
take_along_axis = np.take_along_axis
elif JAX_MODE:
from jax.numpy import take_along_axis
else:
from tensorflow.python.ops.numpy_ops import take_along_axis
import tensorflow.experimental.numpy as tnp

from tensorflow_probability.python.internal import assert_util
from tensorflow_probability.python.internal import distribution_util
Expand Down Expand Up @@ -802,10 +796,10 @@ def windowed_variance(
def index_for_cumulative(indices):
return tf.maximum(indices - 1, 0)
cum_sums = tf.cumsum(x, axis=axis)
sums = take_along_axis(
sums = tnp.take_along_axis(
cum_sums, index_for_cumulative(indices), axis=axis)
cum_variances = cumulative_variance(x, sample_axis=axis)
variances = take_along_axis(
variances = tnp.take_along_axis(
cum_variances, index_for_cumulative(indices), axis=axis)

# This formula is the binary accurate variance merge from [1],
Expand Down Expand Up @@ -906,7 +900,7 @@ def windowed_mean(
paddings = ps.reshape(ps.one_hot(2*axis, depth=2*rank, dtype=tf.int32),
(rank, 2))
cum_sums = ps.pad(raw_cumsum, paddings)
sums = take_along_axis(cum_sums, indices, axis=axis)
sums = tnp.take_along_axis(cum_sums, indices, axis=axis)
counts = ps.cast(indices[1] - indices[0], dtype=sums.dtype)
return tf.math.divide_no_nan(sums[1] - sums[0], counts)

Expand Down
8 changes: 4 additions & 4 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,17 +788,17 @@ def check_windowed(self, func, numpy_func):
check_fn((64, 4, 8), (2, 4), axis=2)

def test_windowed_mean(self):
self.check_windowed(func=tfp.stats.windowed_mean, numpy_func=np.mean)
self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean)

def test_windowed_mean_graph(self):
func = tf.function(tfp.stats.windowed_mean)
func = tf.function(sample_stats.windowed_mean)
self.check_windowed(func=func, numpy_func=np.mean)

def test_windowed_variance(self):
self.check_windowed(func=tfp.stats.windowed_variance, numpy_func=np.var)
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)

def test_windowed_variance_graph(self):
func = tf.function(tfp.stats.windowed_variance)
func = tf.function(sample_stats.windowed_variance)
self.check_windowed(func=func, numpy_func=np.var)


Expand Down
1 change: 1 addition & 0 deletions tensorflow_probability/substrates/meta/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
TF_REPLACEMENTS = {
'import tensorflow ':
'from tensorflow_probability.python.internal.backend import numpy ',
'import tensorflow.experimental.numpy as tnp': 'import numpy as tnp',
'import tensorflow.compat.v1':
'from tensorflow_probability.python.internal.backend.numpy.compat '
'import v1',
Expand Down

0 comments on commit 26f4f12

Please sign in to comment.