How to record the firing time of each neuron in an spiking neural network? #189
Answered
by
fangwei123456
xiaoxiaochaochao
asked this question in
Q&A
Replies: 2 comments 1 reply
-
You can use pytorch's forward hook to save |
Beta Was this translation helpful? Give feedback.
1 reply
-
import torch
import torch.nn as nn
from spikingjelly.clock_driven import neuron, functional, layer
net = nn.Sequential(
nn.Linear(8, 4),
neuron.IFNode(),
nn.Linear(4, 2),
neuron.IFNode()
)
def firing_time_hook(m, x, y):
m.t += 1
with torch.no_grad():
mask = (y == 1.)
m.t_f.append( (y * m.t )[mask])
net[1].register_memory('t', 0)
net[1].register_memory('t_f', [])
# after register_memory, these values can be reset when calling reset function
net[1].register_forward_hook(firing_time_hook)
T = 16
N = 1
x = torch.rand([T, N, 8])
with torch.no_grad():
for t in range(x.shape[0]):
net(x[t])
print(net[1].t_f)
|
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
xiaoxiaochaochao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Sir,
How to record the firing time of each neuron in an spiking neural network?
Beta Was this translation helpful? Give feedback.
All reactions