You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
the installation should be pip install ."[dev]", we need the quotes.
the second example, Parallel parameter space exploration, is platform-dependent. I am running it on a 12-core M2 Pro, it shows that the size does not match. I am trying to change the shape to make it divisible by 12, but it incurs new problems. Details listed below. By the way, to disable GPU on Mac, we can set the environment variable JAX_PLATFORMS=cpu
InconclusiveDimensionOperation Traceback (most recent call last) /Users/cjgou/cjgou/vbjax-example/examples.ipynb Cell 5 line 6
4 log_ks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j]
5 pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy()
----> 6 pars = pars.reshape((3, vb.cores))
7 result = run_batches(pars)
8 pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7)
As you mentioned, a third of testing cases failed on jax-metal. I tried it again with the latest jax-metal 0.0.4, the same, many tests failed. So, I switched to CPU, the default is GPU, to disable the GPU, we need to set the environment variable JAX_PLATFORMS=cpu.
There are some small issues:
InconclusiveDimensionOperation Traceback (most recent call last)
/Users/cjgou/cjgou/vbjax-example/examples.ipynb Cell 5 line 6
4 log_ks, etas = np.mgrid[-9.0:0.0:16j, -4.0:-6.0:32j]
5 pars = np.c_[np.exp(log_ks.ravel()),np.ones(512)*sig_i, etas.ravel()].T.copy()
----> 6 pars = pars.reshape((3, vb.cores))
7 result = run_batches(pars)
8 pl.imshow(result.reshape((16, 32)), vmin=0.2, vmax=0.7)
File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:145, in _reshape(a, order, *args)
143 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
144 if order == "C":
--> 145 return lax.reshape(a, newshape, None)
146 elif order == "F":
147 dims = list(range(a.ndim)[::-1])
File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/lax/lax.py:857, in reshape(operand, new_sizes, dimensions)
854 else:
855 dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)
--> 857 return reshape_p.bind(
858 operand, *dyn_shape, new_sizes=tuple(static_new_sizes),
859 dimensions=None if dims is None or same_dims else dims)
File ~/.virtualenvs/vbjax/lib/python3.10/site-packages/jax/_src/core.py:380, in Primitive.bind(self, *args, **params)
377 def bind(self, *args, **params):
...
1850 if sz1 % sz2:
-> 1851 raise InconclusiveDimensionOperation(f"Cannot divide evenly the sizes of shapes {tuple(s1)} and {tuple(s2)}")
1852 return sz1 // sz2
InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (3, 512) and (3, 12)
The text was updated successfully, but these errors were encountered: