Skip to content

Commit

Permalink
Misc fixes & cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gekkom committed Feb 22, 2024
1 parent 90ad720 commit 1d600aa
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 69 deletions.
62 changes: 2 additions & 60 deletions snntorch/_neurons/neurons.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import inspect
from warnings import warn
from snntorch.surrogate import StraightThroughEstimator, atan, straight_through_estimator
from snntorch.surrogate import atan
import torch
import torch.nn as nn


__all__ = [
"SpikingNeuron",
"LIF",
"_SpikeTensor",
"_SpikeTorchConv",
]

dtype = torch.float
Expand Down Expand Up @@ -234,6 +231,7 @@ def zeros(*args):
def _surrogate_bypass(input_):
return (input_ > 0).float()


class LIF(SpikingNeuron):
"""Parent class for leaky integrate and fire neuron models."""

Expand Down Expand Up @@ -298,59 +296,3 @@ def _V_register_buffer(self, V, learn_V):
self.V = nn.Parameter(V)
else:
self.register_buffer("V", V)

@staticmethod
def init_rsynaptic():
"""
Used to initialize spk, syn and mem as an empty SpikeTensor.
``init_flag`` is used as an attribute in the forward pass to convert
the hidden states to the same as the input.
"""
spk = _SpikeTensor(init_flag=False)
syn = _SpikeTensor(init_flag=False)
mem = _SpikeTensor(init_flag=False)

return spk, syn, mem


class _SpikeTensor(torch.Tensor):
"""Inherits from torch.Tensor with additional attributes.
``init_flag`` is set at the time of initialization.
When called in the forward function of any neuron, they are parsed and
replaced with a torch.Tensor variable.
"""

@staticmethod
def __new__(cls, *args, init_flag=False, **kwargs):
return super().__new__(cls, *args, **kwargs)

def __init__(
self,
*args,
init_flag=True,
):
# super().__init__() # optional
self.init_flag = init_flag


def _SpikeTorchConv(*args, input_):
"""Convert SpikeTensor to torch.Tensor of the same size as ``input_``."""

states = []
# if len(input_.size()) == 0:
# _batch_size = 1 # assume batch_size=1 if 1D input
# else:
# _batch_size = input_.size(0)
if (
len(args) == 1 and type(args) is not tuple
): # if only one hidden state, make it iterable
args = (args,)
for arg in args:
arg = arg.to("cpu")
arg = torch.Tensor(arg) # wash away the SpikeTensor class
arg = torch.zeros_like(input_, requires_grad=True)
states.append(arg)
if len(states) == 1: # otherwise, list isn't unpacked
return states[0]

return states
5 changes: 1 addition & 4 deletions tests/test_snntorch/test_leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,6 @@ def test_leaky_hidden_learn_graded_instance(
assert factor.requires_grad

def test_leaky_compile_fullgraph(self, leaky_instance_surrogate, input_):
# net = nn.Sequential(
# snn.Leaky(beta=0.5, init_hidden=True, surrogate_disable=True),
# )

explanation = dynamo.explain(leaky_instance_surrogate)(input_[0])

assert explanation.graph_break_count == 0
10 changes: 5 additions & 5 deletions tests/test_snntorch/test_slstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def slstm_hidden_reset_subtract_instance():


class TestSLSTM:
def test_sconv2dlstm(self, slstm_instance, input_):
def test_slstm(self, slstm_instance, input_):
c, h = slstm_instance.init_slstm()

h_rec = []
Expand All @@ -70,7 +70,7 @@ def test_sconv2dlstm(self, slstm_instance, input_):
assert h.size() == (1, 2)
assert spk.size() == (1, 2)

def test_sconv2dlstm_reset(
def test_slstm_reset(
self,
slstm_instance,
slstm_reset_zero_instance,
Expand All @@ -93,7 +93,7 @@ def test_sconv2dlstm_reset(
assert lif2.reset_mechanism_val == 1
assert lif3.reset_mechanism_val == 0

def test_sconv2dlstm_init_hidden(self, slstm_hidden_instance, input_):
def test_slstm_init_hidden(self, slstm_hidden_instance, input_):

spk_rec = []

Expand All @@ -103,7 +103,7 @@ def test_sconv2dlstm_init_hidden(self, slstm_hidden_instance, input_):

assert spk_rec[0].size() == (1, 2)

def test_sconv2dlstm_init_hidden_reset_zero(
def test_slstm_init_hidden_reset_zero(
self, slstm_hidden_reset_zero_instance, input_
):

Expand All @@ -115,7 +115,7 @@ def test_sconv2dlstm_init_hidden_reset_zero(

assert spk_rec[0].size() == (1, 2)

def test_sconv2dlstm_init_hidden_reset_subtract(
def test_slstm_init_hidden_reset_subtract(
self, slstm_hidden_reset_subtract_instance, input_
):

Expand Down

0 comments on commit 1d600aa

Please sign in to comment.