Skip to content

How to record the firing time of each neuron in an spiking neural network? #189

Discussion options

You must be logged in to vote
import torch
import torch.nn as nn
from spikingjelly.clock_driven import neuron, functional, layer


net = nn.Sequential(
    nn.Linear(8, 4),
    neuron.IFNode(),
    nn.Linear(4, 2),
    neuron.IFNode()
)

def firing_time_hook(m, x, y):
    m.t += 1
    with torch.no_grad():
        mask = (y == 1.)
        m.t_f.append( (y * m.t )[mask])

net[1].register_memory('t', 0)
net[1].register_memory('t_f', [])
# after register_memory, these values can be reset when calling reset function

net[1].register_forward_hook(firing_time_hook)
T = 16
N = 1
x = torch.rand([T, N, 8])

with torch.no_grad():
    for t in range(x.shape[0]):
        net(x[t])

print(net[1].t_f)
[tensor([]), tensor([]), tenso…

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@xiaoxiaochaochao
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by xiaoxiaochaochao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants