Skip to content

Commit

Permalink
enable brainpy object as pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 18, 2024
1 parent fb08523 commit 6af992e
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 8 deletions.
3 changes: 3 additions & 0 deletions brainpy/_src/math/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@
# '''Default complex data type.'''
complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64

# register brainpy object as pytree
bp_object_as_pytree = False

25 changes: 22 additions & 3 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
float_: type = None,
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -201,6 +202,10 @@ def __init__(
assert isinstance(complex_, type), '"complex_" must a type.'
self.old_complex = get_complex()

if bp_object_as_pytree is not None:
assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.'
self.old_bp_object_as_pytree = defaults.bp_object_as_pytree

self.dt = dt
self.mode = mode
self.membrane_scaling = membrane_scaling
Expand All @@ -209,6 +214,7 @@ def __init__(
self.float_ = float_
self.int_ = int_
self.bool_ = bool_
self.bp_object_as_pytree = bp_object_as_pytree

def __enter__(self) -> 'environment':
if self.dt is not None: set_dt(self.dt)
Expand All @@ -219,6 +225,7 @@ def __enter__(self) -> 'environment':
if self.int_ is not None: set_int(self.int_)
if self.complex_ is not None: set_complex(self.complex_)
if self.bool_ is not None: set_bool(self.bool_)
if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
Expand All @@ -230,6 +237,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.float_ is not None: set_float(self.old_float)
if self.complex_ is not None: set_complex(self.old_complex)
if self.bool_ is not None: set_bool(self.old_bool)
if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree

def clone(self):
return self.__class__(dt=self.dt,
Expand All @@ -239,7 +247,8 @@ def clone(self):
bool_=self.bool_,
complex_=self.complex_,
float_=self.float_,
int_=self.int_)
int_=self.int_,
bp_object_as_pytree=self.bp_object_as_pytree)

def __eq__(self, other):
return id(self) == id(other)
Expand Down Expand Up @@ -267,6 +276,7 @@ def __init__(
bool_: type = None,
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -275,7 +285,8 @@ def __init__(
int_=int_,
bool_=bool_,
membrane_scaling=membrane_scaling,
mode=modes.TrainingMode(batch_size))
mode=modes.TrainingMode(batch_size),
bp_object_as_pytree=bp_object_as_pytree)


class batching_environment(environment):
Expand All @@ -301,6 +312,7 @@ def __init__(
bool_: type = None,
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -309,7 +321,8 @@ def __init__(
int_=int_,
bool_=bool_,
mode=modes.BatchingMode(batch_size),
membrane_scaling=membrane_scaling)
membrane_scaling=membrane_scaling,
bp_object_as_pytree=bp_object_as_pytree)


def set(
Expand All @@ -321,6 +334,7 @@ def set(
float_: type = None,
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
):
"""Set the default computation environment.
Expand All @@ -342,6 +356,8 @@ def set(
The integer data type.
bool_
The bool data type.
bp_object_as_pytree: bool
Whether to register brainpy object as pytree.
"""
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
Expand Down Expand Up @@ -375,6 +391,9 @@ def set(
assert isinstance(complex_, type), '"complex_" must a type.'
set_complex(complex_)

if bp_object_as_pytree is not None:
defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree


set_environment = set

Expand Down
15 changes: 10 additions & 5 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
"""

import numbers
import os
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional

import jax
import numpy as np
from jax._src.tree_util import _registry
from jax.tree_util import register_pytree_node_class

from brainpy import errors
from brainpy._src.math.modes import Mode
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector)
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS
from brainpy._src.math import defaults

variable_ = None
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
Expand Down Expand Up @@ -89,6 +90,10 @@ class BrainPyObject(object):
def __init__(self, name=None):
super().__init__()

if defaults.bp_object_as_pytree:
if self.__class__ not in _registry:
register_pytree_node_class(self.__class__)

# check whether the object has a unique name.
self._name = None
self._name = self.unique_name(name=name)
Expand Down Expand Up @@ -217,8 +222,8 @@ def tree_flatten(self):
static_names = []
static_values = []
for k, v in self.__dict__.items():
# if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)):
if isinstance(v, (BrainPyObject, Variable)):
if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)):
# if isinstance(v, (BrainPyObject, Variable)):
dynamic_names.append(k)
dynamic_values.append(v)
else:
Expand Down
19 changes: 19 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,22 @@ def f1():
self.assertTrue(obj.vs['b'] == 12.)
self.assertTrue(bm.allclose(obj.vs['c'], bm.ones(10) * 11.))


class TestRegisterBPObjectAsPyTree(unittest.TestCase):
def test1(self):
bm.set(bp_object_as_pytree=True)

hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)

print(tree)
print(leaves)
print(jax.tree_unflatten(tree, leaves))
print()




0 comments on commit 6af992e

Please sign in to comment.