Skip to content

Commit

Permalink
Update loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanperdigao committed May 11, 2024
1 parent fa80296 commit d540f8a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions snntorch/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, population_code=False, num_classes=False, reduction='mean', w

def _compute_loss(self, spk_out, targets):
device, num_steps, num_outputs = self._prediction_check(spk_out)
log_softmax_fn = nn.LogSoftmax(dim=-1)

if self.population_code:
for idx in range(self.num_classes):
Expand All @@ -113,10 +114,10 @@ def _compute_loss(self, spk_out, targets):
),
]
weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device)

log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights)

loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights)
else:
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)
log_p_y = log_softmax_fn(spk_out)
loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)
Expand Down

0 comments on commit d540f8a

Please sign in to comment.