Skip to content

Commit

Permalink
enable to set numpy_func_return = True/False
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 2, 2024
1 parent 1a2deae commit fa6e787
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 312 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/_delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def register_entry(
delay_type = 'homo'
else:
delay_type = 'heter'
delay_step = bm.Array(delay_step)
delay_step = delay_step
elif callable(delay_step):
delay_step = delay_step(self.delay_target_shape)
delay_type = 'heter'
Expand Down
28 changes: 16 additions & 12 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@
_max = max


def _return(a):
return Array(a)


def fill_diagonal(a, val, inplace=True):
if a.ndim < 2:
raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}')
Expand All @@ -120,30 +124,30 @@ def fill_diagonal(a, val, inplace=True):


def zeros(shape, dtype=None):
return Array(jnp.zeros(shape, dtype=dtype))
return _return(jnp.zeros(shape, dtype=dtype))


def ones(shape, dtype=None):
return Array(jnp.ones(shape, dtype=dtype))
return _return(jnp.ones(shape, dtype=dtype))


def empty(shape, dtype=None):
return Array(jnp.zeros(shape, dtype=dtype))
return _return(jnp.zeros(shape, dtype=dtype))


def zeros_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
return Array(jnp.zeros_like(a, dtype=dtype, shape=shape))
return _return(jnp.zeros_like(a, dtype=dtype, shape=shape))


def ones_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
return Array(jnp.ones_like(a, dtype=dtype, shape=shape))
return _return(jnp.ones_like(a, dtype=dtype, shape=shape))


def empty_like(a, dtype=None, shape=None):
a = _as_jax_array_(a)
return Array(jnp.zeros_like(a, dtype=dtype, shape=shape))
return _return(jnp.zeros_like(a, dtype=dtype, shape=shape))


def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array:
Expand All @@ -155,7 +159,7 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array:
leaves = [_as_jax_array_(l) for l in leaves]
a = tree_unflatten(tree, leaves)
res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin)
return Array(res)
return _return(res)


def asarray(a, dtype=None, order=None):
Expand All @@ -167,29 +171,29 @@ def asarray(a, dtype=None, order=None):
leaves = [_as_jax_array_(l) for l in leaves]
arrays = tree_unflatten(tree, leaves)
res = jnp.asarray(a=arrays, dtype=dtype, order=order)
return Array(res)
return _return(res)


def arange(*args, **kwargs):
args = [_as_jax_array_(a) for a in args]
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
return Array(jnp.arange(*args, **kwargs))
return _return(jnp.arange(*args, **kwargs))


def linspace(*args, **kwargs):
args = [_as_jax_array_(a) for a in args]
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
res = jnp.linspace(*args, **kwargs)
if isinstance(res, tuple):
return Array(res[0]), res[1]
return _return(res[0]), res[1]
else:
return Array(res)
return _return(res)


def logspace(*args, **kwargs):
args = [_as_jax_array_(a) for a in args]
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
return Array(jnp.logspace(*args, **kwargs))
return _return(jnp.logspace(*args, **kwargs))


def asanyarray(a, dtype=None, order=None):
Expand Down
56 changes: 28 additions & 28 deletions brainpy/_src/math/compat_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,13 @@ def segment_sum(data: Union[Array, jnp.ndarray],
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return Array(jax.ops.segment_sum(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))
return _return(jax.ops.segment_sum(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))


def segment_prod(data: Union[Array, jnp.ndarray],
Expand Down Expand Up @@ -311,13 +311,13 @@ def segment_prod(data: Union[Array, jnp.ndarray],
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return Array(jax.ops.segment_prod(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))
return _return(jax.ops.segment_prod(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))


def segment_max(data: Union[Array, jnp.ndarray],
Expand Down Expand Up @@ -363,13 +363,13 @@ def segment_max(data: Union[Array, jnp.ndarray],
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return Array(jax.ops.segment_max(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))
return _return(jax.ops.segment_max(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))


def segment_min(data: Union[Array, jnp.ndarray],
Expand Down Expand Up @@ -415,13 +415,13 @@ def segment_min(data: Union[Array, jnp.ndarray],
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
"""
return Array(jax.ops.segment_min(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))
return _return(jax.ops.segment_min(as_jax(data),
as_jax(segment_ids),
num_segments,
indices_are_sorted,
unique_indices,
bucket_size,
mode))


def cast(x, dtype):
Expand Down
21 changes: 13 additions & 8 deletions brainpy/_src/math/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,37 @@
# Default computation mode.
mode = NonBatchingMode()

# '''Default computation mode.'''
# Default computation mode.
membrane_scaling = IdScaling()

# '''Default time step.'''
# Default time step.
dt = 0.1

# '''Default bool data type.'''
# Default bool data type.
bool_ = jnp.bool_

# '''Default integer data type.'''
# Default integer data type.
int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32

# '''Default float data type.'''
# Default float data type.
float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32

# '''Default complex data type.'''
# Default complex data type.
complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64

# register brainpy object as pytree
bp_object_as_pytree = False


# default return array type
numpy_func_return = 'bp_array' # 'bp_array','jax_array'


if ti is not None:
# '''Default integer data type in Taichi.'''
# Default integer data type in Taichi.
ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32

# '''Default float data type in Taichi.'''
# Default float data type in Taichi.
ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32

else:
Expand Down
25 changes: 22 additions & 3 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -208,6 +209,10 @@ def __init__(
assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.'
self.old_bp_object_as_pytree = defaults.bp_object_as_pytree

if numpy_func_return is not None:
assert isinstance(numpy_func_return, bool), '"numpy_func_return" must be a bool.'
self.old_numpy_func_return = defaults.numpy_func_return

self.dt = dt
self.mode = mode
self.membrane_scaling = membrane_scaling
Expand All @@ -217,6 +222,7 @@ def __init__(
self.int_ = int_
self.bool_ = bool_
self.bp_object_as_pytree = bp_object_as_pytree
self.numpy_func_return = numpy_func_return

def __enter__(self) -> 'environment':
if self.dt is not None: set_dt(self.dt)
Expand All @@ -228,6 +234,7 @@ def __enter__(self) -> 'environment':
if self.complex_ is not None: set_complex(self.complex_)
if self.bool_ is not None: set_bool(self.bool_)
if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree
if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.numpy_func_return
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
Expand All @@ -240,6 +247,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.complex_ is not None: set_complex(self.old_complex)
if self.bool_ is not None: set_bool(self.old_bool)
if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree
if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.old_numpy_func_return

def clone(self):
return self.__class__(dt=self.dt,
Expand All @@ -250,7 +258,8 @@ def clone(self):
complex_=self.complex_,
float_=self.float_,
int_=self.int_,
bp_object_as_pytree=self.bp_object_as_pytree)
bp_object_as_pytree=self.bp_object_as_pytree,
numpy_func_return=self.numpy_func_return)

def __eq__(self, other):
return id(self) == id(other)
Expand Down Expand Up @@ -279,6 +288,7 @@ def __init__(
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -288,7 +298,8 @@ def __init__(
bool_=bool_,
membrane_scaling=membrane_scaling,
mode=modes.TrainingMode(batch_size),
bp_object_as_pytree=bp_object_as_pytree)
bp_object_as_pytree=bp_object_as_pytree,
numpy_func_return=numpy_func_return)


class batching_environment(environment):
Expand All @@ -315,6 +326,7 @@ def __init__(
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -324,7 +336,8 @@ def __init__(
bool_=bool_,
mode=modes.BatchingMode(batch_size),
membrane_scaling=membrane_scaling,
bp_object_as_pytree=bp_object_as_pytree)
bp_object_as_pytree=bp_object_as_pytree,
numpy_func_return=numpy_func_return)


def set(
Expand All @@ -337,6 +350,7 @@ def set(
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
numpy_func_return: bool = None,
):
"""Set the default computation environment.
Expand All @@ -360,6 +374,8 @@ def set(
The bool data type.
bp_object_as_pytree: bool
Whether to register brainpy object as pytree.
numpy_func_return: bool
Whether to return brainpy array in all numpy functions.
"""
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
Expand Down Expand Up @@ -396,6 +412,9 @@ def set(
if bp_object_as_pytree is not None:
defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree

if numpy_func_return is not None:
defaults.__dict__['numpy_func_return'] = numpy_func_return


set_environment = set

Expand Down
Loading

0 comments on commit fa6e787

Please sign in to comment.