Skip to content

Commit

Permalink
Fix master branch ci
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 12, 2024
1 parent b06d80a commit 5a1a11a
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 46 deletions.
23 changes: 12 additions & 11 deletions brainpy/_src/dnn/interoperation_flax.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

import jax
import dataclasses
from typing import Dict
from typing import Dict, Tuple
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from brainpy import math as bm
Expand Down Expand Up @@ -77,16 +77,16 @@ class ToFlaxRNNCell(RNNCellBase):
model: DynamicalSystem
train_params: Dict[str, jax.Array] = dataclasses.field(init=False)

def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
if len(batch_dims) == 0:
def initialize_carry(self, rng, input_shape: Tuple[int, ...]):
batch_dims = input_shape[:-1]
if len(batch_dims) == 1:
batch_dims = 1
elif len(batch_dims) == 1:
batch_dims = batch_dims[0]
elif len(batch_dims) == 0:
batch_dims = None
else:
raise NotImplementedError

raise ValueError(f'Invalid input shape: {input_shape}')
_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
self.model.reset(batch_size=batch_dims)
self.model.reset(batch_dims)
return [_state_vars.dict(), 0, 0.]

def setup(self):
Expand Down Expand Up @@ -131,6 +131,9 @@ def __call__(self, carry, *inputs):
# carray and output
return [_state_vars.dict(), i + 1, t + share.dt], out

@property
def num_feature_axes(self) -> int:
return 1

else:
class ToFlaxRNNCell(object):
Expand All @@ -140,6 +143,4 @@ def __init__(self, *args, **kwargs):
raise ModuleNotFoundError('"flax" is not installed, or importing "flax" has errors. Please check.')


ToFlax = ToFlaxRNNCell


ToFlax = ToFlaxRNNCell
44 changes: 44 additions & 0 deletions brainpy/_src/dnn/tests/test_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================


import pytest

pytest.skip('Skip this test because it is not implemented yet.',
allow_module_level=True)

import jax
import jax.numpy as jnp
import flax.linen as nn

import brainpy as bp
import brainpy.math as bm

bm.set_platform('cpu')
bm.set_mode(bm.training_mode)

cell = bp.dnn.ToFlaxRNNCell(bp.dyn.RNNCell(num_in=1, num_out=1, ))


class myRNN(nn.Module):
@nn.compact
def __call__(self, x): # x:(batch, time, features)
x = nn.RNN(cell)(x) # Use nn.RNN to unfold the recurrent cell
return x


def test_init():
model = myRNN()
model.init(jax.random.PRNGKey(0), jnp.ones([1, 10, 1])) # batch,time,feature
249 changes: 214 additions & 35 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,16 +679,213 @@ def jacfwd(
transform_setting=dict(holomorphic=holomorphic))


def _functional_hessian(
fun: Callable,
argnums: Optional[Union[int, Sequence[int]]] = None,
has_aux: bool = False,
holomorphic: bool = False,
):
return _jacfwd(
_jacrev(fun, argnums, has_aux=has_aux, holomorphic=holomorphic),
argnums, has_aux=has_aux, holomorphic=holomorphic
)


class GradientTransformPreserveTree(ObjectTransform):
"""
Object-oriented Automatic Differentiation Transformation in BrainPy.
"""

def __init__(
self,
target: Callable,
transform: Callable,

# variables and nodes
grad_vars: Dict[str, Variable],

# gradient setting
argnums: Optional[Union[int, Sequence[int]]],
return_value: bool,
has_aux: bool,
transform_setting: Optional[Dict[str, Any]] = None,

# other
name: str = None,
):
super().__init__(name=name)

# gradient variables
if grad_vars is None:
grad_vars = dict()
assert isinstance(grad_vars, dict), 'grad_vars should be a dict'
new_grad_vars = {}
for k, v in grad_vars.items():
assert isinstance(v, Variable)
new_grad_vars[k] = v
self._grad_vars = new_grad_vars

# parameters
if argnums is None and len(self._grad_vars) == 0:
argnums = 0
if argnums is None:
assert len(self._grad_vars) > 0
_argnums = 0
elif isinstance(argnums, int):
_argnums = (0, argnums + 2) if len(self._grad_vars) > 0 else (argnums + 2)
else:
_argnums = check.is_sequence(argnums, elem_type=int, allow_none=False)
_argnums = tuple(a + 2 for a in _argnums)
if len(self._grad_vars) > 0:
_argnums = (0,) + _argnums
self._nonvar_argnums = argnums
self._argnums = _argnums
self._return_value = return_value
self._has_aux = has_aux

# target
self.target = target

# transform
self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
self._grad_setting = dict() if transform_setting is None else transform_setting
if self._has_aux:
self._transform = self._grad_transform(
self._f_grad_with_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)
else:
self._transform = self._grad_transform(
self._f_grad_without_aux_to_transform,
argnums=self._argnums,
has_aux=True,
**self._grad_setting
)

def _f_grad_with_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k]._value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k]._value = v
# Users should return the auxiliary data like::
# >>> # 1. example of return one data
# >>> return scalar_loss, data
# >>> # 2. example of return multiple data
# >>> return scalar_loss, (data1, data2, ...)
outputs = self.target(*args, **kwargs)
# outputs: [0] is the value for gradient,
# [1] is other values for return
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), outputs[0])
return output0, (outputs, {k: v for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())

def _f_grad_without_aux_to_transform(self,
grad_values: dict,
dyn_values: dict,
*args,
**kwargs):
for k in dyn_values.keys():
self._dyn_vars[k].value = dyn_values[k]
for k, v in grad_values.items():
self._grad_vars[k].value = v
# Users should return the scalar value like this::
# >>> return scalar_loss
output = self.target(*args, **kwargs)
output0 = tree_map(lambda a: (a.value if isinstance(a, Array) else a), output)
return output0, (output, {k: v.value for k, v in self._grad_vars.items()}, self._dyn_vars.dict_data())

def __repr__(self):
name = self.__class__.__name__
f = tools.repr_object(self.target)
f = tools.repr_context(f, " " * (len(name) + 6))
format_ref = (f'{name}({self.name}, target={f}, \n' +
f'{" " * len(name)} num_of_grad_vars={len(self._grad_vars)}, \n'
f'{" " * len(name)} num_of_dyn_vars={len(self._dyn_vars)})')
return format_ref

def _return(self, rets):
grads, (outputs, new_grad_vs, new_dyn_vs) = rets
for k, v in new_grad_vs.items():
self._grad_vars[k].value = v
for k in new_dyn_vs.keys():
self._dyn_vars[k].value = new_dyn_vs[k]

# check returned grads
if len(self._grad_vars) > 0:
if self._nonvar_argnums is None:
pass
else:
arg_grads = grads[1] if isinstance(self._nonvar_argnums, int) else grads[1:]
grads = (grads[0], arg_grads)

# check returned value
if self._return_value:
# check aux
if self._has_aux:
return grads, outputs[0], outputs[1]
else:
return grads, outputs
else:
# check aux
if self._has_aux:
return grads, outputs[1]
else:
return grads

def __call__(self, *args, **kwargs):
if jax.config.jax_disable_jit: # disable JIT
rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)

elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(
self._transform,
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
{}, # dynamical variables
*args,
**kwargs
)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars.values()])
self._eval_dyn_vars = True

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
{k: v.value for k, v in self._grad_vars.items()}, # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)
return self._return(rets)


def hessian(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
has_aux: Optional[bool] = None,
holomorphic=False,

# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Hessian of ``func`` as a dense array.
Expand All @@ -705,42 +902,24 @@ def hessian(
Specifies which positional argument(s) to differentiate with respect to (default ``0``).
holomorphic : bool
Indicates whether ``fun`` is promised to be holomorphic. Default False.
return_value : bool
Whether return the hessian values.
dyn_vars : optional, ArrayType, sequence of ArrayType, dict
The dynamically changed variables used in ``func``.
.. deprecated:: 2.4.0
No longer need to provide ``dyn_vars``. This function is capable of automatically
collecting the dynamical variables used in the target ``func``.
child_objs: optional, BrainPyObject, sequnce, dict
.. versionadded:: 2.3.1
.. deprecated:: 2.4.0
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
has_aux : bool, optional
Indicates whether ``fun`` returns a pair where the first element is
considered the output of the mathematical function to be differentiated
and the second element is auxiliary data. Default False.
Returns
-------
obj: ObjectTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')

return jacfwd(jacrev(func,
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic),
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic,
return_value=return_value)
return GradientTransformPreserveTree(target=func,
transform=jax.hessian,
grad_vars=grad_vars,
argnums=argnums,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic),
return_value=False)


def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
Expand Down Expand Up @@ -960,4 +1139,4 @@ def _check_input_dtype_jacfwd(holomorphic: bool, x: Any) -> None:
f"a sub-dtype of np.floating), but got {aval.dtype.name}. "
"For holomorphic differentiation, pass holomorphic=True. "
"For differentiation of non-holomorphic functions involving "
"complex inputs or integer inputs, use jax.jvp directly.")
"complex inputs or integer inputs, use jax.jvp directly.")

0 comments on commit 5a1a11a

Please sign in to comment.