You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I found that sample_distributions can error when using the jax substrate. sample seems to be ok, using tf appears to be ok, and using the non auto batched joint distribution is also ok.
Here is a small example:
fromfunctoolsimportpartialimportjaxfromtensorflow_probability.substratesimportjaxastfptfd=tfp.distributionstfp.__version__# 0.25.0@partial(tfd.JointDistributionCoroutine, batch_ndims=0)defjoint_dist():
x=yieldtfd.Gamma(2.0, 10.0, name="x")
y=yieldtfd.Gamma(x, 10.0, name="y")
seed=jax.random.key(123)
# okdists, samples=joint_dist.sample_distributions(x=[1.0, 2.0], seed=seed)
# samples:# StructTuple(# x=Array([1., 2.], dtype=float32),# y=Array([0.10665689, 0.21802416], dtype=float32)# )@tfd.JointDistributionCoroutineAutoBatcheddefjoint_dist():
x=yieldtfd.Gamma(2.0, 10.0, name="x")
y=yieldtfd.Gamma(x, 10.0, name="y")
# oksamples=joint_dist.sample(x=[1.0, 2.0], seed=seed)
# samples:# StructTuple(# x=Array([1., 2.], dtype=float32),# y=Array([0.05508393, 0.14792603], dtype=float32)# )# ValueError: Attempt to convert a value (<object object at 0x717aa57398a0>) with an unsupported type (<class 'object'>) to a Tensor.dists, samples=joint_dist.sample_distributions(x=[1.0, 2.0], seed=seed)
Thanks!
The text was updated successfully, but these errors were encountered:
Hi, I found that
sample_distributions
can error when using the jax substrate.sample
seems to be ok, using tf appears to be ok, and using the non auto batched joint distribution is also ok.Here is a small example:
Thanks!
The text was updated successfully, but these errors were encountered: