From d540f8abada0b437b9b9f5f2fb776511bfe7e978 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dylan=20Perdig=C3=A3o?= Date: Sat, 11 May 2024 16:16:34 +0100 Subject: [PATCH] Update loss.py --- snntorch/functional/loss.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/snntorch/functional/loss.py b/snntorch/functional/loss.py index 693802af..f1e19da9 100644 --- a/snntorch/functional/loss.py +++ b/snntorch/functional/loss.py @@ -102,6 +102,7 @@ def __init__(self, population_code=False, num_classes=False, reduction='mean', w def _compute_loss(self, spk_out, targets): device, num_steps, num_outputs = self._prediction_check(spk_out) + log_softmax_fn = nn.LogSoftmax(dim=-1) if self.population_code: for idx in range(self.num_classes): @@ -113,10 +114,10 @@ def _compute_loss(self, spk_out, targets): ), ] weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device) - - log_softmax_fn = nn.LogSoftmax(dim=-1) - loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights) - + loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights) + else: + loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight) + log_p_y = log_softmax_fn(spk_out) loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1) loss = torch.zeros(loss_shape, dtype=dtype, device=device)