From 6af992e73f6ad0935b2d7ef1d677e37a5f99c79a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 18 Feb 2024 21:38:51 +0800 Subject: [PATCH] enable brainpy object as pytree --- brainpy/_src/math/defaults.py | 3 +++ brainpy/_src/math/environment.py | 25 ++++++++++++++++--- brainpy/_src/math/object_transform/base.py | 15 +++++++---- .../math/object_transform/tests/test_base.py | 19 ++++++++++++++ 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 19aca92cf..20cae197e 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -36,3 +36,6 @@ # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 +# register brainpy object as pytree +bp_object_as_pytree = False + diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 1c8b98a3b..e91eca0b8 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -166,6 +166,7 @@ def __init__( float_: type = None, int_: type = None, bool_: type = None, + bp_object_as_pytree: bool = None, ) -> None: super().__init__() @@ -201,6 +202,10 @@ def __init__( assert isinstance(complex_, type), '"complex_" must a type.' self.old_complex = get_complex() + if bp_object_as_pytree is not None: + assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' + self.old_bp_object_as_pytree = defaults.bp_object_as_pytree + self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling @@ -209,6 +214,7 @@ def __init__( self.float_ = float_ self.int_ = int_ self.bool_ = bool_ + self.bp_object_as_pytree = bp_object_as_pytree def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) @@ -219,6 +225,7 @@ def __enter__(self) -> 'environment': if self.int_ is not None: set_int(self.int_) if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) + if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @@ -230,6 +237,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.float_ is not None: set_float(self.old_float) if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) + if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree def clone(self): return self.__class__(dt=self.dt, @@ -239,7 +247,8 @@ def clone(self): bool_=self.bool_, complex_=self.complex_, float_=self.float_, - int_=self.int_) + int_=self.int_, + bp_object_as_pytree=self.bp_object_as_pytree) def __eq__(self, other): return id(self) == id(other) @@ -267,6 +276,7 @@ def __init__( bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, + bp_object_as_pytree: bool = None, ): super().__init__(dt=dt, x64=x64, @@ -275,7 +285,8 @@ def __init__( int_=int_, bool_=bool_, membrane_scaling=membrane_scaling, - mode=modes.TrainingMode(batch_size)) + mode=modes.TrainingMode(batch_size), + bp_object_as_pytree=bp_object_as_pytree) class batching_environment(environment): @@ -301,6 +312,7 @@ def __init__( bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, + bp_object_as_pytree: bool = None, ): super().__init__(dt=dt, x64=x64, @@ -309,7 +321,8 @@ def __init__( int_=int_, bool_=bool_, mode=modes.BatchingMode(batch_size), - membrane_scaling=membrane_scaling) + membrane_scaling=membrane_scaling, + bp_object_as_pytree=bp_object_as_pytree) def set( @@ -321,6 +334,7 @@ def set( float_: type = None, int_: type = None, bool_: type = None, + bp_object_as_pytree: bool = None, ): """Set the default computation environment. @@ -342,6 +356,8 @@ def set( The integer data type. bool_ The bool data type. + bp_object_as_pytree: bool + Whether to register brainpy object as pytree. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' @@ -375,6 +391,9 @@ def set( assert isinstance(complex_, type), '"complex_" must a type.' set_complex(complex_) + if bp_object_as_pytree is not None: + defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree + set_environment = set diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index aaf053ae7..53346a7d1 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,23 +6,24 @@ """ import numbers -import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional import jax import numpy as np +from jax._src.tree_util import _registry +from jax.tree_util import register_pytree_node_class -from brainpy import errors +from brainpy._src.math.modes import Mode from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) -from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS +from brainpy._src.math import defaults variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) @@ -89,6 +90,10 @@ class BrainPyObject(object): def __init__(self, name=None): super().__init__() + if defaults.bp_object_as_pytree: + if self.__class__ not in _registry: + register_pytree_node_class(self.__class__) + # check whether the object has a unique name. self._name = None self._name = self.unique_name(name=name) @@ -217,8 +222,8 @@ def tree_flatten(self): static_names = [] static_values = [] for k, v in self.__dict__.items(): - # if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)): - if isinstance(v, (BrainPyObject, Variable)): + if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)): + # if isinstance(v, (BrainPyObject, Variable)): dynamic_names.append(k) dynamic_values.append(v) else: diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index 2d640b3b5..c6f8f90d4 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -231,3 +231,22 @@ def f1(): self.assertTrue(obj.vs['b'] == 12.) self.assertTrue(bm.allclose(obj.vs['c'], bm.ones(10) * 11.)) + +class TestRegisterBPObjectAsPyTree(unittest.TestCase): + def test1(self): + bm.set(bp_object_as_pytree=True) + + hh = bp.dyn.HH(1) + hh.reset() + + tree = jax.tree_structure(hh) + leaves = jax.tree_leaves(hh) + + print(tree) + print(leaves) + print(jax.tree_unflatten(tree, leaves)) + print() + + + +