From d795517491e85aa082cb202b003ac15f88b6e72a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 12 May 2024 16:32:29 +0800 Subject: [PATCH] Revert "fix issue #661 (#662)" This reverts commit 4bd18980c0aa011c024024653405f6376bc5262a. --- .../object_transform/tests/test_autograd.py | 94 +++++++++---------- .../math/object_transform/tests/test_base.py | 21 +++-- requirements-dev.txt | 2 +- 3 files changed, 62 insertions(+), 55 deletions(-) diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index 90829d80..1cd7c7cd 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -1172,52 +1172,52 @@ def f(a, b): -class TestHessian(unittest.TestCase): - def test_hessian5(self): - bm.set_mode(bm.training_mode) - - class RNN(bp.DynamicalSystem): - def __init__(self, num_in, num_hidden): - super(RNN, self).__init__() - self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) - self.out = bp.dnn.Dense(num_hidden, 1) - - def update(self, x): - return self.out(self.rnn(x)) - - # define the loss function - def lossfunc(inputs, targets): - runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) - predicts = runner.predict(inputs) - loss = bp.losses.mean_squared_error(predicts, targets) - return loss - - model = RNN(1, 2) - data_x = bm.random.rand(1, 1000, 1) - data_y = data_x + bm.random.randn(1, 1000, 1) - - bp.reset_state(model, 1) - losshess = bm.hessian(lossfunc, grad_vars=model.train_vars()) - hess_matrix = losshess(data_x, data_y) - - weights = model.train_vars().unique() - - # define the loss function - def loss_func_for_jax(weight_vals, inputs, targets): - for k, v in weight_vals.items(): - weights[k].value = v - runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) - predicts = runner.predict(inputs) - loss = bp.losses.mean_squared_error(predicts, targets) - return loss - - bp.reset_state(model, 1) - jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y) - - for k, v in hess_matrix.items(): - for kk, vv in v.items(): - self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4)) - - bm.clear_buffer_memory() +# class TestHessian(unittest.TestCase): +# def test_hessian5(self): +# bm.set_mode(bm.training_mode) +# +# class RNN(bp.DynamicalSystem): +# def __init__(self, num_in, num_hidden): +# super(RNN, self).__init__() +# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) +# self.out = bp.dnn.Dense(num_hidden, 1) +# +# def update(self, x): +# return self.out(self.rnn(x)) +# +# # define the loss function +# def lossfunc(inputs, targets): +# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) +# predicts = runner.predict(inputs) +# loss = bp.losses.mean_squared_error(predicts, targets) +# return loss +# +# model = RNN(1, 2) +# data_x = bm.random.rand(1, 1000, 1) +# data_y = data_x + bm.random.randn(1, 1000, 1) +# +# bp.reset_state(model, 1) +# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars()) +# hess_matrix = losshess(data_x, data_y) +# +# weights = model.train_vars().unique() +# +# # define the loss function +# def loss_func_for_jax(weight_vals, inputs, targets): +# for k, v in weight_vals.items(): +# weights[k].value = v +# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) +# predicts = runner.predict(inputs) +# loss = bp.losses.mean_squared_error(predicts, targets) +# return loss +# +# bp.reset_state(model, 1) +# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y) +# +# for k, v in hess_matrix.items(): +# for kk, vv in v.items(): +# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4)) +# +# bm.clear_buffer_memory() diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index 4e1923e9..a790945f 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -237,12 +237,15 @@ def test1(self): hh = bp.dyn.HH(1) hh.reset() - tree = jax.tree.structure(hh) - leaves = jax.tree.leaves(hh) + tree = jax.tree_structure(hh) + leaves = jax.tree_leaves(hh) + # tree = jax.tree.structure(hh) + # leaves = jax.tree.leaves(hh) print(tree) print(leaves) - print(jax.tree.unflatten(tree, leaves)) + print(jax.tree_unflatten(tree, leaves)) + # print(jax.tree.unflatten(tree, leaves)) print() @@ -281,13 +284,17 @@ def not_close(x, y): def all_close(x, y): assert bm.allclose(x, y) - jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + # jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) - random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) - jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + # random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + # jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) obj.load_state_dict(random_state) - jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + # jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) diff --git a/requirements-dev.txt b/requirements-dev.txt index 641f99fd..754073f4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ matplotlib msgpack tqdm pathos -taichi +taichi==1.7.0 numba braincore braintools