Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 12, 2024
1 parent 5a1a11a commit 211e974
Show file tree
Hide file tree
Showing 26 changed files with 422 additions and 1,895 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def find_fps_with_gd_method(
"""
# optimization settings
if optimizer is None:
optimizer = optim.Adam(lr=optim.ExponentialDecay(0.2, 1, 0.9999),
optimizer = optim.Adam(lr=optim.ExponentialDecayLR(0.2, 1, 0.9999),
beta1=0.9, beta2=0.999, eps=1e-8)
else:
if not isinstance(optimizer, optim.Optimizer):
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/analysis/lowdim/tests/test_bifurcation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def dw(w, t, V, a=0.7, b=0.8):
self.int_V = bp.odeint(dV, method=method)
self.int_w = bp.odeint(dw, method=method)

def update(self, tdi):
t, dt = tdi['t'], tdi['dt']
def update(self):
t, dt = bp.share['t'], bp.share['dt']
self.V.value = self.int_V(self.V, t, self.w, self.Iext, dt)
self.w.value = self.int_w(self.w, t, self.V, self.a, self.b, dt)
self.Iext[:] = 0.
Expand Down
10 changes: 7 additions & 3 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ def import_taichi(error_if_not_found=True):

if taichi is None:
return None
if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi
taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2]
minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \
_minimal_taichi_version[2]
if taichi_version >= minimal_taichi_version:
return taichi
else:
raise ModuleNotFoundError(taichi_install_info)


def raise_taichi_not_found(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/rates/tests/test_nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Test_NVAR(parameterized.TestCase):
def test_NVAR(self,mode):
bm.random.seed()
input=bm.random.randn(1,5)
layer=bp.dnn.NVAR(num_in=5,
layer=bp.dyn.NVAR(num_in=5,
delay=10,
mode=mode)
if mode in [bm.NonBatchingMode()]:
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/initialize/tests/test_decay_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# visualization
def mat_visualize(matrix, cmap=None):
if cmap is None:
cmap = plt.cm.get_cmap('coolwarm')
plt.cm.get_cmap('coolwarm')
cmap = plt.colormaps.get_cmap('coolwarm')
plt.colormaps.get_cmap('coolwarm')
im = plt.matshow(matrix, cmap=cmap)
plt.colorbar(mappable=im, shrink=0.8, aspect=15)
plt.show()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def dV(self, V, t, h, n, Iext):

return dVdt

def update(self, tdi):
t, dt = tdi.t, tdi.dt
def update(self):
t, dt = bp.share['t'], bp.share['dt']
V, h, n = self.integral(self.V, self.h, self.n, t, self.input, dt=dt)
self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
self.V.value = V
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_nodes():
A.pre = B
B.pre = A

net = bp.dyn.Network(A, B)
net = bp.Network(A, B)
abs_nodes = net.nodes(method='absolute')
rel_nodes = net.nodes(method='relative')
print()
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/math/object_transform/tests/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import brainpy as bp


class GABAa_without_Variable(bp.TwoEndConn):
class GABAa_without_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_without_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_neu_nodes_1():
assert len(neu.nodes(method='relative', include_self=False)) == 1


class GABAa_with_Variable(bp.TwoEndConn):
class GABAa_with_Variable(bp.synapses.TwoEndConn):
def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75.,
alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs):
super(GABAa_with_Variable, self).__init__(pre=pre, post=post, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/optimizers/tests/test_ModifyLr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train_data():
class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dnn.RNNCell(num_in, num_hidden, train_state=True)
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/train/back_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(

# optimizer
if optimizer is None:
lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
lr = optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975)
optimizer = optim.Adam(lr=lr)
self.optimizer: optim.Optimizer = optimizer
if len(self.optimizer.vars_to_train) == 0:
Expand Down
296 changes: 53 additions & 243 deletions docs/quickstart/analysis.ipynb

Large diffs are not rendered by default.

144 changes: 24 additions & 120 deletions docs/tutorial_advanced/advanced_lowdim_analysis.ipynb

Large diffs are not rendered by default.

48 changes: 24 additions & 24 deletions docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -35,33 +34,33 @@
"\n",
"import brainpy as bp\n",
"import brainpy.math as bm"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"bm.set(mode=bm.training_mode, dt=1.)\n",
"\n",
"bp.__version__"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"num_time = 10"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# the recurrent cell with trainable parameters\n",
"cell1 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((28, 28),\n",
Expand All @@ -72,13 +71,13 @@
" in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=(3, 3)))"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class CNN(nn.Module):\n",
" \"\"\"A simple CNN model.\"\"\"\n",
Expand All @@ -94,13 +93,13 @@
" x = nn.relu(x)\n",
" x = nn.Dense(features=10)(x)\n",
" return x"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def apply_model(state, images, labels):\n",
Expand All @@ -119,24 +118,24 @@
" (loss, logits), grads = grad_fn(state.params)\n",
" accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n",
" return grads, loss, accuracy"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@jax.jit\n",
"def update_model(state, grads):\n",
" return state.apply_gradients(grads=grads)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def train_epoch(state, train_ds, batch_size, rng):\n",
" \"\"\"Train for a single epoch.\"\"\"\n",
Expand All @@ -160,13 +159,13 @@
" train_loss = np.mean(epoch_loss)\n",
" train_accuracy = np.mean(epoch_accuracy)\n",
" return state, train_loss, train_accuracy"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def get_datasets():\n",
" \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n",
Expand All @@ -177,27 +176,27 @@
" train_ds['image'] = jnp.asarray(train_ds['image']) / 255.\n",
" test_ds['image'] = jnp.asarray(test_ds['image']) / 255.\n",
" return train_ds, test_ds"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def create_train_state(rng, config):\n",
" \"\"\"Creates initial `TrainState`.\"\"\"\n",
" cnn = CNN()\n",
" params = cnn.init(rng, jnp.ones([1, num_time, 28, 28, 1]))['params']\n",
" tx = optax.sgd(config.learning_rate, config.momentum)\n",
" return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def train_and_evaluate(config: ml_collections.ConfigDict,\n",
" workdir: str) -> train_state.TrainState:\n",
Expand Down Expand Up @@ -247,13 +246,13 @@
"\n",
" summary_writer.flush()\n",
" return state"
]
],
"outputs": []
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"config = ml_collections.ConfigDict()\n",
"\n",
Expand All @@ -263,7 +262,8 @@
"config.num_epochs = 10\n",
"\n",
"train_and_evaluate(config, './ckpt')"
]
],
"outputs": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 211e974

Please sign in to comment.