Skip to content

Commit

Permalink
Merge pull request #259 from mims-harvard/neurips_benchmarks
Browse files Browse the repository at this point in the history
Neurips benchmarks -- evaluate_many on scdti
  • Loading branch information
amva13 authored May 8, 2024
2 parents 5fbd717 + 5d8e4c0 commit 4c79d17
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
import sys


if __name__ == '__main__':
loader = unittest.TestLoader()
start_dir = 'tdc/test'
Expand Down
17 changes: 17 additions & 0 deletions tdc/benchmark_group/scdti_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,20 @@ def evaluate(self, y_pred):
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
return [precision, recall, accuracy, f1]

def evaluate_many(self, preds):
from numpy import mean, std
if len(preds) < 5:
raise Exception(
"Run your model on at least 5 seeds to compare results and provide your outputs in preds."
)
out = dict()
preds = [self.evaluate(p) for p in preds]
out["precision"] = (mean([x[0] for x in preds]),
std([x[0] for x in preds]))
out["recall"] = (mean([x[1] for x in preds]), std([x[1] for x in preds
]))
out["accuracy"] = (mean([x[2] for x in preds]),
std([x[2] for x in preds]))
out["f1"] = (mean([x[3] for x in preds]), std([x[3] for x in preds]))
return out
4 changes: 4 additions & 0 deletions tdc/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def test_SCDTI_benchmark(self):
zero_pred = [0] * len(y_true)
results = group.evaluate(zero_pred)
assert results[-1] != 1.0 # should not be perfect F1 score
many_results = group.evaluate_many([y_true] * 5)
assert "f1" in many_results
assert len(many_results["f1"]
) == 2 # should include mean and standard deviation


if __name__ == "__main__":
Expand Down

0 comments on commit 4c79d17

Please sign in to comment.