You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm trying to sample from the predictive distribution of a trained Gaussian process with two outputs (dataset.out_dim = 2). I'm not specifying any output kernel function, so I'm assuming that a multioutput GP with a shared kernel is trained under the hood. There are no issues with the training, but I'm unable to sample from the predictive distribution, altough it works fine when I use the same code in a one-dimensional setting.
I get the following error:
sample = predictive_dist.sample(sample_shape=(num_samples,), seed=rng_key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/gpjax/distributions.py", line 282, in sample
sample = self._distribution.sample(seed, sample_shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1205, in sample
return self._call_sample_n(sample_shape, seed, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 333, in _call_sample_n
x = self._maybe_broadcast_distribution_batch_shape().sample(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1205, in sample
return self._call_sample_n(sample_shape, seed, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1182, in _call_sample_n
samples = self._sample_n(
^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/sample.py", line 224, in _sample_n
x = self.distribution.sample(ps.concat([[n], sample_shape], axis=0),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1205, in sample
return self._call_sample_n(sample_shape, seed, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1182, in _call_sample_n
samples = self._sample_n(
^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/distributions/normal.py", line 179, in _sample_n
sampled = samplers.normal(
^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py", line 271, in normal
return tf.random.stateless_normal(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/random_generators.py", line 145, in _normal_jax
shape = _bcast_shape(shape, [mean, stddev])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/random_generators.py", line 65, in _bcast_shape
bcast_shape = ops.broadcast_shape(bcast_shape, np.asarray(arg).shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py", line 86, in _broadcast_static_shape
if (tensor_shape.TensorShape(shape_x).ndims is None or
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py", line 880, in __init__
self._dims = tuple(as_dimension(d).value for d in dims)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py", line 880, in <genexpr>
self._dims = tuple(as_dimension(d).value for d in dims)
^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py", line 794, in as_dimension
return Dimension(value)
^^^^^^^^^^^^^^^^
File "/softs/conda/auto/envs/jax-cpu/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/tensor_shape.py", line 258, in __init__
raise ValueError("Dimension %d must be >= 0" % value)
ValueError: Dimension -1910093456 must be >= 0
This error can be reproduced with the following script:
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi, I'm trying to sample from the predictive distribution of a trained Gaussian process with two outputs (
dataset.out_dim = 2
). I'm not specifying any output kernel function, so I'm assuming that a multioutput GP with a shared kernel is trained under the hood. There are no issues with the training, but I'm unable to sample from the predictive distribution, altough it works fine when I use the same code in a one-dimensional setting.I get the following error:
This error can be reproduced with the following script:
Is this a bug, or am I doing something wrong ?
Thank you !
Beta Was this translation helpful? Give feedback.
All reactions