Skip to content

Commit

Permalink
fix numpy_func_return setting
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 3, 2024
1 parent 8b122c1 commit 805a69f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
17 changes: 10 additions & 7 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __init__(
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
numpy_func_return: str = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -210,7 +210,9 @@ def __init__(
self.old_bp_object_as_pytree = defaults.bp_object_as_pytree

if numpy_func_return is not None:
assert isinstance(numpy_func_return, bool), '"numpy_func_return" must be a bool.'
assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.'
assert numpy_func_return in ['bp_array', 'jax_array'], \
f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.'
self.old_numpy_func_return = defaults.numpy_func_return

self.dt = dt
Expand Down Expand Up @@ -288,7 +290,7 @@ def __init__(
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
numpy_func_return: str = None,
):
super().__init__(dt=dt,
x64=x64,
Expand Down Expand Up @@ -326,7 +328,7 @@ def __init__(
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
numpy_func_return: str = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -350,7 +352,7 @@ def set(
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
numpy_func_return: str = None,
):
"""Set the default computation environment.
Expand All @@ -374,8 +376,8 @@ def set(
The bool data type.
bp_object_as_pytree: bool
Whether to register brainpy object as pytree.
numpy_func_return: bool
Whether to return brainpy array in all numpy functions.
numpy_func_return: str
The array to return in all numpy functions. Support 'bp_array' and 'jax_array'.
"""
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
Expand Down Expand Up @@ -413,6 +415,7 @@ def set(
defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree

if numpy_func_return is not None:
assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".'
defaults.__dict__['numpy_func_return'] = numpy_func_return


Expand Down
15 changes: 15 additions & 0 deletions brainpy/_src/math/tests/test_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import unittest

import jax

import brainpy.math as bm


class TestEnvironment(unittest.TestCase):
def test_numpy_func_return(self):
with bm.environment(numpy_func_return='jax_array'):
a = bm.random.randn(3, 3)
self.assertTrue(isinstance(a, jax.Array))
with bm.environment(numpy_func_return='bp_array'):
a = bm.random.randn(3, 3)
self.assertTrue(isinstance(a, bm.Array))

0 comments on commit 805a69f

Please sign in to comment.