Skip to content

Commit

Permalink
Check for statically known rank
Browse files Browse the repository at this point in the history
Parametrize tests
  • Loading branch information
nicolaspi committed Sep 26, 2022
1 parent 169f7f5 commit c90e961
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
6 changes: 6 additions & 0 deletions tensorflow_probability/python/stats/sample_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,12 @@ def _prepare_window_args(x, low_indices=None, high_indices=None, axis=0):
low_indices = high_indices // 2
else:
low_indices = tf.convert_to_tensor(low_indices)

indices_rank = tf.get_static_value(ps.rank(low_indices))
x_rank = tf.get_static_value(ps.rank(x))
if indices_rank is None or x_rank is None:
raise ValueError("`indices` and `x` ranks must be statically known.")

# Broadcast indices together.
high_indices = high_indices + tf.zeros_like(low_indices)
low_indices = low_indices + tf.zeros_like(high_indices)
Expand Down
104 changes: 57 additions & 47 deletions tensorflow_probability/python/stats/sample_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
"""Tests for Sample Stats Ops."""

# Dependency imports
import functools
import itertools

import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
from absl.testing import parameterized
from tensorflow.python.framework.errors_impl import InvalidArgumentError

from tensorflow_probability.python.internal import test_util
from tensorflow_probability.python.stats import sample_stats

Expand Down Expand Up @@ -721,7 +725,8 @@ def apply_func(vector, l, h):
out = np.transpose(t_out, axes=dims)
return out

def check_gaussian_windowed(self, shape, indice_shape, axis,

def check_gaussian_windowed_func(self, shape, indice_shape, axis,
window_func, np_func):
stat_shape = np.array(shape).astype(np.int32)
stat_shape[axis] = 1
Expand Down Expand Up @@ -753,51 +758,56 @@ def check_gaussian_windowed(self, shape, indice_shape, axis,
def _make_dynamic_shape(self, x):
return tf1.placeholder_with_default(x, shape=(None,)*len(x.shape))

def check_windowed(self, func, numpy_func):
check_fn = functools.partial(self.check_gaussian_windowed,
window_func=func, np_func=numpy_func)
check_fn((64, 4, 8), (128, 1, 1), axis=0)
check_fn((64, 4, 8), (32, 1, 1), axis=0)
check_fn((64, 4, 8), (32, 4, 1), axis=0)
check_fn((64, 4, 8), (32, 4, 8), axis=0)
check_fn((64, 4, 8), (64, 4, 8), axis=0)
check_fn((64, 4, 8), (128, 1), axis=0)
check_fn((64, 4, 8), (32,), axis=0)
check_fn((64, 4, 8), (32, 4), axis=0)

check_fn((64, 4, 8), (64, 64, 1), axis=1)
check_fn((64, 4, 8), (1, 64, 1), axis=1)
check_fn((64, 4, 8), (64, 2, 8), axis=1)
check_fn((64, 4, 8), (64, 4, 8), axis=1)
check_fn((64, 4, 8), (16,), axis=1)
check_fn((64, 4, 8), (1, 64), axis=1)

check_fn((64, 4, 8), (64, 4, 64), axis=2)
check_fn((64, 4, 8), (1, 1, 64), axis=2)
check_fn((64, 4, 8), (64, 4, 4), axis=2)
check_fn((64, 4, 8), (1, 1, 4), axis=2)
check_fn((64, 4, 8), (64, 4, 8), axis=2)
check_fn((64, 4, 8), (16,), axis=2)
check_fn((64, 4, 8), (1, 4), axis=2)
check_fn((64, 4, 8), (64, 4), axis=2)

with self.assertRaises(Exception):
# Non broadcastable shapes
check_fn((64, 4, 8), (4, 1, 4), axis=2)

with self.assertRaises(Exception):
# Non broadcastable shapes
check_fn((64, 4, 8), (2, 4), axis=2)

def test_windowed_mean(self):
self.check_windowed(func=sample_stats.windowed_mean, numpy_func=np.mean)

def test_windowed_mean_graph(self):
func = tf.function(sample_stats.windowed_mean)
self.check_windowed(func=func, numpy_func=np.mean)

def test_windowed_variance(self):
self.check_windowed(func=sample_stats.windowed_variance, numpy_func=np.var)
@parameterized.named_parameters(*[(
f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis,
tf_func, np_func) for a, (b, axis), (tf_func, np_func) in
itertools.product([(64, 4, 8), ],
[((128, 1, 1), 0),
((32, 1, 1), 0),
((32, 4, 1), 0),
((32, 4, 8), 0),
((64, 4, 8), 0),
((128, 1), 0),
((32,), 0),
((32, 4), 0),
((64, 64, 1), 1),
((1, 64, 1), 1),
((64, 2, 8), 1),
((64, 4, 8), 1),
((16,), 1),
((1, 64), 1),
((64, 4, 64), 2),
((1, 1, 64), 2),
((64, 4, 4), 2),
((1, 1, 4), 2),
((64, 4, 8), 2),
((16,), 2),
((1, 4), 2),
((64, 4), 2)],
[
(sample_stats.windowed_mean, np.mean),
(sample_stats.windowed_variance, np.var)
])])
def test_windowed(self, shape, indice_shape, axis, window_func, np_func):
self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func,
np_func)


@parameterized.named_parameters(*[(
f"{np_func.__name__} shape={a} indices_shape={b} axis={axis}", a, b, axis,
tf_func, np_func) for a, (b, axis), (tf_func, np_func) in
itertools.product([(64, 4, 8), ],
[((4, 1, 4), 2), ((2, 4), 2)],
[(sample_stats.windowed_mean, np.mean),
(sample_stats.windowed_variance, np.var)])])
def test_non_broadcastable_shapes(self, shape, indice_shape, axis,
window_func, np_func):
with self.assertRaisesRegexp((IndexError, ValueError, InvalidArgumentError),
'^shape mismatch|Incompatible shapes'):
self.check_gaussian_windowed_func(shape, indice_shape, axis, window_func,
np_func)


@test_util.test_all_tf_execution_regimes
Expand Down

0 comments on commit c90e961

Please sign in to comment.