Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the second example, Parallel parameter space exploration, is platform-dependent #46

Open
gouchangjiang opened this issue Nov 26, 2023 · 2 comments

Comments

@gouchangjiang
Copy link

There are some small issues:

  1. the installation should be pip install ."[dev]", we need the quotes.
  2. the second example, Parallel parameter space exploration, is platform-dependent. I am running it on a 12-core M2 Pro, it shows that the size does not match. I am trying to change the shape to make it divisible by 12, but it incurs new problems. Details listed below. By the way, to disable GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu

InconclusiveDimensionOperation Traceback (most recent call last)
/Users/cjgou/cjgou/vbjax-example/examples.ipynb Cell 5 line 6
4 log_ks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j]
5 pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy()
----> 6 pars = pars.reshape((3, vb.cores))
7 result = run_batches(pars)
8 pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7)

File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:145, in _reshape(a, order, *args)
143 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
144 if order == "C":
--> 145 return lax.reshape(a, newshape, None)
146 elif order == "F":
147 dims = list(range(a.ndim)[::-1])

File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/lax/lax.py:857, in reshape(operand, new_sizes, dimensions)
854 else:
855 dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
--> 857 return reshape_p.bind(
858 operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
859 dimensions=None if dims is None or same_dims else dims)

File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
...
1850 if sz1 % sz2:
-> 1851 raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
1852 return sz1 // sz2

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)

@maedoc
Copy link
Member

maedoc commented Nov 27, 2023

hi! thanks for trying out the examples,

pip install ."[dev]", we need the quotes

yeah I guess only Bash is ok with it, but neither zsh (the default on macOS) or cmd on Windows like it.

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)

indeed, the example code is dependent on the number of cores present. I'll fix this shortly.

GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu

thanks, are you using the experimental Metal backend for Jax on your Mac?

@gouchangjiang
Copy link
Author

As you mentioned, a third of testing cases failed on jax-metal. I tried it again with the latest jax-metal 0.0.4, the same, many tests failed. So, I switched to CPU, the default is GPU, to disable the GPU, we need to set the environment variable JAX_PLATFORMS=cpu.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants