diff --git a/snntorch/_neurons/alpha.py b/snntorch/_neurons/alpha.py index 2d339b87..4b0b0c91 100644 --- a/snntorch/_neurons/alpha.py +++ b/snntorch/_neurons/alpha.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class Alpha(LIF): @@ -119,118 +119,86 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) self._alpha_cases() - if self.init_hidden: - self.syn_exc, self.syn_inh, self.mem = self.init_alpha() + self._init_mem() - # if reset_mechanism == "subtract": - # self.mem_residual = False + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int - def forward(self, input_, syn_exc=False, syn_inh=False, mem=False): + def _init_mem(self): + syn_exc = torch.zeros(1) + syn_inh = torch.zeros(1) + mem = torch.zeros(1) - if ( - hasattr(syn_exc, "init_flag") - or hasattr(syn_inh, "init_flag") - or hasattr(mem, "init_flag") - ): # only triggered on first-pass - syn_exc, syn_inh, mem = _SpikeTorchConv( - syn_exc, syn_inh, mem, input_=input_ - ) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn_exc, self.syn_inh, self.mem = _SpikeTorchConv( - self.syn_exc, self.syn_inh, self.mem, input_=input_ - ) + self.register_buffer("syn_exc", syn_exc) + self.register_buffer("syn_inh", syn_inh) + self.register_buffer("mem", mem) - # if hidden states are passed externally - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn_exc, syn_inh, mem = self._build_state_function( - input_, syn_exc, syn_inh, mem - ) + def reset_mem(self): + self.syn_exc = torch.zeros_like( + self.syn_exc, device=self.syn_exc.device + ) + self.syn_inh = torch.zeros_like( + self.syn_inh, device=self.syn_inh.device + ) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if self.state_quant: - syn_exc = self.state_quant(syn_exc) - syn_inh = self.state_quant(syn_inh) - mem = self.state_quant(mem) + def init_alpha(self): + """Deprecated, use :class:`Alpha.reset_mem` instead""" + self.reset_mem() + return self.syn_exc, self.syn_inh, self.mem - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) + def forward(self, input_, syn_exc=None, syn_inh=None, mem=None): - else: - spk = self.fire(mem) + if not syn_exc == None: + self.syn_exc = syn_exc - return spk, syn_exc, syn_inh, mem + if not syn_inh == None: + self.syn_inh = syn_inh - # if hidden states and outputs are instance variables - if self.init_hidden: - self._alpha_forward_cases(mem, syn_exc, syn_inh) + if not mem == None: + self.mem = mem - self.reset = self.mem_reset(self.mem) - ( - self.syn_exc, - self.syn_inh, - self.mem, - ) = self._build_state_function_hidden(input_) + if self.init_hidden and ( + not mem == None or not syn_exc == None or not syn_inh == None + ): + raise TypeError( + "When `init_hidden=True`, Alpha expects 1 input argument." + ) - if self.state_quant: - self.syn_exc = self.state_quant(self.syn_exc) - self.syn_inh = self.state_quant(self.syn_inh) - self.mem = self.state_quant(self.mem) + if not self.syn_exc.shape == input_.shape: + self.syn_exc = torch.zeros_like(input_, device=self.syn_exc.device) - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) + if not self.syn_inh.shape == input_.shape: + self.syn_inh = torch.zeros_like(input_, device=self.syn_inh.device) - else: - self.spk = self.fire(self.mem) + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) - if self.output: - return self.spk, self.syn_exc, self.syn_inh, self.mem - else: - return self.spk + self.reset = self.mem_reset(self.mem) + self.syn_exc, self.syn_inh, self.mem = self.state_function(input_) - def _base_state_function(self, input_, syn_exc, syn_inh, mem): - base_fn_syn_exc = self.alpha.clamp(0, 1) * syn_exc + input_ - base_fn_syn_inh = self.beta.clamp(0, 1) * syn_inh - input_ - tau_alpha = ( - torch.log(self.alpha.clamp(0, 1)) - / ( - torch.log(self.beta.clamp(0, 1)) - - torch.log(self.alpha.clamp(0, 1)) - ) - + 1 - ) - base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh) - return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem + if self.state_quant: + self.syn_exc = self.state_quant(self.syn_exc) + self.syn_inh = self.state_quant(self.syn_inh) + self.mem = self.state_quant(self.mem) - def _base_state_reset_sub_function(self, input_, syn_inh): - syn_exc_reset = self.threshold - syn_inh_reset = self.beta.clamp(0, 1) * syn_inh - input_ - mem_reset = 0 - return syn_exc_reset, syn_inh_reset, mem_reset + if self.inhibition: + spk = self.fire_inhibition(self.mem.size(0), self.mem) + else: + spk = self.fire(self.mem) - def _build_state_function(self, input_, syn_exc, syn_inh, mem): - if self.reset_mechanism_val == 0: - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn_exc, syn_inh, mem), - self._base_state_reset_sub_function(input_, syn_inh), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn_exc, syn_inh, mem), - self._base_state_function(input_, syn_exc, syn_inh, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn_exc, syn_inh, mem) - return state_fn + if self.output: + return spk, self.syn_exc, self.syn_inh, self.mem + elif self.init_hidden: + return spk + else: + return spk, self.syn_exc, self.syn_inh, self.mem - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn_syn_exc = self.alpha.clamp(0, 1) * self.syn_exc + input_ base_fn_syn_inh = self.beta.clamp(0, 1) * self.syn_inh - input_ tau_alpha = ( @@ -244,32 +212,34 @@ def _base_state_function_hidden(self, input_): base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh) return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem - def _base_state_reset_sub_function_hidden(self, input_): + def _base_state_reset_sub_function(self, input_): syn_exc_reset = self.threshold syn_inh_reset = self.beta.clamp(0, 1) * self.syn_inh - input_ mem_reset = -self.syn_inh return syn_exc_reset, syn_inh_reset, mem_reset - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_sub_function_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_function_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + syn_exec, syn_inh, mem = self._base_state_function(input_) + syn_exec2, syn_inh2, mem2 = self._base_state_reset_sub_function(input_) + + syn_exec -= syn_exec2 * self.reset + syn_inh -= syn_inh2 * self.reset + mem -= mem2 * self.reset + + return syn_exec, syn_inh, mem + + def _base_zero(self, input_): + syn_exec, syn_inh, mem = self._base_state_function(input_) + syn_exec2, syn_inh2, mem2 = self._base_state_function(input_) + + syn_exec -= syn_exec2 * self.reset + syn_inh -= syn_inh2 * self.reset + mem -= mem2 * self.reset + + return syn_exec, syn_inh, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): @@ -291,12 +261,6 @@ def _alpha_cases(self): "tau_alpha = log(alpha)/log(beta) - log(alpha) + 1" ) - def _alpha_forward_cases(self, mem, syn_exc, syn_inh): - if mem is not False or syn_exc is not False or syn_inh is not False: - raise TypeError( - "When `init_hidden=True`, Alpha expects 1 input argument." - ) - @classmethod def detach_hidden(cls): """Used to detach hidden states from the current graph. @@ -315,6 +279,15 @@ def reset_hidden(cls): variables.""" for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Alpha): - cls.instances[layer].syn_exc = _SpikeTensor(init_flag=False) - cls.instances[layer].syn_inh = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn_exc = torch.zeros_like( + cls.instances[layer].syn_exc, + device=cls.instances[layer].syn_exc.device, + ) + cls.instances[layer].syn_inh = torch.zeros_like( + cls.instances[layer].syn_inh, + device=cls.instances[layer].syn_inh.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/snntorch/_neurons/lapicque.py b/snntorch/_neurons/lapicque.py index bacbfc52..cacec197 100644 --- a/snntorch/_neurons/lapicque.py +++ b/snntorch/_neurons/lapicque.py @@ -1,5 +1,5 @@ import torch -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class Lapicque(LIF): @@ -216,93 +216,77 @@ def __init__( self._lapicque_cases(time_step, beta, R, C) - if self.init_hidden: - self.mem = self.init_lapicque() + self._init_mem() - def forward(self, input_, mem=False): + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int - if hasattr(mem, "init_flag"): # only triggered on first-pass - mem = _SpikeTorchConv(mem, input_=input_) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.mem = _SpikeTorchConv(self.mem, input_=input_) + def _init_mem(self): + mem = torch.zeros(1) + self.register_buffer("mem", mem) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - mem = self._build_state_function(input_, mem) + def reset_mem(self): + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if self.state_quant: - mem = self.state_quant(mem) + def init_lapicque(self): + """Deprecated, use :class:`Lapicque.reset_mem` instead""" + self.reset_mem() + return self.mem - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) - else: - spk = self.fire(mem) + def forward(self, input_, mem=None): - return spk, mem + if not mem == None: + self.mem = mem - # intended for truncated-BPTT where instance variables are hidden - # states - if self.init_hidden: - self._lapicque_forward_cases(mem) - self.reset = self.mem_reset(self.mem) - self.mem = self._build_state_function_hidden(input_) + if self.init_hidden and not mem == None: + raise TypeError( + "`mem` should not be passed as an argument while `init_hidden=True`" + ) - if self.state_quant: - self.mem = self.state_quant(self.mem) + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) + self.reset = self.mem_reset(self.mem) + self.mem = self.state_function(input_) - if self.output: - return self.spk, self.mem - else: - return self.spk + if self.state_quant: + self.mem = self.state_quant(self.mem) - def _base_state_function(self, input_, mem): - base_fn = ( - input_ * self.R * (1 / (self.R * self.C)) * self.time_step - + (1 - (self.time_step / (self.R * self.C))) * mem - ) - return base_fn + if self.inhibition: + spk = self.fire_inhibition( + self.mem.size(0), self.mem + ) # batch_size + else: + spk = self.fire(self.mem) - def _build_state_function(self, input_, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = ( - self._base_state_function(input_, mem) - - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function( - input_, mem - ) - self.reset * self._base_state_function(input_, mem) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, mem) - return state_fn + if self.output: + return spk, self.mem + elif self.init_hidden: + return spk + else: + return spk, self.mem - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn = ( input_ * self.R * (1 / (self.R * self.C)) * self.time_step + (1 - (self.time_step / (self.R * self.C))) * self.mem ) return base_fn - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = ( - self._base_state_function_hidden(input_) - - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function_hidden( - input_ - ) - self.reset * self._base_state_function_hidden(input_) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + return self._base_state_function(input_) - self.reset * self.threshold + + def _base_zero(self, input_): + return self._base_state_function( + input_ + ) - self.reset * self._base_state_function(input_) + + def _base_int(self, input_): + return self._base_state_function(input_) def _lapicque_cases(self, time_step, beta, R, C): if not isinstance(time_step, torch.Tensor): @@ -357,12 +341,6 @@ def _lapicque_cases(self, time_step, beta, R, C): R = self.time_step / (C * torch.log(1 / self.beta)) self.register_buffer("R", R) - def _lapicque_forward_cases(self, mem): - if mem is not False: - raise TypeError( - "When `init_hidden=True`, Lapicque expects 1 input argument." - ) - @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. @@ -381,4 +359,7 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Lapicque): - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index 7406fa82..75fe4fd3 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -2,6 +2,7 @@ import torch from torch import nn + class Leaky(LIF): """ First-order leaky integrate-and-fire neuron model. @@ -170,7 +171,6 @@ def __init__( ) self._init_mem() - self.init_hidden = init_hidden if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub @@ -178,12 +178,8 @@ def __init__( self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int - - self.reset_delay = reset_delay - - if not self.reset_delay and self.init_hidden: - raise NotImplementedError("`reset_delay=True` is only supported for `init_hidden=False`") + self.reset_delay = reset_delay def _init_mem(self): mem = torch.zeros(1) @@ -196,12 +192,12 @@ def init_leaky(self): """Deprecated, use :class:`Leaky.reset_mem` instead""" self.reset_mem() return self.mem - + def forward(self, input_, mem=None): if not mem == None: self.mem = mem - + if self.init_hidden and not mem == None: raise TypeError( "`mem` should not be passed as an argument while `init_hidden=True`" @@ -217,12 +213,16 @@ def forward(self, input_, mem=None): self.mem = self.state_quant(self.mem) if self.inhibition: - spk = self.fire_inhibition(self.mem.size(0), self.mem) # batch_size + spk = self.fire_inhibition( + self.mem.size(0), self.mem + ) # batch_size else: spk = self.fire(self.mem) - + if not self.reset_delay: - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset + do_reset = ( + spk / self.graded_spikes_factor - self.reset + ) # avoid double reset if self.reset_mechanism_val == 0: # reset by subtraction self.mem = self.mem - do_reset * self.threshold elif self.reset_mechanism_val == 1: # reset to zero diff --git a/snntorch/_neurons/neurons.py b/snntorch/_neurons/neurons.py index ee58079c..18fa4216 100644 --- a/snntorch/_neurons/neurons.py +++ b/snntorch/_neurons/neurons.py @@ -1,6 +1,5 @@ -import inspect from warnings import warn -from snntorch.surrogate import StraightThroughEstimator, atan, straight_through_estimator +from snntorch.surrogate import atan import torch import torch.nn as nn @@ -8,8 +7,6 @@ __all__ = [ "SpikingNeuron", "LIF", - "_SpikeTensor", - "_SpikeTorchConv", ] dtype = torch.float @@ -234,6 +231,7 @@ def zeros(*args): def _surrogate_bypass(input_): return (input_ > 0).float() + class LIF(SpikingNeuron): """Parent class for leaky integrate and fire neuron models.""" @@ -298,107 +296,3 @@ def _V_register_buffer(self, V, learn_V): self.V = nn.Parameter(V) else: self.register_buffer("V", V) - - @staticmethod - def init_rleaky(): - """ - Used to initialize spk and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - spk = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return spk, mem - - @staticmethod - def init_synaptic(): - """Used to initialize syn and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - - syn = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return syn, mem - - @staticmethod - def init_rsynaptic(): - """ - Used to initialize spk, syn and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - spk = _SpikeTensor(init_flag=False) - syn = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return spk, syn, mem - - @staticmethod - def init_lapicque(): - """ - Used to initialize mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - - mem = _SpikeTensor(init_flag=False) - - return mem - - @staticmethod - def init_alpha(): - """Used to initialize syn_exc, syn_inh and mem as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - syn_exc = _SpikeTensor(init_flag=False) - syn_inh = _SpikeTensor(init_flag=False) - mem = _SpikeTensor(init_flag=False) - - return syn_exc, syn_inh, mem - - -class _SpikeTensor(torch.Tensor): - """Inherits from torch.Tensor with additional attributes. - ``init_flag`` is set at the time of initialization. - When called in the forward function of any neuron, they are parsed and - replaced with a torch.Tensor variable. - """ - - @staticmethod - def __new__(cls, *args, init_flag=False, **kwargs): - return super().__new__(cls, *args, **kwargs) - - def __init__( - self, - *args, - init_flag=True, - ): - # super().__init__() # optional - self.init_flag = init_flag - - -def _SpikeTorchConv(*args, input_): - """Convert SpikeTensor to torch.Tensor of the same size as ``input_``.""" - - states = [] - # if len(input_.size()) == 0: - # _batch_size = 1 # assume batch_size=1 if 1D input - # else: - # _batch_size = input_.size(0) - if ( - len(args) == 1 and type(args) is not tuple - ): # if only one hidden state, make it iterable - args = (args,) - for arg in args: - arg = arg.to("cpu") - arg = torch.Tensor(arg) # wash away the SpikeTensor class - arg = torch.zeros_like(input_, requires_grad=True) - states.append(arg) - if len(states) == 1: # otherwise, list isn't unpacked - return states[0] - - return states diff --git a/snntorch/_neurons/rleaky.py b/snntorch/_neurons/rleaky.py index 48383ada..66046bc0 100644 --- a/snntorch/_neurons/rleaky.py +++ b/snntorch/_neurons/rleaky.py @@ -2,7 +2,7 @@ import torch.nn as nn # from torch import functional as F -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class RLeaky(LIF): @@ -280,72 +280,81 @@ def __init__( if not learn_recurrent: self._disable_recurrent_grad() + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int + self.reset_delay = reset_delay - if not self.reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') - - if self.init_hidden: - self.spk, self.mem = self.init_rleaky() - # self.state_fn = self._build_state_function_hidden - # else: - # self.state_fn = self._build_state_function - - def forward(self, input_, spk=False, mem=False): - if hasattr(spk, "init_flag") or hasattr( - mem, "init_flag" - ): # only triggered on first-pass - spk, mem = _SpikeTorchConv(spk, mem, input_=input_) - # init_hidden case - elif mem is False and hasattr(self.mem, "init_flag"): - self.spk, self.mem = _SpikeTorchConv( - self.spk, self.mem, input_=input_ - ) + def _init_mem(self): + spk = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("spk", spk) + self.register_buffer("mem", mem) - # TO-DO: alternatively, we could do torch.exp(-1 / - # self.beta.clamp_min(0)), giving actual time constants instead of - # values in [0, 1] as initial beta beta = self.beta.clamp(0, 1) + def reset_mem(self): + self.spk = torch.zeros_like(self.spk, device=self.spk.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - mem = self._build_state_function(input_, spk, mem) + def init_rleaky(self): + """Deprecated, use :class:`RLeaky.reset_mem` instead""" + self.reset_mem() + return self.spk, self.mem - if self.state_quant: - mem = self.state_quant(mem) + def forward(self, input_, spk=None, mem=None): - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) # batch_size - else: - spk = self.fire(mem) + if not spk == None: + self.spk = spk - if not self.reset_delay: - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset - if self.reset_mechanism_val == 0: # reset by subtraction - mem = mem - do_reset * self.threshold - elif self.reset_mechanism_val == 1: # reset to zero - mem = mem - do_reset * mem + if not mem == None: + self.mem = mem - return spk, mem + if self.init_hidden and (not mem == None or not spk == None): + raise TypeError( + "When `init_hidden=True`," "RLeaky expects 1 input argument." + ) + + if not self.spk.shape == input_.shape: + self.spk = torch.zeros_like(input_, device=self.spk.device) - # intended for truncated-BPTT where instance variables are hidden - # states - if self.init_hidden: - self._rleaky_forward_cases(spk, mem) - self.reset = self.mem_reset(self.mem) - self.mem = self._build_state_function_hidden(input_) + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) + + # TO-DO: alternatively, we could do torch.exp(-1 / + # self.beta.clamp_min(0)), giving actual time constants instead of + # values in [0, 1] as initial beta beta = self.beta.clamp(0, 1) - if self.state_quant: - self.mem = self.state_quant(self.mem) + self.reset = self.mem_reset(self.mem) + self.mem = self.state_function(input_) - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) + if self.state_quant: + self.mem = self.state_quant(self.mem) - if self.output: # read-out layer returns output+states - return self.spk, self.mem - else: # hidden layer e.g., in nn.Sequential, only returns output - return self.spk + if self.inhibition: + self.spk = self.fire_inhibition(self.mem.size(0), self.mem) + else: + self.spk = self.fire(self.mem) + + if not self.reset_delay: + do_reset = ( + self.spk / self.graded_spikes_factor - self.reset + ) # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + self.mem = self.mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + self.mem = self.mem - do_reset * self.mem + + if self.output: + return self.spk, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.mem def _init_recurrent_net(self): if self.all_to_all: @@ -381,24 +390,7 @@ def _disable_recurrent_grad(self): for param in self.recurrent.parameters(): param.requires_grad = False - def _base_state_function(self, input_, spk, mem): - base_fn = self.beta.clamp(0, 1) * mem + input_ + self.recurrent(spk) - return base_fn - - def _build_state_function(self, input_, spk, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = self._base_state_function( - input_, spk, mem - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function( - input_, spk, mem - ) - self.reset * self._base_state_function(input_, spk, mem) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, spk, mem) - return state_fn - - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn = ( self.beta.clamp(0, 1) * self.mem + input_ @@ -406,25 +398,16 @@ def _base_state_function_hidden(self, input_): ) return base_fn - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = ( - self._base_state_function_hidden(input_) - - self.reset * self.threshold - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = self._base_state_function_hidden( - input_ - ) - self.reset * self._base_state_function_hidden(input_) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + return self._base_state_function(input_) - self.reset * self.threshold - def _rleaky_forward_cases(self, spk, mem): - if mem is not False or spk is not False: - raise TypeError( - "When `init_hidden=True`," "RLeaky expects 1 input argument." - ) + def _base_zero(self, input_): + return self._base_state_function( + input_ + ) - self.reset * self._base_state_function(input_) + + def _base_int(self, input_): + return self._base_state_function(input_) def _rleaky_init_cases(self): all_to_all_bool = bool(self.all_to_all) diff --git a/snntorch/_neurons/rsynaptic.py b/snntorch/_neurons/rsynaptic.py index 654fd3d2..0020a3ff 100644 --- a/snntorch/_neurons/rsynaptic.py +++ b/snntorch/_neurons/rsynaptic.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class RSynaptic(LIF): @@ -295,72 +295,89 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int + self.reset_delay = reset_delay - if not reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') - - if self.init_hidden: - self.spk, self.syn, self.mem = self.init_rsynaptic() - - def forward(self, input_, spk=False, syn=False, mem=False): - if ( - hasattr(spk, "init_flag") - or hasattr(syn, "init_flag") - or hasattr(mem, "init_flag") - ): # only triggered on first-pass - spk, syn, mem = _SpikeTorchConv(spk, syn, mem, input_=input_) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.spk, self.syn, self.mem = _SpikeTorchConv( - self.spk, self.syn, self.mem, input_=input_ + def _init_mem(self): + spk = torch.zeros(1) + syn = torch.zeros(1) + mem = torch.zeros(1) + + self.register_buffer("spk", spk) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) + + def reset_mem(self): + self.spk = torch.zeros_like(self.spk, device=self.spk.device) + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) + + def init_rsynaptic(self): + """Deprecated, use :class:`RSynaptic.reset_mem` instead""" + self.reset_mem() + return self.spk, self.syn, self.mem + + def forward(self, input_, spk=None, syn=None, mem=None): + if not spk == None: + self.spk = spk + + if not syn == None: + self.syn = syn + + if not mem == None: + self.mem = mem + + if self.init_hidden and (not spk == None or not syn == None or not mem == None): + raise TypeError( + "When `init_hidden=True`, RSynaptic expects 1 input argument." ) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, spk, syn, mem) - - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) - - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) - else: - spk = self.fire(mem) - - if not self.reset_delay: - # reset membrane potential _right_ after spike - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset - if self.reset_mechanism_val == 0: # reset by subtraction - mem = mem - do_reset * self.threshold - elif self.reset_mechanism_val == 1: # reset to zero - # mem -= do_reset * mem - mem = mem - do_reset * mem - - return spk, syn, mem - - # intended for truncated-BPTT where instance variables are hidden - # states - if self.init_hidden: - self._rsynaptic_forward_cases(spk, mem, syn) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) - - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) - - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) - - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk + if not self.spk.shape == input_.shape: + self.spk = torch.zeros_like(input_, device=self.spk.device) + + if not self.syn.shape == input_.shape: + self.syn = torch.zeros_like(input_, device=self.syn.device) + + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) + + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) + + if self.state_quant: + self.syn = self.state_quant(self.syn) + self.mem = self.state_quant(self.mem) + + if self.inhibition: + self.spk = self.fire_inhibition(self.mem.size(0), self.mem) + else: + self.spk = self.fire(self.mem) + + if not self.reset_delay: + # reset membrane potential _right_ after spike + do_reset = ( + spk / self.graded_spikes_factor - self.reset + ) # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + # mem -= do_reset * mem + mem = mem - do_reset * mem + + if self.output: + return self.spk, self.syn, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.syn, self.mem def _init_recurrent_net(self): if self.all_to_all: @@ -396,42 +413,7 @@ def _disable_recurrent_grad(self): for param in self.recurrent.parameters(): param.requires_grad = False - def _base_state_function(self, input_, spk, syn, mem): - base_fn_syn = ( - self.alpha.clamp(0, 1) * syn + input_ + self.recurrent(spk) - ) - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return base_fn_syn, base_fn_mem - - def _base_state_reset_zero(self, input_, spk, syn, mem): - base_fn_syn = ( - self.alpha.clamp(0, 1) * syn + input_ + self.recurrent(spk) - ) - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return 0, base_fn_mem - - def _build_state_function(self, input_, spk, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, spk, syn, mem), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, spk, syn, mem), - self._base_state_reset_zero(input_, spk, syn, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, spk, syn, mem) - return state_fn - - def _base_state_function_hidden(self, input_): + def _base_state_function(self, input_): base_fn_syn = ( self.alpha.clamp(0, 1) * self.syn + input_ @@ -440,7 +422,7 @@ def _base_state_function_hidden(self, input_): base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): base_fn_syn = ( self.alpha.clamp(0, 1) * self.syn + input_ @@ -449,26 +431,22 @@ def _base_state_reset_zero_hidden(self, input_): base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_function_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem -= self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn2 *= self.reset + mem2 *= self.reset + syn -= syn2 + mem -= mem2 + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): @@ -478,12 +456,6 @@ def _alpha_register_buffer(self, alpha, learn_alpha): else: self.register_buffer("alpha", alpha) - def _rsynaptic_forward_cases(self, spk, mem, syn): - if mem is not False or syn is not False or spk is not False: - raise TypeError( - "When `init_hidden=True`, RSynaptic expects 1 input argument." - ) - def _rsynaptic_init_cases(self): all_to_all_bool = bool(self.all_to_all) linear_features_bool = self.linear_features @@ -545,8 +517,14 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], RSynaptic): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) class RecurrentOneToOne(nn.Module): diff --git a/snntorch/_neurons/sconv2dlstm.py b/snntorch/_neurons/sconv2dlstm.py index 73d91a7d..a38e4b03 100644 --- a/snntorch/_neurons/sconv2dlstm.py +++ b/snntorch/_neurons/sconv2dlstm.py @@ -1,12 +1,10 @@ import torch -from torch._C import Value import torch.nn as nn import torch.nn.functional as F -from .neurons import _SpikeTensor, _SpikeTorchConv, SpikingNeuron +from .neurons import SpikingNeuron class SConv2dLSTM(SpikingNeuron): - """ A spiking 2d convolutional long short-term memory cell. Hidden states are membrane potential and synaptic current @@ -240,8 +238,14 @@ def __init__( output, ) - if self.init_hidden: - self.syn, self.mem = self.init_sconv2dlstm() + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int self.in_channels = in_channels self.out_channels = out_channels @@ -268,117 +272,63 @@ def __init__( bias=self.bias, ) - def forward(self, input_, syn=False, mem=False): - if hasattr(mem, "init_flag") or hasattr( - syn, "init_flag" - ): # only triggered on first-pass - - syn, mem = _SpikeTorchConv( - syn, mem, input_=self._reshape_input(input_) - ) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn, self.mem = _SpikeTorchConv( - self.syn, self.mem, input_=self._reshape_input(input_) - ) + def _init_mem(self): + syn = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, syn, mem) - - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) - - if self.max_pool: - spk = self.fire(F.max_pool2d(mem, self.max_pool)) - elif self.avg_pool: - spk = self.fire(F.avg_pool2d(mem, self.avg_pool)) - else: - spk = self.fire(mem) - return spk, syn, mem - - if self.init_hidden: - # self._sconv2dlstm_forward_cases(mem, c) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) - - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) - - if self.max_pool: - self.spk = self.fire(F.max_pool2d(self.mem, self.max_pool)) - elif self.avg_pool: - self.spk = self.fire(F.avg_pool2d(self.mem, self.avg_pool)) - else: - self.spk = self.fire(self.mem) - - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk - - def _base_state_function(self, input_, syn, mem): + def reset_mem(self): + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - combined = torch.cat( - [input_, mem], dim=1 - ) # concatenate along channel axis (BxCxHxW) - combined_conv = self.conv(combined) - cc_i, cc_f, cc_o, cc_g = torch.split( - combined_conv, self.out_channels, dim=1 - ) - i = torch.sigmoid(cc_i) - f = torch.sigmoid(cc_f) - o = torch.sigmoid(cc_o) - g = torch.tanh(cc_g) - - base_fn_syn = f * syn + i * g - base_fn_mem = o * torch.tanh(base_fn_syn) - - return base_fn_syn, base_fn_mem - - def _base_state_reset_zero(self, input_, syn, mem): - combined = torch.cat( - [input_, mem], dim=1 - ) # concatenate along channel axis - combined_conv = self.conv(combined) - cc_i, cc_f, cc_o, cc_g = torch.split( - combined_conv, self.out_channels, dim=1 - ) - i = torch.sigmoid(cc_i) - f = torch.sigmoid(cc_f) - o = torch.sigmoid(cc_o) - g = torch.tanh(cc_g) + def init_sconv2dlstm(self): + """Deprecated, use :class:`SConv2dLSTM.reset_mem` instead""" + self.reset_mem() + return self.syn, self.mem - base_fn_syn = f * syn + i * g - base_fn_mem = o * torch.tanh(base_fn_syn) + def forward(self, input_, syn=None, mem=None): + if not syn == None: + self.syn = syn - return 0, base_fn_mem + if not mem == None: + self.mem = mem - def _build_state_function(self, input_, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, syn, mem), - (0, self.reset * self.threshold), - ) + if self.init_hidden and (not mem == None or not syn == None): + raise TypeError( + "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn, mem), - self._base_state_reset_zero(input_, syn, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn, mem) - return state_fn + + size = input_.size() + correct_shape = (size[0], self.out_channels, size[2], size[3]) + if not self.syn.shape == correct_shape: + self.syn = torch.zeros(correct_shape, device=self.syn.device) + + if not self.mem.shape == correct_shape: + self.mem = torch.zeros(correct_shape, device=self.mem.device) + + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) + + if self.state_quant: + self.syn = self.state_quant(self.syn) + self.mem = self.state_quant(self.mem) + + if self.max_pool: + self.spk = self.fire(F.max_pool2d(self.mem, self.max_pool)) + elif self.avg_pool: + self.spk = self.fire(F.avg_pool2d(self.mem, self.avg_pool)) + else: + self.spk = self.fire(self.mem) - def _base_state_function_hidden(self, input_): + if self.output: + return self.spk, self.syn, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.syn, self.mem + + def _base_state_function(self, input_): combined = torch.cat( [input_, self.mem], dim=1 ) # concatenate along channel axis @@ -396,7 +346,7 @@ def _base_state_function_hidden(self, input_): return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): combined = torch.cat( [input_, self.mem], dim=1 ) # concatenate along channel axis @@ -414,43 +364,22 @@ def _base_state_reset_zero_hidden(self, input_): return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_zero_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn - - @staticmethod - def init_sconv2dlstm(): - """ - Used to initialize h and c as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - mem = _SpikeTensor(init_flag=False) - syn = _SpikeTensor(init_flag=False) - - return mem, syn - - def _reshape_input(self, input_): - device = input_.device - b, _, h, w = input_.size() - return torch.zeros(b, self.out_channels, h, w).to(device) + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem -= self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn2 *= self.reset + mem2 *= self.reset + syn -= syn2 + mem -= mem2 + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _sconv2dlstm_cases(self): if self.max_pool and self.avg_pool: @@ -478,5 +407,11 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], SConv2dLSTM): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/snntorch/_neurons/slstm.py b/snntorch/_neurons/slstm.py index dd29ba6b..dceff42f 100644 --- a/snntorch/_neurons/slstm.py +++ b/snntorch/_neurons/slstm.py @@ -1,12 +1,9 @@ import torch -from torch._C import Value import torch.nn as nn -import torch.nn.functional as F -from .neurons import _SpikeTensor, _SpikeTorchConv, SpikingNeuron +from .neurons import SpikingNeuron class SLSTM(SpikingNeuron): - """ A spiking long short-term memory cell. Hidden states are membrane potential and synaptic current @@ -188,8 +185,14 @@ def __init__( output, ) - if self.init_hidden: - self.syn, self.mem = self.init_slstm() + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int self.input_size = input_size self.hidden_size = hidden_size @@ -199,122 +202,82 @@ def __init__( self.input_size, self.hidden_size, bias=self.bias ) - def forward(self, input_, syn=False, mem=False): - if hasattr(mem, "init_flag") or hasattr( - syn, "init_flag" - ): # only triggered on first-pass + def _init_mem(self): + syn = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) - syn, mem = _SpikeTorchConv( - syn, mem, input_=self._reshape_input(input_) - ) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn, self.mem = _SpikeTorchConv( - self.syn, self.mem, input_=self._reshape_input(input_) - ) + def reset_mem(self): + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, syn, mem) + def init_slstm(self): + """Deprecated, use :class:`SLSTM.reset_mem` instead""" + self.reset_mem() + return self.syn, self.mem - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) + def forward(self, input_, syn=None, mem=None): + if not syn == None: + self.syn = syn - spk = self.fire(mem) - return spk, syn, mem + if not mem == None: + self.mem = mem - if self.init_hidden: - # self._slstm_forward_cases(mem, syn) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) + if self.init_hidden and (not mem == None or not syn == None): + raise TypeError( + "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" + ) - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) + size = input_.size() + correct_shape = (size[0], self.hidden_size) - self.spk = self.fire(self.mem) + if not self.syn.shape == input_.shape: + self.syn = torch.zeros(correct_shape, device=self.syn.device) - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk + if not self.mem.shape == input_.shape: + self.mem = torch.zeros(correct_shape, device=self.mem.device) - def _base_state_function(self, input_, syn, mem): - base_fn_mem, base_fn_syn = self.lstm_cell(input_, (mem, syn)) - return base_fn_syn, base_fn_mem + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) - def _base_state_reset_zero(self, input_, syn, mem): - base_fn_mem, _ = self.lstm_cell(input_, (mem, syn)) - return 0, base_fn_mem + if self.state_quant: + self.syn = self.state_quant(self.syn) + self.mem = self.state_quant(self.mem) - def _build_state_function(self, input_, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, syn, mem), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn, mem), - self._base_state_reset_zero(input_, syn, mem), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn, mem) - return state_fn + self.spk = self.fire(self.mem) - def _base_state_function_hidden(self, input_): + if self.output: + return self.spk, self.syn, self.mem + elif self.init_hidden: + return self.spk + else: + return self.spk, self.syn, self.mem + + def _base_state_function(self, input_): base_fn_mem, base_fn_syn = self.lstm_cell(input_, (self.mem, self.syn)) return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): base_fn_mem, _ = self.lstm_cell(input_, (self.mem, self.syn)) return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_zero_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn - - def _reshape_input(self, input_): - device = input_.device - b, _ = input_.size() - return torch.zeros(b, self.hidden_size).to(device) - - @staticmethod - def init_slstm(): - """ - Used to initialize mem and syn as an empty SpikeTensor. - ``init_flag`` is used as an attribute in the forward pass to convert - the hidden states to the same as the input. - """ - mem = _SpikeTensor(init_flag=False) - syn = _SpikeTensor(init_flag=False) - - return mem, syn + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem -= self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn2 *= self.reset + mem2 *= self.reset + syn -= syn2 + mem -= mem2 + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) @classmethod def detach_hidden(cls): @@ -335,5 +298,11 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], SLSTM): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/snntorch/_neurons/synaptic.py b/snntorch/_neurons/synaptic.py index 4e3be032..6209ca32 100644 --- a/snntorch/_neurons/synaptic.py +++ b/snntorch/_neurons/synaptic.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .neurons import _SpikeTensor, _SpikeTorchConv, LIF +from .neurons import LIF class Synaptic(LIF): @@ -186,132 +186,106 @@ def __init__( self._alpha_register_buffer(alpha, learn_alpha) + self._init_mem() + + if self.reset_mechanism_val == 0: # reset by subtraction + self.state_function = self._base_sub + elif self.reset_mechanism_val == 1: # reset to zero + self.state_function = self._base_zero + elif self.reset_mechanism_val == 2: # no reset, pure integration + self.state_function = self._base_int + self.reset_delay = reset_delay - if not reset_delay and self.init_hidden: - raise NotImplementedError('no reset_delay only supported for init_hidden=False') + def _init_mem(self): + syn = torch.zeros(1) + mem = torch.zeros(1) + self.register_buffer("syn", syn) + self.register_buffer("mem", mem) - if self.init_hidden: - self.syn, self.mem = self.init_synaptic() + def reset_mem(self): + self.syn = torch.zeros_like(self.syn, device=self.syn.device) + self.mem = torch.zeros_like(self.mem, device=self.mem.device) - def forward(self, input_, syn=False, mem=False): + def init_synaptic(self): + """Deprecated, use :class:`Synaptic.reset_mem` instead""" + self.reset_mem() + return self.syn, self.mem - if hasattr(syn, "init_flag") or hasattr( - mem, "init_flag" - ): # only triggered on first-pass - syn, mem = _SpikeTorchConv(syn, mem, input_=input_) - elif mem is False and hasattr( - self.mem, "init_flag" - ): # init_hidden case - self.syn, self.mem = _SpikeTorchConv( - self.syn, self.mem, input_=input_ - ) + def forward(self, input_, syn=None, mem=None): - if not self.init_hidden: - self.reset = self.mem_reset(mem) - syn, mem = self._build_state_function(input_, syn, mem) - - if self.state_quant: - syn = self.state_quant(syn) - mem = self.state_quant(mem) - - if self.inhibition: - spk = self.fire_inhibition(mem.size(0), mem) - else: - spk = self.fire(mem) - - if not self.reset_delay: - # reset membrane potential _right_ after spike - do_reset = spk / self.graded_spikes_factor - self.reset # avoid double reset - if self.reset_mechanism_val == 0: # reset by subtraction - mem = mem - do_reset * self.threshold - elif self.reset_mechanism_val == 1: # reset to zero - mem = mem - do_reset * mem - - return spk, syn, mem - - # intended for truncated-BPTT where instance variables are - # hidden states - if self.init_hidden: - self._synaptic_forward_cases(mem, syn) - self.reset = self.mem_reset(self.mem) - self.syn, self.mem = self._build_state_function_hidden(input_) - - if self.state_quant: - self.syn = self.state_quant(self.syn) - self.mem = self.state_quant(self.mem) - - if self.inhibition: - self.spk = self.fire_inhibition(self.mem.size(0), self.mem) - else: - self.spk = self.fire(self.mem) - - if self.output: - return self.spk, self.syn, self.mem - else: - return self.spk - - def _base_state_function(self, input_, syn, mem): - base_fn_syn = self.alpha.clamp(0, 1) * syn + input_ - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return base_fn_syn, base_fn_mem + if not syn == None: + self.syn = mem - def _base_state_reset_zero(self, input_, syn, mem): - base_fn_syn = self.alpha.clamp(0, 1) * syn + input_ - base_fn_mem = self.beta.clamp(0, 1) * mem + base_fn_syn - return 0, base_fn_mem + if not mem == None: + self.mem = mem - def _build_state_function(self, input_, syn, mem): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function(input_, syn, mem), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function(input_, syn, mem), - self._base_state_reset_zero(input_, syn, mem), - ) + if self.init_hidden and (not mem == None or not syn == None): + raise TypeError( + "`mem` or `syn` should not be passed as an argument while `init_hidden=True`" ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function(input_, syn, mem) - return state_fn - def _base_state_function_hidden(self, input_): + if not self.syn.shape == input_.shape: + self.syn = torch.zeros_like(input_, device=self.syn.device) + + if not self.mem.shape == input_.shape: + self.mem = torch.zeros_like(input_, device=self.mem.device) + + self.reset = self.mem_reset(self.mem) + self.syn, self.mem = self.state_function(input_) + + if self.state_quant: + self.mem = self.state_quant(self.mem) + self.syn = self.state_quant(self.syn) + + if self.inhibition: + spk = self.fire_inhibition( + self.mem.size(0), self.mem + ) # batch_size + else: + spk = self.fire(self.mem) + + if not self.reset_delay: + # reset membrane potential _right_ after spike + do_reset = ( + spk / self.graded_spikes_factor - self.reset + ) # avoid double reset + if self.reset_mechanism_val == 0: # reset by subtraction + mem = mem - do_reset * self.threshold + elif self.reset_mechanism_val == 1: # reset to zero + mem = mem - do_reset * mem + + if self.output: + return spk, self.syn, self.mem + elif self.init_hidden: + return spk + else: + return spk, self.syn, self.mem + + def _base_state_function(self, input_): base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_ base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return base_fn_syn, base_fn_mem - def _base_state_reset_zero_hidden(self, input_): + def _base_state_reset_zero(self, input_): base_fn_syn = self.alpha.clamp(0, 1) * self.syn + input_ base_fn_mem = self.beta.clamp(0, 1) * self.mem + base_fn_syn return 0, base_fn_mem - def _build_state_function_hidden(self, input_): - if self.reset_mechanism_val == 0: # reset by subtraction - state_fn = tuple( - map( - lambda x, y: x - y, - self._base_state_function_hidden(input_), - (0, self.reset * self.threshold), - ) - ) - elif self.reset_mechanism_val == 1: # reset to zero - state_fn = tuple( - map( - lambda x, y: x - self.reset * y, - self._base_state_function_hidden(input_), - self._base_state_reset_zero_hidden(input_), - ) - ) - elif self.reset_mechanism_val == 2: # no reset, pure integration - state_fn = self._base_state_function_hidden(input_) - return state_fn + def _base_sub(self, input_): + syn, mem = self._base_state_function(input_) + mem = mem - self.reset * self.threshold + return syn, mem + + def _base_zero(self, input_): + syn, mem = self._base_state_function(input_) + syn2, mem2 = self._base_state_reset_zero(input_) + syn -= syn2 * self.reset + mem -= mem2 * self.reset + return syn, mem + + def _base_int(self, input_): + return self._base_state_function(input_) def _alpha_register_buffer(self, alpha, learn_alpha): if not isinstance(alpha, torch.Tensor): @@ -321,12 +295,6 @@ def _alpha_register_buffer(self, alpha, learn_alpha): else: self.register_buffer("alpha", alpha) - def _synaptic_forward_cases(self, mem, syn): - if mem is not False or syn is not False: - raise TypeError( - "When `init_hidden=True`, Synaptic expects 1 input argument." - ) - @classmethod def detach_hidden(cls): """Returns the hidden states, detached from the current graph. @@ -346,5 +314,11 @@ def reset_hidden(cls): for layer in range(len(cls.instances)): if isinstance(cls.instances[layer], Synaptic): - cls.instances[layer].syn = _SpikeTensor(init_flag=False) - cls.instances[layer].mem = _SpikeTensor(init_flag=False) + cls.instances[layer].syn = torch.zeros_like( + cls.instances[layer].syn, + device=cls.instances[layer].syn.device, + ) + cls.instances[layer].mem = torch.zeros_like( + cls.instances[layer].mem, + device=cls.instances[layer].mem.device, + ) diff --git a/snntorch/surrogate.py b/snntorch/surrogate.py index 390eb1bd..31914e93 100644 --- a/snntorch/surrogate.py +++ b/snntorch/surrogate.py @@ -197,7 +197,7 @@ def backward(ctx, grad_output): grad = ( ctx.alpha / 2 - / (1 + (math.pi / 2 * ctx.alpha * input_).pow_(2)) + / (1 + (torch.pi / 2 * ctx.alpha * input_).pow_(2)) * grad_input ) return grad, None diff --git a/tests/test_snntorch/test_alpha.py b/tests/test_snntorch/test_alpha.py index e96f8c4c..29e59935 100644 --- a/tests/test_snntorch/test_alpha.py +++ b/tests/test_snntorch/test_alpha.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -16,6 +17,10 @@ def input_(): def alpha_instance(): return snn.Alpha(alpha=0.6, beta=0.5, reset_mechanism="subtract") +@pytest.fixture(scope="module") +def alpha_instance_surrogate(): + return snn.Alpha(alpha=0.6, beta=0.5, reset_mechanism="subtract", surrogate_disable=True) + @pytest.fixture(scope="module") def alpha_reset_zero_instance(): @@ -136,3 +141,9 @@ def test_alpha_init_hidden_reset_none( def test_alpha_cases(self, alpha_hidden_instance, input_): with pytest.raises(TypeError): alpha_hidden_instance(input_, input_) + + + def test_alpha_compile_fullgraph(self, alpha_instance_surrogate, input_): + explanation = dynamo.explain(alpha_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 \ No newline at end of file diff --git a/tests/test_snntorch/test_lapicque.py b/tests/test_snntorch/test_lapicque.py index 36345975..f76dd232 100644 --- a/tests/test_snntorch/test_lapicque.py +++ b/tests/test_snntorch/test_lapicque.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -17,6 +18,11 @@ def lapicque_instance(): return snn.Lapicque(beta=0.5) +@pytest.fixture(scope="module") +def lapicque_instance_surrogate(): + return snn.Lapicque(beta=0.5, surrogate_disable=True) + + @pytest.fixture(scope="module") def lapicque_reset_zero_instance(): return snn.Lapicque(beta=0.5, reset_mechanism="zero") @@ -128,3 +134,8 @@ def test_lapicque_init_hidden_reset_none( def test_lapicque_cases(self, lapicque_hidden_instance, input_): with pytest.raises(TypeError): lapicque_hidden_instance(input_, input_) + + def test_lapicque_compile_fullgraph(self, lapicque_instance_surrogate, input_): + explanation = dynamo.explain(lapicque_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_leaky.py b/tests/test_snntorch/test_leaky.py index a04ca744..2bfc26ca 100644 --- a/tests/test_snntorch/test_leaky.py +++ b/tests/test_snntorch/test_leaky.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -17,6 +18,11 @@ def leaky_instance(): return snn.Leaky(beta=0.5) +@pytest.fixture(scope="module") +def leaky_instance_surrogate(): + return snn.Leaky(beta=0.5, surrogate_disable=True) + + @pytest.fixture(scope="module") def leaky_reset_zero_instance(): return snn.Leaky(beta=0.5, reset_mechanism="zero") @@ -126,8 +132,13 @@ def test_leaky_cases(self, leaky_hidden_instance, input_): leaky_hidden_instance(input_, input_) def test_leaky_hidden_learn_graded_instance( - self, leaky_hidden_learn_graded_instance + self, leaky_hidden_learn_graded_instance ): factor = leaky_hidden_learn_graded_instance.graded_spikes_factor assert factor.requires_grad + + def test_leaky_compile_fullgraph(self, leaky_instance_surrogate, input_): + explanation = dynamo.explain(leaky_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_rleaky.py b/tests/test_snntorch/test_rleaky.py index 5e336372..617488d4 100644 --- a/tests/test_snntorch/test_rleaky.py +++ b/tests/test_snntorch/test_rleaky.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -17,6 +18,13 @@ def rleaky_instance(): return snn.RLeaky(beta=0.5, V=0.5, all_to_all=False) +@pytest.fixture(scope="module") +def rleaky_instance_surrogate(): + return snn.RLeaky( + beta=0.5, V=0.5, all_to_all=False, surrogate_disable=True + ) + + @pytest.fixture(scope="module") def rleaky_reset_zero_instance(): return snn.RLeaky( @@ -133,3 +141,10 @@ def test_rleaky_init_hidden_reset_none( def test_lreaky_cases(self, rleaky_hidden_instance, input_): with pytest.raises(TypeError): rleaky_hidden_instance(input_, input_, input_) + + def test_rleaky_compile_fullgraph( + self, rleaky_instance_surrogate, input_ + ): + explanation = dynamo.explain(rleaky_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_rsynaptic.py b/tests/test_snntorch/test_rsynaptic.py index 54e981d6..9bd386b6 100644 --- a/tests/test_snntorch/test_rsynaptic.py +++ b/tests/test_snntorch/test_rsynaptic.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -22,6 +23,13 @@ def rsynaptic_instance(): ) +@pytest.fixture(scope="module") +def rsynaptic_instance_surrogate(): + return snn.RSynaptic( + alpha=0.5, beta=0.5, V=0.5, all_to_all=False, surrogate_disable=True + ) + + @pytest.fixture(scope="module") def rsynaptic_reset_zero_instance(): return snn.RSynaptic( @@ -144,3 +152,10 @@ def test_rsynaptic_init_hidden_reset_none( def test_rsynaptic_cases(self, rsynaptic_hidden_instance, input_): with pytest.raises(TypeError): rsynaptic_hidden_instance(input_, input_) + + def test_rsynaptic_compile_fullgraph( + self, rsynaptic_instance_surrogate, input_ + ): + explanation = dynamo.explain(rsynaptic_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_sconv2dlstm.py b/tests/test_snntorch/test_sconv2dlstm.py index 8a504923..7c8b96f4 100644 --- a/tests/test_snntorch/test_sconv2dlstm.py +++ b/tests/test_snntorch/test_sconv2dlstm.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo # TO-DO: add avg/max-pooling tests @@ -20,6 +21,11 @@ def sconv2dlstm_instance(): return snn.SConv2dLSTM(1, 8, 3) +@pytest.fixture(scope="module") +def sconv2dlstm_instance_surrogate(): + return snn.SConv2dLSTM(1, 8, 3, surrogate_disable=True) + + @pytest.fixture(scope="module") def sconv2dlstm_reset_zero_instance(): return snn.SConv2dLSTM(1, 8, 3, reset_mechanism="zero") @@ -124,3 +130,10 @@ def test_sconv2dlstm_init_hidden_reset_subtract( spk_rec.append(spk) assert spk_rec[0].size() == (1, 8, 4, 4) + + def test_sconv2dlstm_compile_fullgraph( + self, sconv2dlstm_instance_surrogate, input_ + ): + explanation = dynamo.explain(sconv2dlstm_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_slstm.py b/tests/test_snntorch/test_slstm.py index 32007172..8aa46b79 100644 --- a/tests/test_snntorch/test_slstm.py +++ b/tests/test_snntorch/test_slstm.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo # TO-DO: add avg/max-pooling tests @@ -20,6 +21,11 @@ def slstm_instance(): return snn.SLSTM(1, 2) +@pytest.fixture(scope="module") +def slstm_instance_surrogate(): + return snn.SLSTM(1, 2, surrogate_disable=True) + + @pytest.fixture(scope="module") def slstm_reset_zero_instance(): return snn.SLSTM(1, 2, reset_mechanism="zero") @@ -46,7 +52,7 @@ def slstm_hidden_reset_subtract_instance(): class TestSLSTM: - def test_sconv2dlstm(self, slstm_instance, input_): + def test_slstm(self, slstm_instance, input_): c, h = slstm_instance.init_slstm() h_rec = [] @@ -64,7 +70,7 @@ def test_sconv2dlstm(self, slstm_instance, input_): assert h.size() == (1, 2) assert spk.size() == (1, 2) - def test_sconv2dlstm_reset( + def test_slstm_reset( self, slstm_instance, slstm_reset_zero_instance, @@ -87,7 +93,7 @@ def test_sconv2dlstm_reset( assert lif2.reset_mechanism_val == 1 assert lif3.reset_mechanism_val == 0 - def test_sconv2dlstm_init_hidden(self, slstm_hidden_instance, input_): + def test_slstm_init_hidden(self, slstm_hidden_instance, input_): spk_rec = [] @@ -97,7 +103,7 @@ def test_sconv2dlstm_init_hidden(self, slstm_hidden_instance, input_): assert spk_rec[0].size() == (1, 2) - def test_sconv2dlstm_init_hidden_reset_zero( + def test_slstm_init_hidden_reset_zero( self, slstm_hidden_reset_zero_instance, input_ ): @@ -109,7 +115,7 @@ def test_sconv2dlstm_init_hidden_reset_zero( assert spk_rec[0].size() == (1, 2) - def test_sconv2dlstm_init_hidden_reset_subtract( + def test_slstm_init_hidden_reset_subtract( self, slstm_hidden_reset_subtract_instance, input_ ): @@ -120,3 +126,10 @@ def test_sconv2dlstm_init_hidden_reset_subtract( spk_rec.append(spk) assert spk_rec[0].size() == (1, 2) + + def test_slstm_compile_fullgraph( + self, slstm_instance_surrogate, input_ + ): + explanation = dynamo.explain(slstm_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 diff --git a/tests/test_snntorch/test_synaptic.py b/tests/test_snntorch/test_synaptic.py index 6ece23cc..262f5ca8 100644 --- a/tests/test_snntorch/test_synaptic.py +++ b/tests/test_snntorch/test_synaptic.py @@ -5,6 +5,7 @@ import pytest import snntorch as snn import torch +import torch._dynamo as dynamo @pytest.fixture(scope="module") @@ -16,6 +17,10 @@ def input_(): def synaptic_instance(): return snn.Synaptic(alpha=0.5, beta=0.5) +@pytest.fixture(scope="module") +def synaptic_instance_surrogate(): + return snn.Synaptic(alpha=0.5, beta=0.5, surrogate_disable=True) + @pytest.fixture(scope="module") def synaptic_reset_zero_instance(): @@ -123,3 +128,8 @@ def test_synaptic_init_hidden_reset_none( def test_synaptic_cases(self, synaptic_hidden_instance, input_): with pytest.raises(TypeError): synaptic_hidden_instance(input_, input_) + + def test_synaptic_compile_fullgraph(self, synaptic_instance_surrogate, input_): + explanation = dynamo.explain(synaptic_instance_surrogate)(input_[0]) + + assert explanation.graph_break_count == 0 \ No newline at end of file