Skip to content

Commit

Permalink
Merge pull request #292 from gekkom/fullgraph2
Browse files Browse the repository at this point in the history
Add torch.compile(fullgraph=True) support for more neuron models
  • Loading branch information
jeshraghian authored Mar 9, 2024
2 parents 21d8586 + 1d600aa commit 46afb8b
Show file tree
Hide file tree
Showing 18 changed files with 704 additions and 918 deletions.
223 changes: 98 additions & 125 deletions snntorch/_neurons/alpha.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

from .neurons import _SpikeTensor, _SpikeTorchConv, LIF
from .neurons import LIF


class Alpha(LIF):
Expand Down Expand Up @@ -119,118 +119,86 @@ def __init__(
self._alpha_register_buffer(alpha, learn_alpha)
self._alpha_cases()

if self.init_hidden:
self.syn_exc, self.syn_inh, self.mem = self.init_alpha()
self._init_mem()

# if reset_mechanism == "subtract":
# self.mem_residual = False
if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int

def forward(self, input_, syn_exc=False, syn_inh=False, mem=False):
def _init_mem(self):
syn_exc = torch.zeros(1)
syn_inh = torch.zeros(1)
mem = torch.zeros(1)

if (
hasattr(syn_exc, "init_flag")
or hasattr(syn_inh, "init_flag")
or hasattr(mem, "init_flag")
): # only triggered on first-pass
syn_exc, syn_inh, mem = _SpikeTorchConv(
syn_exc, syn_inh, mem, input_=input_
)
elif mem is False and hasattr(
self.mem, "init_flag"
): # init_hidden case
self.syn_exc, self.syn_inh, self.mem = _SpikeTorchConv(
self.syn_exc, self.syn_inh, self.mem, input_=input_
)
self.register_buffer("syn_exc", syn_exc)
self.register_buffer("syn_inh", syn_inh)
self.register_buffer("mem", mem)

# if hidden states are passed externally
if not self.init_hidden:
self.reset = self.mem_reset(mem)
syn_exc, syn_inh, mem = self._build_state_function(
input_, syn_exc, syn_inh, mem
)
def reset_mem(self):
self.syn_exc = torch.zeros_like(
self.syn_exc, device=self.syn_exc.device
)
self.syn_inh = torch.zeros_like(
self.syn_inh, device=self.syn_inh.device
)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)

if self.state_quant:
syn_exc = self.state_quant(syn_exc)
syn_inh = self.state_quant(syn_inh)
mem = self.state_quant(mem)
def init_alpha(self):
"""Deprecated, use :class:`Alpha.reset_mem` instead"""
self.reset_mem()
return self.syn_exc, self.syn_inh, self.mem

if self.inhibition:
spk = self.fire_inhibition(mem.size(0), mem)
def forward(self, input_, syn_exc=None, syn_inh=None, mem=None):

else:
spk = self.fire(mem)
if not syn_exc == None:
self.syn_exc = syn_exc

return spk, syn_exc, syn_inh, mem
if not syn_inh == None:
self.syn_inh = syn_inh

# if hidden states and outputs are instance variables
if self.init_hidden:
self._alpha_forward_cases(mem, syn_exc, syn_inh)
if not mem == None:
self.mem = mem

self.reset = self.mem_reset(self.mem)
(
self.syn_exc,
self.syn_inh,
self.mem,
) = self._build_state_function_hidden(input_)
if self.init_hidden and (
not mem == None or not syn_exc == None or not syn_inh == None
):
raise TypeError(
"When `init_hidden=True`, Alpha expects 1 input argument."
)

if self.state_quant:
self.syn_exc = self.state_quant(self.syn_exc)
self.syn_inh = self.state_quant(self.syn_inh)
self.mem = self.state_quant(self.mem)
if not self.syn_exc.shape == input_.shape:
self.syn_exc = torch.zeros_like(input_, device=self.syn_exc.device)

if self.inhibition:
self.spk = self.fire_inhibition(self.mem.size(0), self.mem)
if not self.syn_inh.shape == input_.shape:
self.syn_inh = torch.zeros_like(input_, device=self.syn_inh.device)

else:
self.spk = self.fire(self.mem)
if not self.mem.shape == input_.shape:
self.mem = torch.zeros_like(input_, device=self.mem.device)

if self.output:
return self.spk, self.syn_exc, self.syn_inh, self.mem
else:
return self.spk
self.reset = self.mem_reset(self.mem)
self.syn_exc, self.syn_inh, self.mem = self.state_function(input_)

def _base_state_function(self, input_, syn_exc, syn_inh, mem):
base_fn_syn_exc = self.alpha.clamp(0, 1) * syn_exc + input_
base_fn_syn_inh = self.beta.clamp(0, 1) * syn_inh - input_
tau_alpha = (
torch.log(self.alpha.clamp(0, 1))
/ (
torch.log(self.beta.clamp(0, 1))
- torch.log(self.alpha.clamp(0, 1))
)
+ 1
)
base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh)
return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem
if self.state_quant:
self.syn_exc = self.state_quant(self.syn_exc)
self.syn_inh = self.state_quant(self.syn_inh)
self.mem = self.state_quant(self.mem)

def _base_state_reset_sub_function(self, input_, syn_inh):
syn_exc_reset = self.threshold
syn_inh_reset = self.beta.clamp(0, 1) * syn_inh - input_
mem_reset = 0
return syn_exc_reset, syn_inh_reset, mem_reset
if self.inhibition:
spk = self.fire_inhibition(self.mem.size(0), self.mem)
else:
spk = self.fire(self.mem)

def _build_state_function(self, input_, syn_exc, syn_inh, mem):
if self.reset_mechanism_val == 0:
state_fn = tuple(
map(
lambda x, y: x - self.reset * y,
self._base_state_function(input_, syn_exc, syn_inh, mem),
self._base_state_reset_sub_function(input_, syn_inh),
)
)
elif self.reset_mechanism_val == 1: # reset to zero
state_fn = tuple(
map(
lambda x, y: x - self.reset * y,
self._base_state_function(input_, syn_exc, syn_inh, mem),
self._base_state_function(input_, syn_exc, syn_inh, mem),
)
)
elif self.reset_mechanism_val == 2: # no reset, pure integration
state_fn = self._base_state_function(input_, syn_exc, syn_inh, mem)
return state_fn
if self.output:
return spk, self.syn_exc, self.syn_inh, self.mem
elif self.init_hidden:
return spk
else:
return spk, self.syn_exc, self.syn_inh, self.mem

def _base_state_function_hidden(self, input_):
def _base_state_function(self, input_):
base_fn_syn_exc = self.alpha.clamp(0, 1) * self.syn_exc + input_
base_fn_syn_inh = self.beta.clamp(0, 1) * self.syn_inh - input_
tau_alpha = (
Expand All @@ -244,32 +212,34 @@ def _base_state_function_hidden(self, input_):
base_fn_mem = tau_alpha * (base_fn_syn_exc + base_fn_syn_inh)
return base_fn_syn_exc, base_fn_syn_inh, base_fn_mem

def _base_state_reset_sub_function_hidden(self, input_):
def _base_state_reset_sub_function(self, input_):
syn_exc_reset = self.threshold
syn_inh_reset = self.beta.clamp(0, 1) * self.syn_inh - input_
mem_reset = -self.syn_inh
return syn_exc_reset, syn_inh_reset, mem_reset

def _build_state_function_hidden(self, input_):
if self.reset_mechanism_val == 0: # reset by subtraction
state_fn = tuple(
map(
lambda x, y: x - self.reset * y,
self._base_state_function_hidden(input_),
self._base_state_reset_sub_function_hidden(input_),
)
)
elif self.reset_mechanism_val == 1: # reset to zero
state_fn = tuple(
map(
lambda x, y: x - self.reset * y,
self._base_state_function_hidden(input_),
self._base_state_function_hidden(input_),
)
)
elif self.reset_mechanism_val == 2: # no reset, pure integration
state_fn = self._base_state_function_hidden(input_)
return state_fn
def _base_sub(self, input_):
syn_exec, syn_inh, mem = self._base_state_function(input_)
syn_exec2, syn_inh2, mem2 = self._base_state_reset_sub_function(input_)

syn_exec -= syn_exec2 * self.reset
syn_inh -= syn_inh2 * self.reset
mem -= mem2 * self.reset

return syn_exec, syn_inh, mem

def _base_zero(self, input_):
syn_exec, syn_inh, mem = self._base_state_function(input_)
syn_exec2, syn_inh2, mem2 = self._base_state_function(input_)

syn_exec -= syn_exec2 * self.reset
syn_inh -= syn_inh2 * self.reset
mem -= mem2 * self.reset

return syn_exec, syn_inh, mem

def _base_int(self, input_):
return self._base_state_function(input_)

def _alpha_register_buffer(self, alpha, learn_alpha):
if not isinstance(alpha, torch.Tensor):
Expand All @@ -291,12 +261,6 @@ def _alpha_cases(self):
"tau_alpha = log(alpha)/log(beta) - log(alpha) + 1"
)

def _alpha_forward_cases(self, mem, syn_exc, syn_inh):
if mem is not False or syn_exc is not False or syn_inh is not False:
raise TypeError(
"When `init_hidden=True`, Alpha expects 1 input argument."
)

@classmethod
def detach_hidden(cls):
"""Used to detach hidden states from the current graph.
Expand All @@ -315,6 +279,15 @@ def reset_hidden(cls):
variables."""
for layer in range(len(cls.instances)):
if isinstance(cls.instances[layer], Alpha):
cls.instances[layer].syn_exc = _SpikeTensor(init_flag=False)
cls.instances[layer].syn_inh = _SpikeTensor(init_flag=False)
cls.instances[layer].mem = _SpikeTensor(init_flag=False)
cls.instances[layer].syn_exc = torch.zeros_like(
cls.instances[layer].syn_exc,
device=cls.instances[layer].syn_exc.device,
)
cls.instances[layer].syn_inh = torch.zeros_like(
cls.instances[layer].syn_inh,
device=cls.instances[layer].syn_inh.device,
)
cls.instances[layer].mem = torch.zeros_like(
cls.instances[layer].mem,
device=cls.instances[layer].mem.device,
)
Loading

0 comments on commit 46afb8b

Please sign in to comment.