From 4e3151e8ce36b803404b1e87b02309649817e516 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 10 May 2024 18:45:00 +0800 Subject: [PATCH] Update autograd.py --- .../_src/math/object_transform/autograd.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 2e5e103c..4016cf29 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -884,8 +884,12 @@ def hessian( func: Callable, grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, - has_aux: Optional[bool] = None, + return_value: bool = False, holomorphic=False, + + # deprecated + dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, + child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> ObjectTransform: """Hessian of ``func`` as a dense array. @@ -912,14 +916,29 @@ def hessian( obj: ObjectTransform The transformed object. """ + child_objs = check.is_all_objs(child_objs, out_as='dict') + dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - return GradientTransformPreserveTree(target=func, - transform=jax.hessian, - grad_vars=grad_vars, - argnums=argnums, - has_aux=False if has_aux is None else has_aux, - transform_setting=dict(holomorphic=holomorphic), - return_value=False) + return jacfwd(jacrev(func, + dyn_vars=dyn_vars, + child_objs=child_objs, + grad_vars=grad_vars, + argnums=argnums, + holomorphic=holomorphic), + dyn_vars=dyn_vars, + child_objs=child_objs, + grad_vars=grad_vars, + argnums=argnums, + holomorphic=holomorphic, + return_value=return_value) + + # return GradientTransformPreserveTree(target=func, + # transform=jax.hessian, + # grad_vars=grad_vars, + # argnums=argnums, + # has_aux=False if has_aux is None else has_aux, + # transform_setting=dict(holomorphic=holomorphic), + # return_value=False) def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):