Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultivariateNormalDiag vmap issue #276

Open
haydn-jones opened this issue Jul 15, 2024 · 0 comments
Open

MultivariateNormalDiag vmap issue #276

haydn-jones opened this issue Jul 15, 2024 · 0 comments

Comments

@haydn-jones
Copy link

haydn-jones commented Jul 15, 2024

It's unclear to me why the following code does not work as MultivariateNormalDiag supports batch dimensions for loc and scale:

import distrax as dx
import jax
import jax.numpy as jnp
from jax import vmap


@jax.jit
def build():
    def single(i):
        return dx.MultivariateNormalDiag(jnp.zeros(10), jnp.ones(10))

    x = vmap(single)(jnp.arange(10))
    return x


dist = build()
dist.loc

produces the following error:

Traceback (most recent call last):
  File ".../test.py", line 17, in <module>
    dist.loc
  File ".../python3.12/site-packages/distrax/_src/distributions/mvn_from_bijector.py", line 103, in loc
    return jnp.broadcast_to(self._loc, shape=shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 2087, in broadcast_to
    return util._broadcast_to(array, shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../python3.12/site-packages/jax/_src/numpy/util.py", line 422, in _broadcast_to
    raise ValueError(f"Cannot broadcast to shape with fewer dimensions: {arr_shape=} {shape=}")
ValueError: Cannot broadcast to shape with fewer dimensions: arr_shape=(10, 10) shape=(10,)

This seems similar to #239

Ah, I see in the README that this distribution is specifically called out for being problematic with vmap.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant