diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 3671755a..0401151a 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -792,6 +792,8 @@ def event_mv_prob_homo_taichi( """ events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) + print(weight.shape) + print(weight) weight = jnp.atleast_1d(as_jax(weight)) conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32)