Skip to content

Commit

Permalink
Update environment.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ztqakita committed Oct 31, 2023
1 parent 0e35936 commit 9f0c90b
Showing 1 changed file with 22 additions and 22 deletions.
44 changes: 22 additions & 22 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class environment(_DecoratorContextManager):
def __init__(
self,
mode: modes.Mode = None,
scaling: scales.Scaling = None,
membrane_scaling: scales.Scaling = None,
dt: float = None,
x64: bool = None,
complex_: type = None,
Expand All @@ -174,9 +174,9 @@ def __init__(
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
self.old_mode = get_mode()

if scaling is not None:
assert isinstance(scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
self.old_scaling = get_membrane_scaling()
if membrane_scaling is not None:
assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
self.old_membrane_scaling = get_membrane_scaling()

if x64 is not None:
assert isinstance(x64, bool), f'"x64" must be a bool.'
Expand All @@ -200,7 +200,7 @@ def __init__(

self.dt = dt
self.mode = mode
self.scaling = scaling
self.membrane_scaling = membrane_scaling
self.x64 = x64
self.complex_ = complex_
self.float_ = float_
Expand All @@ -210,7 +210,7 @@ def __init__(
def __enter__(self) -> 'environment':
if self.dt is not None: set_dt(self.dt)
if self.mode is not None: set_mode(self.mode)
if self.scaling is not None: set_membrane_scaling(self.scaling)
if self.membrane_scaling is not None: set_membrane_scaling(self.membrane_scaling)
if self.x64 is not None: set_x64(self.x64)
if self.float_ is not None: set_float(self.float_)
if self.int_ is not None: set_int(self.int_)
Expand All @@ -221,7 +221,7 @@ def __enter__(self) -> 'environment':
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.dt is not None: set_dt(self.old_dt)
if self.mode is not None: set_mode(self.old_mode)
if self.scaling is not None: set_membrane_scaling(self.old_scaling)
if self.membrane_scaling is not None: set_membrane_scaling(self.old_membrane_scaling)
if self.x64 is not None: set_x64(self.old_x64)
if self.int_ is not None: set_int(self.old_int)
if self.float_ is not None: set_float(self.old_float)
Expand All @@ -231,7 +231,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
def clone(self):
return self.__class__(dt=self.dt,
mode=self.mode,
scaling=self.scaling,
membrane_scaling=self.membrane_scaling,
x64=self.x64,
bool_=self.bool_,
complex_=self.complex_,
Expand Down Expand Up @@ -263,15 +263,15 @@ def __init__(
int_: type = None,
bool_: type = None,
batch_size: int = 1,
scaling: scales.Scaling = None,
membrane_scaling: scales.Scaling = None,
):
super().__init__(dt=dt,
x64=x64,
complex_=complex_,
float_=float_,
int_=int_,
bool_=bool_,
scaling=scaling,
membrane_scaling=membrane_scaling,
mode=modes.TrainingMode(batch_size))


Expand All @@ -297,7 +297,7 @@ def __init__(
int_: type = None,
bool_: type = None,
batch_size: int = 1,
scaling: scales.Scaling = None,
membrane_scaling: scales.Scaling = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -306,12 +306,12 @@ def __init__(
int_=int_,
bool_=bool_,
mode=modes.BatchingMode(batch_size),
scaling=scaling)
membrane_scaling=membrane_scaling)


def set(
mode: modes.Mode = None,
scaling: scales.Scaling = None,
membrane_scaling: scales.Scaling = None,
dt: float = None,
x64: bool = None,
complex_: type = None,
Expand All @@ -325,8 +325,8 @@ def set(
----------
mode: Mode
The computing mode.
scaling: Scaling
The numerical scaling.
membrane_scaling: Scaling
The numerical membrane_scaling.
dt: float
The numerical integration precision.
x64: bool
Expand All @@ -348,9 +348,9 @@ def set(
assert isinstance(mode, modes.Mode), f'"mode" must a {modes.Mode}.'
set_mode(mode)

if scaling is not None:
assert isinstance(scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
set_membrane_scaling(scaling)
if membrane_scaling is not None:
assert isinstance(membrane_scaling, scales.Scaling), f'"membrane_scaling" must a {scales.Scaling}.'
set_membrane_scaling(membrane_scaling)

if x64 is not None:
assert isinstance(x64, bool), f'"x64" must be a bool.'
Expand Down Expand Up @@ -575,20 +575,20 @@ def get_mode() -> modes.Mode:
return bm.mode


def set_membrane_scaling(scaling: scales.Scaling):
def set_membrane_scaling(membrane_scaling: scales.Scaling):
"""Set the default computing membrane_scaling.
Parameters
----------
scaling: Scaling
The instance of :py:class:`~.Scaling`.
"""
if not isinstance(scaling, scales.Scaling):
if not isinstance(membrane_scaling, scales.Scaling):
raise TypeError(f'Must be instance of brainpy.math.Scaling. '
f'But we got {type(scaling)}: {scaling}')
f'But we got {type(membrane_scaling)}: {membrane_scaling}')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['membrane_scaling'] = scaling
bm.__dict__['membrane_scaling'] = membrane_scaling


def get_membrane_scaling() -> scales.Scaling:
Expand Down

0 comments on commit 9f0c90b

Please sign in to comment.