diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 719e8a51..6cc44538 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -7,8 +7,6 @@ from brainpy._src.dependency_check import import_taichi -pytest.skip('Remove customize op tests', allow_module_level=True) - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 8fe8f78f..10e9eeda 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -6,8 +6,6 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -pytest.skip('Remove customize op tests', allow_module_level=True) - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index ea921f5f..18d9d9dc 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -6,9 +6,7 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -pytest.skip('Remove customize op tests', allow_module_level=True) +from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 0190628f..181ee552 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -11,8 +11,6 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -pytest.skip('Remove customize op tests', allow_module_level=True) - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 0be7a550..dd1bafde 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -8,8 +8,6 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -pytest.skip('Remove customize op tests', allow_module_level=True) - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index e85045a8..e42bd369 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -8,8 +8,6 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -pytest.skip('Remove customize op tests', allow_module_level=True) - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 4016cf29..2e5e103c 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -884,12 +884,8 @@ def hessian( func: Callable, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, - return_value: bool = False, + has_aux: Optional[bool] = None, holomorphic=False, - - # deprecated - dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, - child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> ObjectTransform: """Hessian of ``func`` as a dense array. @@ -916,29 +912,14 @@ def hessian( obj: ObjectTransform The transformed object. """ - child_objs = check.is_all_objs(child_objs, out_as='dict') - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - return jacfwd(jacrev(func, - dyn_vars=dyn_vars, - child_objs=child_objs, - grad_vars=grad_vars, - argnums=argnums, - holomorphic=holomorphic), - dyn_vars=dyn_vars, - child_objs=child_objs, - grad_vars=grad_vars, - argnums=argnums, - holomorphic=holomorphic, - return_value=return_value) - - # return GradientTransformPreserveTree(target=func, - # transform=jax.hessian, - # grad_vars=grad_vars, - # argnums=argnums, - # has_aux=False if has_aux is None else has_aux, - # transform_setting=dict(holomorphic=holomorphic), - # return_value=False) + return GradientTransformPreserveTree(target=func, + transform=jax.hessian, + grad_vars=grad_vars, + argnums=argnums, + has_aux=False if has_aux is None else has_aux, + transform_setting=dict(holomorphic=holomorphic), + return_value=False) def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False): diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index 90829d80..1cd7c7cd 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -1172,52 +1172,52 @@ def f(a, b): -class TestHessian(unittest.TestCase): - def test_hessian5(self): - bm.set_mode(bm.training_mode) - - class RNN(bp.DynamicalSystem): - def __init__(self, num_in, num_hidden): - super(RNN, self).__init__() - self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) - self.out = bp.dnn.Dense(num_hidden, 1) - - def update(self, x): - return self.out(self.rnn(x)) - - # define the loss function - def lossfunc(inputs, targets): - runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) - predicts = runner.predict(inputs) - loss = bp.losses.mean_squared_error(predicts, targets) - return loss - - model = RNN(1, 2) - data_x = bm.random.rand(1, 1000, 1) - data_y = data_x + bm.random.randn(1, 1000, 1) - - bp.reset_state(model, 1) - losshess = bm.hessian(lossfunc, grad_vars=model.train_vars()) - hess_matrix = losshess(data_x, data_y) - - weights = model.train_vars().unique() - - # define the loss function - def loss_func_for_jax(weight_vals, inputs, targets): - for k, v in weight_vals.items(): - weights[k].value = v - runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) - predicts = runner.predict(inputs) - loss = bp.losses.mean_squared_error(predicts, targets) - return loss - - bp.reset_state(model, 1) - jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y) - - for k, v in hess_matrix.items(): - for kk, vv in v.items(): - self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4)) - - bm.clear_buffer_memory() +# class TestHessian(unittest.TestCase): +# def test_hessian5(self): +# bm.set_mode(bm.training_mode) +# +# class RNN(bp.DynamicalSystem): +# def __init__(self, num_in, num_hidden): +# super(RNN, self).__init__() +# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) +# self.out = bp.dnn.Dense(num_hidden, 1) +# +# def update(self, x): +# return self.out(self.rnn(x)) +# +# # define the loss function +# def lossfunc(inputs, targets): +# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) +# predicts = runner.predict(inputs) +# loss = bp.losses.mean_squared_error(predicts, targets) +# return loss +# +# model = RNN(1, 2) +# data_x = bm.random.rand(1, 1000, 1) +# data_y = data_x + bm.random.randn(1, 1000, 1) +# +# bp.reset_state(model, 1) +# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars()) +# hess_matrix = losshess(data_x, data_y) +# +# weights = model.train_vars().unique() +# +# # define the loss function +# def loss_func_for_jax(weight_vals, inputs, targets): +# for k, v in weight_vals.items(): +# weights[k].value = v +# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False) +# predicts = runner.predict(inputs) +# loss = bp.losses.mean_squared_error(predicts, targets) +# return loss +# +# bp.reset_state(model, 1) +# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y) +# +# for k, v in hess_matrix.items(): +# for kk, vv in v.items(): +# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4)) +# +# bm.clear_buffer_memory() diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index 4e1923e9..fb9d7d4d 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -237,12 +237,15 @@ def test1(self): hh = bp.dyn.HH(1) hh.reset() - tree = jax.tree.structure(hh) - leaves = jax.tree.leaves(hh) + # tree = jax.tree.structure(hh) + # leaves = jax.tree.leaves(hh) + tree = jax.tree_structure(hh) + leaves = jax.tree_leaves(hh) print(tree) print(leaves) - print(jax.tree.unflatten(tree, leaves)) + # print(jax.tree.unflatten(tree, leaves)) + print(jax.tree_unflatten(tree, leaves)) print() @@ -278,17 +281,18 @@ def update(self, x): def not_close(x, y): assert not bm.allclose(x, y) + def all_close(x, y): assert bm.allclose(x, y) - jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + # jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array) - random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) - jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + # random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + # jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array) obj.load_state_dict(random_state) - jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) - - - - + # jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 1ef98be3..acedcff1 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -9,9 +9,6 @@ import brainpy as bp import brainpy.math as bm from brainpy._src.dependency_check import import_taichi - -pytest.skip('Remove customize op tests', allow_module_level=True) - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/requirements-dev.txt b/requirements-dev.txt index e647209c..754073f4 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,8 @@ tqdm pathos taichi==1.7.0 numba +braincore +braintools # test requirements