diff --git a/supirfactor_dynamical/models/__init__.py b/supirfactor_dynamical/models/__init__.py index 987b154..9ab1385 100644 --- a/supirfactor_dynamical/models/__init__.py +++ b/supirfactor_dynamical/models/__init__.py @@ -11,7 +11,10 @@ from .biophysical_model import SupirFactorBiophysical from .decay_model import DecayModule -from .chromatin_model import ChromatinAwareModel +from .chromatin_model import ( + ChromatinModule, + ChromatinAwareModel +) # Standard mixins from ._base_velocity_model import ( @@ -29,7 +32,8 @@ TFLSTMDecoder.type_name: TFLSTMDecoder, SupirFactorBiophysical.type_name: SupirFactorBiophysical, DecayModule.type_name: DecayModule, - ChromatinAwareModel.type_name: ChromatinAwareModel + ChromatinAwareModel.type_name: ChromatinAwareModel, + ChromatinModule.type_name: ChromatinModule } _not_velocity = [ diff --git a/supirfactor_dynamical/models/_base_trainer.py b/supirfactor_dynamical/models/_base_trainer.py index 0f8f4a9..2434e90 100644 --- a/supirfactor_dynamical/models/_base_trainer.py +++ b/supirfactor_dynamical/models/_base_trainer.py @@ -654,6 +654,29 @@ def predict( **kwargs ) + @torch.inference_mode() + def score( + self, + dataloader, + loss_function=torch.nn.MSELoss(), + **kwargs + ): + + if dataloader is None: + return None + + _score = 0 + + with torch.no_grad(): + for data in dataloader: + _score += loss_function( + self._slice_data_and_forward(data), + self.output_data(data), + **kwargs + ) + + return _score + def _shuffle_time_data(dl): try: diff --git a/supirfactor_dynamical/postprocessing/results.py b/supirfactor_dynamical/postprocessing/results.py index c588657..9df3422 100644 --- a/supirfactor_dynamical/postprocessing/results.py +++ b/supirfactor_dynamical/postprocessing/results.py @@ -1,5 +1,7 @@ import pandas as pd import numpy as np +import torcheval.metrics +import torch from inferelator.postprocessing import ResultsProcessor @@ -143,6 +145,42 @@ def process_results_to_dataframes( return results, loss_df, time_dependent_loss +def add_classification_metrics_to_dataframe( + result_df, + model_object, + training_dataloader, + validation_dataloader=None, + column_prefix=None +): + if column_prefix is None: + column_prefix = "training_" + + result_df[column_prefix + "accuracy"] = model_object.score( + training_dataloader, + loss_function=torcheval.metrics.functional.multilabel_accuracy + ) + + result_df[column_prefix + "auprc"] = model_object.score( + training_dataloader, + loss_function=torcheval.metrics.functional.multilabel_auprc + ) + + result_df[column_prefix + "cross_entropy"] = model_object.score( + training_dataloader, + loss_function=torch.nn.BCELoss() + ) + + if validation_dataloader is not None: + result_df = add_classification_metrics_to_dataframe( + result_df, + model_object, + validation_dataloader, + column_prefix="validation_" + ) + + return result_df + + def process_combined_results( results, gold_standard=None, diff --git a/supirfactor_dynamical/tests/test_chromatin.py b/supirfactor_dynamical/tests/test_chromatin.py index d0a928f..6fd6f8c 100644 --- a/supirfactor_dynamical/tests/test_chromatin.py +++ b/supirfactor_dynamical/tests/test_chromatin.py @@ -21,6 +21,11 @@ ChromatinAwareModel ) +from supirfactor_dynamical.postprocessing.results import ( + process_results_to_dataframes, + add_classification_metrics_to_dataframe +) + class TestChromatinTraining(unittest.TestCase): @@ -56,6 +61,43 @@ def test_train(self): 10 ) + def test_classification_results(self): + + model = ChromatinModule( + 4, + 25, + k=10 + ) + + model.train_model( + self.dataloader, + 10 + ) + + results, losses, _ = process_results_to_dataframes( + model, + None, + model_type='chromatin', + leader_columns=["Name"], + leader_values=["Value"] + ) + + self.assertEqual( + results.shape, + (1, 4) + ) + + results = add_classification_metrics_to_dataframe( + results, + model, + self.dataloader + ) + + self.assertEqual( + results.shape, + (1, 7) + ) + class TestChromatinTrainingSparse(TestChromatinTraining):