Skip to content

Commit

Permalink
Merge pull request #321 from DylanPerdigao/master
Browse files Browse the repository at this point in the history
Support of Population Coding for ce_rate_loss
  • Loading branch information
jeshraghian authored May 12, 2024
2 parents 47b3176 + d540f8a commit e0968b2
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions snntorch/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,31 @@ class ce_rate_loss(LossFunctions):
"""

def __init__(self, reduction='mean', weight=None):
def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "ce_rate_loss"

def _compute_loss(self, spk_out, targets):
device, num_steps, _ = self._prediction_check(spk_out)
device, num_steps, num_outputs = self._prediction_check(spk_out)
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

if self.population_code:
for idx in range(self.num_classes):
spk_out[
:,
:,
int(num_outputs * idx / self.num_classes) : int(
num_outputs * (idx + 1) / self.num_classes
),
]
weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights)
else:
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

log_p_y = log_softmax_fn(spk_out)

loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)

Expand Down

0 comments on commit e0968b2

Please sign in to comment.