求助:在执行 loss.backwad() 时发生 RuntimeError, #194
Answered
by
fangwei123456
Maninnlack
asked this question in
Q&A
-
我使用spikingjelly框架搭建了一个网络,但是在训练到第二个batch的时候,执行 import torch
import torch.nn as nn
from utils import neuron, surrogate
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 8, kernel_size=3,padding=1),
nn.BatchNorm2d(8),
neuron.IFNode(surrogate_function=surrogate.ATan()),
nn.Conv2d(8, 8, kernel_size=3,padding=1),
nn.BatchNorm2d(8),
neuron.IFNode(surrogate_function=surrogate.ATan())
)
self.fc1 = nn.Sequential(
nn.Flatten(),
nn.Linear(8 * 16 * 16, 1, bias=False),
neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan())
)
def forward(self, x):
c1 = self.conv1(x)
output = self.fc1(c1)
return output
def main():
net = Net()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for i in range(5):
x = torch.rand([3, 3, 16, 16])
y = torch.rand([3, 1])
output = net(x)
loss = criterion(output, y)
print(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if __name__ == '__main__':
main() 网络在输出两个loss之后会报错:
我实在是不知道该怎么办了,所以想在这求助一下。 |
Beta Was this translation helpful? Give feedback.
Answered by
fangwei123456
Mar 31, 2022
Replies: 1 comment 1 reply
-
在 |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
fangwei123456
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
在
optimizer.step()
后加入reset
,因为SJ框架把状态存在网络内部,下一批数据计算时需要清除之前的状态:spikingjelly/spikingjelly/clock_driven/examples/lif_fc_mnist.py
Line 138 in 0550c0e