Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JointDistributionCoroutineAutoBatched.sample_distributions errors using jax substrate #1976

Open
jeffpollock9 opened this issue Dec 5, 2024 · 0 comments

Comments

@jeffpollock9
Copy link
Contributor

Hi, I found that sample_distributions can error when using the jax substrate. sample seems to be ok, using tf appears to be ok, and using the non auto batched joint distribution is also ok.

Here is a small example:

from functools import partial

import jax
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions

tfp.__version__
# 0.25.0

@partial(tfd.JointDistributionCoroutine, batch_ndims=0)
def joint_dist():
    x = yield tfd.Gamma(2.0, 10.0, name="x")
    y = yield tfd.Gamma(x, 10.0, name="y")


seed = jax.random.key(123)

# ok
dists, samples = joint_dist.sample_distributions(x=[1.0, 2.0], seed=seed)

# samples:
# StructTuple(
#   x=Array([1., 2.], dtype=float32),
#   y=Array([0.10665689, 0.21802416], dtype=float32)
# )


@tfd.JointDistributionCoroutineAutoBatched
def joint_dist():
    x = yield tfd.Gamma(2.0, 10.0, name="x")
    y = yield tfd.Gamma(x, 10.0, name="y")


# ok
samples = joint_dist.sample(x=[1.0, 2.0], seed=seed)

# samples:
# StructTuple(
#   x=Array([1., 2.], dtype=float32),
#   y=Array([0.05508393, 0.14792603], dtype=float32)
# )

# ValueError: Attempt to convert a value (<object object at 0x717aa57398a0>) with an unsupported type (<class 'object'>) to a Tensor.
dists, samples = joint_dist.sample_distributions(x=[1.0, 2.0], seed=seed)

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant