Skip to content

Commit

Permalink
Parameterize score aggregation function
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 19, 2023
1 parent 84fb413 commit eee0735
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 7 deletions.
26 changes: 21 additions & 5 deletions supirfactor_dynamical/models/_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,22 +659,38 @@ def score(
self,
dataloader,
loss_function=torch.nn.MSELoss(),
reduction='sum',
**kwargs
):

if dataloader is None:
return None

_score = 0
_score = []

with torch.no_grad():
for data in dataloader:
_score += loss_function(
self._slice_data_and_forward(data),
self.output_data(data),
**kwargs
_score.append(
loss_function(
self._slice_data_and_forward(data),
self.output_data(data),
**kwargs
)
)

_score = torch.Tensor(_score)

if reduction == 'mean':
_score = torch.mean(_score)
elif reduction == 'sum':
_score = torch.sum(_score)
elif reduction is None:
pass
else:
raise ValueError(
f'reduction must be `mean`, `sum` or None; {reduction} passed'
)

return _score


Expand Down
6 changes: 4 additions & 2 deletions supirfactor_dynamical/postprocessing/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,12 +158,14 @@ def add_classification_metrics_to_dataframe(
result_df[column_prefix + "accuracy"] = model_object.score(
training_dataloader,
loss_function=torcheval.metrics.functional.multilabel_accuracy,
criteria='hamming'
criteria='hamming',
reduction='mean'
).item()

result_df[column_prefix + "auprc"] = model_object.score(
training_dataloader,
loss_function=torcheval.metrics.functional.multilabel_auprc
loss_function=torcheval.metrics.functional.multilabel_auprc,
reduction='mean'
).item()

result_df[column_prefix + "cross_entropy"] = model_object.score(
Expand Down

0 comments on commit eee0735

Please sign in to comment.