Skip to content

Commit

Permalink
population coding for ce_rate_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanperdigao committed May 11, 2024
1 parent 47b3176 commit fa80296
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions snntorch/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,30 @@ class ce_rate_loss(LossFunctions):
"""

def __init__(self, reduction='mean', weight=None):
def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "ce_rate_loss"

def _compute_loss(self, spk_out, targets):
device, num_steps, _ = self._prediction_check(spk_out)
device, num_steps, num_outputs = self._prediction_check(spk_out)

if self.population_code:
for idx in range(self.num_classes):
spk_out[
:,
:,
int(num_outputs * idx / self.num_classes) : int(
num_outputs * (idx + 1) / self.num_classes
),
]
weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device)

log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights)

log_p_y = log_softmax_fn(spk_out)

loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)

Expand Down

0 comments on commit fa80296

Please sign in to comment.