Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue #661 #662

Merged
merged 3 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 213 additions & 34 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,213 @@ def jacfwd(
transform_setting=dict(holomorphic=holomorphic))


def _functional_hessian(
fun: Callable,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: bool = False,
holomorphic: bool = False,
):
return _jacfwd(
_jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
argnums, has_aux=has_aux, holomorphic=holomorphic
)


class GradientTransformPreserveTree(ObjectTransform):
"""
Object-oriented Automatic Differentiation Transformation in BrainPy.
"""

def __init__(
self,
target: Callable,
transform: Callable,

# variables and nodes
grad_vars: Dict[str, Variable],

# gradient setting
argnums: Optional[Union[int, Sequence[int]]],
return_value: bool,
has_aux: bool,
transform_setting: Optional[Dict[str, Any]] = None,

# other
name: str = None,
):
super().__init__(name=name)

# gradient variables
if grad_vars is None:
grad_vars = dict()
assert isinstance(grad_vars, dict), 'grad_vars should be a dict'
new_grad_vars = {}
for k, v in grad_vars.items():
assert isinstance(v, Variable)
new_grad_vars[k] = v
self._grad_vars = new_grad_vars

# parameters
if argnums is None and len(self._grad_vars) == 0:
argnums = 0
if argnums is None:
assert len(self._grad_vars) > 0
_argnums = 0
elif isinstance(argnums, int):
_argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2)
else:
_argnums = check.is_sequence(argnums, elem_type=int, allow_none=False)
_argnums = tuple(a + 2 for a in _argnums)
if len(self._grad_vars) > 0:
_argnums = (0,) + _argnums
self._nonvar_argnums = argnums
self._argnums = _argnums
self._return_value = return_value
self._has_aux = has_aux

# target
self.target = target

# transform
self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
self._grad_setting = dict() if transform_setting is None else transform_setting
if self._has_aux:
self._transform = self._grad_transform(
self._f_grad_with_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
else:
self._transform = self._grad_transform(
self._f_grad_without_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)

def _f_grad_with_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k]._value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k]._value = v
# Users should return the auxiliary data like::
# >>> # 1. example of return one data
# >>> return scalar_loss, data
# >>> # 2. example of return multiple data
# >>> return scalar_loss, (data1, data2, ...)
outputs = self.target(*args, **kwargs)
# outputs: [0] is the value for gradient,
# [1] is other values for return
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0])
return output0, (outputs, {k: v for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())

def _f_grad_without_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k].value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k].value = v
# Users should return the scalar value like this::
# >>> return scalar_loss
output = self.target(*args, **kwargs)
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output)
return output0, (output, {k: v.value for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())

def __repr__(self):
name = self.__class__.__name__
f = tools.repr_object(self.target)
f = tools.repr_context(f, " " * (len(name) + 6))
format_ref = (f'{name}({self.name}, target={f}, \n' +
f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})')
return format_ref

def _return(self, rets):
grads, (outputs, new_grad_vs, new_dyn_vs) = rets
for k, v in new_grad_vs.items():
self._grad_vars[k].value = v
for k in new_dyn_vs.keys():
self._dyn_vars[k].value = new_dyn_vs[k]

# check returned grads
if len(self._grad_vars) > 0:
if self._nonvar_argnums is None:
pass
else:
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
grads = (grads[0], arg_grads)

# check returned value
if self._return_value:
# check aux
if self._has_aux:
return grads, outputs[0], outputs[1]
else:
return grads, outputs
else:
# check aux
if self._has_aux:
return grads, outputs[1]
else:
return grads

def __call__(self, *args, **kwargs):
if jax.config.jax_disable_jit: # disable JIT
rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)

elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(
self._transform,
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars.values()])
self._eval_dyn_vars = True

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)


def hessian(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
has_aux: Optional[bool] = None,
holomorphic=False,

# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Hessian of ``func`` as a dense array.

Expand All @@ -705,42 +902,24 @@ def hessian(
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
holomorphic : bool
Indicates whether ``fun`` is promised to be holomorphic. Default False.
return_value : bool
Whether return the hessian values.
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.

.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict

.. versionadded:: 2.3.1

.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
has_aux : bool, optional
Indicates whether ``fun`` returns a pair where the first element is
considered the output of the mathematical function to be differentiated
and the second element is auxiliary data. Default False.

Returns
-------
obj: ObjectTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')

return jacfwd(jacrev(func,
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic),
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic,
return_value=return_value)
return GradientTransformPreserveTree(target=func,
transform=jax.hessian,
grad_vars=grad_vars,
argnums=argnums,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic),
return_value=False)


def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
Expand Down
50 changes: 50 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,3 +1171,53 @@ def f(a, b):
self.assertTrue(file.read().strip() == expect_res.strip())



class TestHessian(unittest.TestCase):
def test_hessian5(self):
bm.set_mode(bm.training_mode)

class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

model = RNN(1, 2)
data_x = bm.random.rand(1, 1000, 1)
data_y = data_x + bm.random.randn(1, 1000, 1)

bp.reset_state(model, 1)
losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
hess_matrix = losshess(data_x, data_y)

weights = model.train_vars().unique()

# define the loss function
def loss_func_for_jax(weight_vals, inputs, targets):
for k, v in weight_vals.items():
weights[k].value = v
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

bp.reset_state(model, 1)
jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)

for k, v in hess_matrix.items():
for kk, vv in v.items():
self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))

bm.clear_buffer_memory()


14 changes: 7 additions & 7 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def test1(self):
hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)
tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree_unflatten(tree, leaves))
print(jax.tree.unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -281,13 +281,13 @@ def not_close(x, y):
def all_close(x, y):
assert bm.allclose(x, y)

jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)



Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ tqdm
pathos
taichi
numba
braincore
braintools


# test requirements
Expand Down
Loading