Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Online Training Through Time (OTTT) Training Function Error #593

Open
1 of 4 tasks
tsumme1 opened this issue Nov 23, 2024 · 1 comment
Open
1 of 4 tasks

Online Training Through Time (OTTT) Training Function Error #593

tsumme1 opened this issue Nov 23, 2024 · 1 comment

Comments

@tsumme1
Copy link

tsumme1 commented Nov 23, 2024

Issue type

  • Bug Report
  • Feature Request
  • Help wanted
  • Other

SpikingJelly version

0.0.0.0.15

Description

When trying to train a network with OTTT, the function runs into an error, seemingly because the neuron state is not detached from the computational graph after each time-step. Adding functional.detach_net(model) inside functional.ottt_online_training eliminated the error for me.

Code to reproduce the error

import torch
from torch import nn
from torch.nn import functional as F

from spikingjelly.activation_based import neuron, layer, functional

net = layer.OTTTSequential(
    nn.Linear(8, 4),
    neuron.OTTTLIFNode(),
    nn.Linear(4, 2),
    neuron.LIFNode()
)

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

T = 4
N = 2
online = True
for epoch in range(2):

    x_seq = torch.rand([N, T, 8])
    target_seq = torch.rand([N, T, 2])

    functional.ottt_online_training(model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, online=online)
    functional.reset_net(net)

Error

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
@Met4physics
Copy link
Contributor

Met4physics commented Dec 2, 2024

net = layer.OTTTSequential(
    nn.Linear(8, 4),
    neuron.OTTTLIFNode(),
    nn.Linear(4, 2),
    neuron.OTTTLIFNode()
)

What about using neuron.OTTTLIFNode() in the last layer?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants