diff --git a/distrax/_src/distributions/deterministic_test.py b/distrax/_src/distributions/deterministic_test.py index 0f67ee1..beca6c2 100644 --- a/distrax/_src/distributions/deterministic_test.py +++ b/distrax/_src/distributions/deterministic_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions import deterministic from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -107,10 +108,11 @@ def test_sample_shape(self, loc, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype)) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype)) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/epsilon_greedy_test.py b/distrax/_src/distributions/epsilon_greedy_test.py index 6e99e56..a7ae3a3 100644 --- a/distrax/_src/distributions/epsilon_greedy_test.py +++ b/distrax/_src/distributions/epsilon_greedy_test.py @@ -22,6 +22,7 @@ import chex from distrax._src.distributions import epsilon_greedy from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -51,11 +52,12 @@ def test_num_categories(self): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls( - preferences=self.preferences, epsilon=self.epsilon, dtype=dtype) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls( + preferences=self.preferences, epsilon=self.epsilon, dtype=dtype) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) def test_jittable(self): super()._test_jittable( diff --git a/distrax/_src/distributions/gamma_test.py b/distrax/_src/distributions/gamma_test.py index 62f619f..f5def6f 100644 --- a/distrax/_src/distributions/gamma_test.py +++ b/distrax/_src/distributions/gamma_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions import gamma from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -73,11 +74,12 @@ def test_sample_shape(self, distr_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls( - concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype)) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls( + concentration=jnp.ones((), dtype), rate=jnp.ones((), dtype)) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/greedy_test.py b/distrax/_src/distributions/greedy_test.py index 3f8280c..6c24139 100644 --- a/distrax/_src/distributions/greedy_test.py +++ b/distrax/_src/distributions/greedy_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions import greedy from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -48,10 +49,11 @@ def test_num_categories(self): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls(preferences=self.preferences, dtype=dtype) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls(preferences=self.preferences, dtype=dtype) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) def test_jittable(self): super()._test_jittable((np.array([0., 4., -1., 4.]),)) diff --git a/distrax/_src/distributions/gumbel_test.py b/distrax/_src/distributions/gumbel_test.py index 0dbae60..36dcd34 100644 --- a/distrax/_src/distributions/gumbel_test.py +++ b/distrax/_src/distributions/gumbel_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions import gumbel from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -67,11 +68,12 @@ def test_sample_shape(self, distr_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls( - loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls( + loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/laplace_test.py b/distrax/_src/distributions/laplace_test.py index 01a71f7..92c4e1c 100644 --- a/distrax/_src/distributions/laplace_test.py +++ b/distrax/_src/distributions/laplace_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions import laplace from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -65,11 +66,12 @@ def test_sample_shape(self, distr_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls( - loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls( + loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/log_stddev_normal_test.py b/distrax/_src/distributions/log_stddev_normal_test.py index 1b434f7..6611636 100644 --- a/distrax/_src/distributions/log_stddev_normal_test.py +++ b/distrax/_src/distributions/log_stddev_normal_test.py @@ -21,6 +21,7 @@ from distrax._src.distributions import log_stddev_normal as lsn from distrax._src.distributions import normal import jax +import jax.experimental import jax.numpy as jnp import mock import numpy as np @@ -105,11 +106,12 @@ def test_sampling_batched_custom_dim(self): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = lsn.LogStddevNormal( - loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype)) - samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0)) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = lsn.LogStddevNormal( + loc=jnp.zeros((), dtype), log_scale=jnp.zeros((), dtype)) + samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0)) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) def test_kl_versus_normal(self): loc, scale = jnp.array([2.0]), jnp.array([2.0]) diff --git a/distrax/_src/distributions/logistic_test.py b/distrax/_src/distributions/logistic_test.py index d37cee7..47aaf17 100644 --- a/distrax/_src/distributions/logistic_test.py +++ b/distrax/_src/distributions/logistic_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions import logistic from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -66,11 +67,12 @@ def test_sample_shape(self, distr_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls( - loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls( + loc=jnp.zeros((), dtype), scale=jnp.ones((), dtype)) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/multinomial_test.py b/distrax/_src/distributions/multinomial_test.py index 6739b6a..4498234 100644 --- a/distrax/_src/distributions/multinomial_test.py +++ b/distrax/_src/distributions/multinomial_test.py @@ -22,6 +22,7 @@ from distrax._src.utils import equivalence from distrax._src.utils import math import jax +import jax.experimental import jax.numpy as jnp import numpy as np from scipy import stats @@ -405,12 +406,16 @@ def test_sample_and_log_prob(self, dist_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist_params = { - 'logits': self.logits, 'dtype': dtype, 'total_count': self.total_count} - dist = self.distrax_cls(**dist_params) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist_params = { + 'logits': self.logits, + 'dtype': dtype, + 'total_count': self.total_count, + } + dist = self.distrax_cls(**dist_params) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants def test_sample_extreme_probs(self): diff --git a/distrax/_src/distributions/mvn_diag_plus_low_rank_test.py b/distrax/_src/distributions/mvn_diag_plus_low_rank_test.py index 8863e29..759312c 100644 --- a/distrax/_src/distributions/mvn_diag_plus_low_rank_test.py +++ b/distrax/_src/distributions/mvn_diag_plus_low_rank_test.py @@ -22,6 +22,7 @@ from distrax._src.utils import equivalence import jax +import jax.experimental import jax.numpy as jnp import numpy as np from tensorflow_probability.substrates import jax as tfp @@ -180,13 +181,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_diag_shape, ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist_params = { - 'loc': np.array([0., 0.], dtype), - 'scale_diag': np.array([1., 1.], dtype)} - dist = MultivariateNormalDiagPlusLowRank(**dist_params) - samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0)) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist_params = { + 'loc': np.array([0., 0.], dtype), + 'scale_diag': np.array([1., 1.], dtype)} + dist = MultivariateNormalDiagPlusLowRank(**dist_params) + samples = self.variant(dist.sample)(seed=jax.random.PRNGKey(0)) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/mvn_diag_test.py b/distrax/_src/distributions/mvn_diag_test.py index 04a84b0..f4704b0 100644 --- a/distrax/_src/distributions/mvn_diag_test.py +++ b/distrax/_src/distributions/mvn_diag_test.py @@ -22,6 +22,7 @@ from distrax._src.distributions import normal from distrax._src.utils import equivalence import jax +import jax.experimental import jax.numpy as jnp import numpy as np @@ -214,13 +215,14 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist_params = { - 'loc': np.array([0., 0.], dtype), - 'scale_diag': np.array([1., 1.], dtype)} - dist = self.distrax_cls(**dist_params) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist_params = { + 'loc': np.array([0., 0.], dtype), + 'scale_diag': np.array([1., 1.], dtype)} + dist = self.distrax_cls(**dist_params) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/mvn_full_covariance_test.py b/distrax/_src/distributions/mvn_full_covariance_test.py index b47f6be..04e4d86 100644 --- a/distrax/_src/distributions/mvn_full_covariance_test.py +++ b/distrax/_src/distributions/mvn_full_covariance_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions.mvn_full_covariance import MultivariateNormalFullCovariance from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -106,13 +107,14 @@ def test_sample_shape(self, sample_shape, loc_shape, covariance_matrix_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist_params = { - 'loc': np.array([0., 0.], dtype), - 'covariance_matrix': np.array([[1., 0.], [0., 1.]], dtype)} - dist = self.distrax_cls(**dist_params) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist_params = { + 'loc': np.array([0., 0.], dtype), + 'covariance_matrix': np.array([[1., 0.], [0., 1.]], dtype)} + dist = self.distrax_cls(**dist_params) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/mvn_tri_test.py b/distrax/_src/distributions/mvn_tri_test.py index a2c6a14..b1d7916 100644 --- a/distrax/_src/distributions/mvn_tri_test.py +++ b/distrax/_src/distributions/mvn_tri_test.py @@ -20,6 +20,7 @@ import chex from distrax._src.distributions.mvn_tri import MultivariateNormalTri from distrax._src.utils import equivalence +import jax.experimental import jax.numpy as jnp import numpy as np @@ -114,13 +115,14 @@ def test_sample_shape(self, sample_shape, loc_shape, scale_tri_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist_params = { - 'loc': np.array([0., 0.], dtype), - 'scale_tri': np.array([[1., 0.], [0., 1.]], dtype)} - dist = self.distrax_cls(**dist_params) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist_params = { + 'loc': np.array([0., 0.], dtype), + 'scale_tri': np.array([[1., 0.], [0., 1.]], dtype)} + dist = self.distrax_cls(**dist_params) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/one_hot_categorical_test.py b/distrax/_src/distributions/one_hot_categorical_test.py index 6899f6a..d1b7bc7 100644 --- a/distrax/_src/distributions/one_hot_categorical_test.py +++ b/distrax/_src/distributions/one_hot_categorical_test.py @@ -23,6 +23,7 @@ from distrax._src.utils import equivalence from distrax._src.utils import math import jax +import jax.experimental import jax.numpy as jnp import numpy as np import scipy @@ -178,11 +179,12 @@ def test_sample_and_log_prob(self, distr_params, sample_shape): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist_params = {'logits': self.logits, 'dtype': dtype} - dist = self.distrax_cls(**dist_params) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist_params = {'logits': self.logits, 'dtype': dtype} + dist = self.distrax_cls(**dist_params) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) @chex.all_variants @parameterized.named_parameters( diff --git a/distrax/_src/distributions/softmax_test.py b/distrax/_src/distributions/softmax_test.py index bc1cea4..1580b05 100644 --- a/distrax/_src/distributions/softmax_test.py +++ b/distrax/_src/distributions/softmax_test.py @@ -71,11 +71,12 @@ def test_parameters(self): ('float32', jnp.float32), ('float64', jnp.float64)) def test_sample_dtype(self, dtype): - dist = self.distrax_cls( - logits=self.logits, temperature=self.temperature, dtype=dtype) - samples = self.variant(dist.sample)(seed=self.key) - self.assertEqual(samples.dtype, dist.dtype) - chex.assert_type(samples, dtype) + with jax.experimental.enable_x64(dtype.dtype.itemsize == 8): + dist = self.distrax_cls( + logits=self.logits, temperature=self.temperature, dtype=dtype) + samples = self.variant(dist.sample)(seed=self.key) + self.assertEqual(samples.dtype, dist.dtype) + chex.assert_type(samples, dtype) def test_jittable(self): super()._test_jittable((np.array([2., 4., 1., 3.]),))