Skip to content

Commit

Permalink
Merge pull request #534 from chaoming0625/updates
Browse files Browse the repository at this point in the history
[math] new abstract function for XLACustomOp, fix its bugs
  • Loading branch information
chaoming0625 authored Nov 5, 2023
2 parents 5a1e2d9 + 48f77ee commit c014976
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 66 deletions.
10 changes: 5 additions & 5 deletions brainpy/_src/dyn/neurons/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
scaling: Optional[bm.Scaling] = None,

spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
):
Expand All @@ -43,18 +43,18 @@ def __init__(
self.spk_reset = spk_reset
self.spk_fun = is_callable(spk_fun)
self.detach_spk = detach_spk
self._spk_type = spk_type
self._spk_dtype = spk_dtype
if scaling is None:
self.scaling = bm.get_membrane_scaling()
else:
self.scaling = scaling

@property
def spk_type(self):
if self._spk_type is None:
def spk_dtype(self):
if self._spk_dtype is None:
return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_
else:
return self._spk_type
return self._spk_dtype

def offset_scaling(self, x, bias=None, scale=None):
s = self.scaling.offset_scaling(x, bias=bias, scale=scale)
Expand Down
76 changes: 38 additions & 38 deletions brainpy/_src/dyn/neurons/lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand All @@ -99,7 +99,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)

Expand All @@ -124,7 +124,7 @@ def derivative(self, V, t, I):

def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -206,7 +206,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand All @@ -230,7 +230,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)

Expand All @@ -257,7 +257,7 @@ def derivative(self, V, t, I):

def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -399,7 +399,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
detach_spk: bool = False,
spk_reset: str = 'soft',
method: str = 'exp_auto',
Expand Down Expand Up @@ -429,7 +429,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down Expand Up @@ -673,7 +673,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand All @@ -699,7 +699,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)

Expand Down Expand Up @@ -730,7 +730,7 @@ def derivative(self, V, t, I):

def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
detach_spk: bool = False,
spk_reset: str = 'soft',
method: str = 'exp_auto',
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down Expand Up @@ -1343,7 +1343,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -1373,7 +1373,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)
# parameters
Expand Down Expand Up @@ -1416,7 +1416,7 @@ def derivative(self):
def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size))
self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -1672,7 +1672,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -1708,7 +1708,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down Expand Up @@ -1991,7 +1991,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand All @@ -2017,7 +2017,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)
# parameters
Expand Down Expand Up @@ -2046,7 +2046,7 @@ def derivative(self, V, t, I):

def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -2255,7 +2255,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -2287,7 +2287,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down Expand Up @@ -2554,7 +2554,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -2583,7 +2583,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)
# parameters
Expand Down Expand Up @@ -2624,7 +2624,7 @@ def derivative(self):
def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.init_variable(self._V_initializer, batch_size))
self.w = self.std_scaling(self.init_variable(self._w_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -2856,7 +2856,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -2891,7 +2891,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down Expand Up @@ -3201,7 +3201,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -3237,7 +3237,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)
# parameters
Expand Down Expand Up @@ -3291,7 +3291,7 @@ def reset_state(self, batch_size=None, **kwargs):
self.V_th = self.offset_scaling(self.init_variable(self._Vth_initializer, batch_size))
self.I1 = self.std_scaling(self.init_variable(self._I1_initializer, batch_size))
self.I2 = self.std_scaling(self.init_variable(self._I2_initializer, batch_size))
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -3581,7 +3581,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -3623,7 +3623,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down Expand Up @@ -3952,7 +3952,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -3982,7 +3982,7 @@ def __init__(
spk_fun=spk_fun,
detach_spk=detach_spk,
method=method,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,
scaling=scaling)
# parameters
Expand Down Expand Up @@ -4031,7 +4031,7 @@ def reset_state(self, batch_size=None, **kwargs):
self.V = self.offset_scaling(self.V)
self.u = self.offset_scaling(self.init_variable(self._u_initializer, batch_size), bias=self.b * self.scaling.bias,
scale=self.scaling.scale)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_type), batch_size)
self.spike = self.init_variable(partial(bm.zeros, dtype=self.spk_dtype), batch_size)

def update(self, x=None):
t = share.load('t')
Expand Down Expand Up @@ -4266,7 +4266,7 @@ def __init__(
keep_size: bool = False,
mode: Optional[bm.Mode] = None,
spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
spk_dtype: Any = None,
spk_reset: str = 'soft',
detach_spk: bool = False,
method: str = 'exp_auto',
Expand Down Expand Up @@ -4302,7 +4302,7 @@ def __init__(
sharding=sharding,
spk_fun=spk_fun,
detach_spk=detach_spk,
spk_type=spk_type,
spk_dtype=spk_dtype,
spk_reset=spk_reset,

init_var=False,
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,8 @@ def float(self): return jnp.asarray(self.value, dtype=jnp.float32)
def double(self): return jnp.asarray(self.value, dtype=jnp.float64)


setattr(Array, "__array_priority__", 100)

JaxArray = Array
ndarray = Array

Expand Down
Loading

0 comments on commit c014976

Please sign in to comment.