Skip to content

Commit

Permalink
Add multilabel result aggregation helper
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Oct 19, 2023
1 parent d4ec480 commit 7c62187
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 2 deletions.
8 changes: 6 additions & 2 deletions supirfactor_dynamical/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

from .biophysical_model import SupirFactorBiophysical
from .decay_model import DecayModule
from .chromatin_model import ChromatinAwareModel
from .chromatin_model import (
ChromatinModule,
ChromatinAwareModel
)

# Standard mixins
from ._base_velocity_model import (
Expand All @@ -29,7 +32,8 @@
TFLSTMDecoder.type_name: TFLSTMDecoder,
SupirFactorBiophysical.type_name: SupirFactorBiophysical,
DecayModule.type_name: DecayModule,
ChromatinAwareModel.type_name: ChromatinAwareModel
ChromatinAwareModel.type_name: ChromatinAwareModel,
ChromatinModule.type_name: ChromatinModule
}

_not_velocity = [
Expand Down
23 changes: 23 additions & 0 deletions supirfactor_dynamical/models/_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,29 @@ def predict(
**kwargs
)

@torch.inference_mode()
def score(
self,
dataloader,
loss_function=torch.nn.MSELoss(),
**kwargs
):

if dataloader is None:
return None

_score = 0

with torch.no_grad():
for data in dataloader:
_score += loss_function(
self._slice_data_and_forward(data),
self.output_data(data),
**kwargs
)

return _score


def _shuffle_time_data(dl):
try:
Expand Down
38 changes: 38 additions & 0 deletions supirfactor_dynamical/postprocessing/results.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pandas as pd
import numpy as np
import torcheval.metrics
import torch

from inferelator.postprocessing import ResultsProcessor

Expand Down Expand Up @@ -143,6 +145,42 @@ def process_results_to_dataframes(
return results, loss_df, time_dependent_loss


def add_classification_metrics_to_dataframe(
result_df,
model_object,
training_dataloader,
validation_dataloader=None,
column_prefix=None
):
if column_prefix is None:
column_prefix = "training_"

result_df[column_prefix + "accuracy"] = model_object.score(
training_dataloader,
loss_function=torcheval.metrics.functional.multilabel_accuracy
)

result_df[column_prefix + "auprc"] = model_object.score(
training_dataloader,
loss_function=torcheval.metrics.functional.multilabel_auprc
)

result_df[column_prefix + "cross_entropy"] = model_object.score(
training_dataloader,
loss_function=torch.nn.BCELoss()
)

if validation_dataloader is not None:
result_df = add_classification_metrics_to_dataframe(
result_df,
model_object,
validation_dataloader,
column_prefix="validation_"
)

return result_df


def process_combined_results(
results,
gold_standard=None,
Expand Down
42 changes: 42 additions & 0 deletions supirfactor_dynamical/tests/test_chromatin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
ChromatinAwareModel
)

from supirfactor_dynamical.postprocessing.results import (
process_results_to_dataframes,
add_classification_metrics_to_dataframe
)


class TestChromatinTraining(unittest.TestCase):

Expand Down Expand Up @@ -56,6 +61,43 @@ def test_train(self):
10
)

def test_classification_results(self):

model = ChromatinModule(
4,
25,
k=10
)

model.train_model(
self.dataloader,
10
)

results, losses, _ = process_results_to_dataframes(
model,
None,
model_type='chromatin',
leader_columns=["Name"],
leader_values=["Value"]
)

self.assertEqual(
results.shape,
(1, 4)
)

results = add_classification_metrics_to_dataframe(
results,
model,
self.dataloader
)

self.assertEqual(
results.shape,
(1, 7)
)


class TestChromatinTrainingSparse(TestChromatinTraining):

Expand Down

0 comments on commit 7c62187

Please sign in to comment.