From 805a69fa03418c77cc1376fbcd5b6d73890280e4 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 10:42:51 +0800 Subject: [PATCH] fix `numpy_func_return` setting --- brainpy/_src/math/environment.py | 17 ++++++++++------- brainpy/_src/math/tests/test_environment.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 brainpy/_src/math/tests/test_environment.py diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 1948f4a7..ebbb8b6a 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -169,7 +169,7 @@ def __init__( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, - numpy_func_return: bool = None, + numpy_func_return: str = None, ) -> None: super().__init__() @@ -210,7 +210,9 @@ def __init__( self.old_bp_object_as_pytree = defaults.bp_object_as_pytree if numpy_func_return is not None: - assert isinstance(numpy_func_return, bool), '"numpy_func_return" must be a bool.' + assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' + assert numpy_func_return in ['bp_array', 'jax_array'], \ + f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' self.old_numpy_func_return = defaults.numpy_func_return self.dt = dt @@ -288,7 +290,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, - numpy_func_return: bool = None, + numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, @@ -326,7 +328,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, - numpy_func_return: bool = None, + numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, @@ -350,7 +352,7 @@ def set( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, - numpy_func_return: bool = None, + numpy_func_return: str = None, ): """Set the default computation environment. @@ -374,8 +376,8 @@ def set( The bool data type. bp_object_as_pytree: bool Whether to register brainpy object as pytree. - numpy_func_return: bool - Whether to return brainpy array in all numpy functions. + numpy_func_return: str + The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' @@ -413,6 +415,7 @@ def set( defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree if numpy_func_return is not None: + assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' defaults.__dict__['numpy_func_return'] = numpy_func_return diff --git a/brainpy/_src/math/tests/test_environment.py b/brainpy/_src/math/tests/test_environment.py new file mode 100644 index 00000000..83315899 --- /dev/null +++ b/brainpy/_src/math/tests/test_environment.py @@ -0,0 +1,15 @@ +import unittest + +import jax + +import brainpy.math as bm + + +class TestEnvironment(unittest.TestCase): + def test_numpy_func_return(self): + with bm.environment(numpy_func_return='jax_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, jax.Array)) + with bm.environment(numpy_func_return='bp_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, bm.Array))