Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify scdti group to use pinnacle single-cell network dataset for ib… #313

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 108 additions & 40 deletions tdc/benchmark_group/scdti_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

from .base_group import BenchmarkGroup
from ..resource.pinnacle import PINNACLE


class SCDTIGroup(BenchmarkGroup):
Expand All @@ -15,52 +16,119 @@ class SCDTIGroup(BenchmarkGroup):

def __init__(self, path="./data", file_format="csv"):
"""Create an SCDTI benchmark group class."""
# super().__init__(name="SCDTI_Group", path=path)
self.name = "SCDTI_Group"
self.path = os.path.join(path, self.name)
# self.datasets = ["opentargets_dti"]
self.dataset_names = ["opentargets_dti"]
self.file_format = file_format
self.split = None
self.p = PINNACLE()

def get_train_valid_split(self):
def precision_recall_at_k(self, y, preds, k: int = 5):
"""
Calculate recall@k and precision@k for binary classification.
"""
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, average_precision_score
assert preds.shape[0] == y.shape[0]
assert k > 0
if k > preds.shape[0]:
return -1, -1, -1, -1

# Sort the scores and the labels by the scores
sorted_indices = np.argsort(preds.flatten())[::-1]
sorted_preds = preds[sorted_indices]
sorted_y = y[sorted_indices]

# Get the scores of the k highest predictions
topk_preds = sorted_preds[:k]
topk_y = sorted_y[:k]

# Calculate the recall@k and precision@k
recall_k = np.sum(topk_y) / np.sum(y)
precision_k = np.sum(topk_y) / k

# Calculate the accuracy@k
accuracy_k = accuracy_score(topk_y, topk_preds > 0.5)

# Calculate the AP@k
ap_k = average_precision_score(topk_y, topk_preds)

return recall_k, precision_k, accuracy_k, ap_k

def get_train_valid_split(self, seed=1):
"""parameters included for compatibility. this benchmark has a fixed train/test split."""
from ..resource.dataloader import DataLoader
if self.split is None:
dl = DataLoader(name="opentargets_dti")
self.split = dl.get_split()
return self.split["train"], self.split["dev"]

def get_test(self):
from ..resource.dataloader import DataLoader
if self.split is None:
dl = DataLoader(name="opentargets_dti")
self.split = dl.get_split()
return self.split["test"]

def evaluate(self, y_pred):
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
y_true = self.get_test()["Y"]
# Calculate metrics
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
return [precision, recall, accuracy, f1]

def evaluate_many(self, preds):
train = self.p.get_exp_data(seed=seed, split="train")
val = self.p.get_exp_data(seed=seed, split="val")
return {"train": train, "val": val}

def get_test(self, seed=1):
return {"test": self.p.get_exp_data(seed=seed, split="test")}

def evaluate(self, y_pred, k=5, top_k=20):
from numpy import mean
from sklearn.metrics import roc_auc_score
y_true = self.get_test()["test"]
assert "preds" in y_pred.columns, "require 'preds' prediction label in input df"
assert "cell_type_label" in y_pred.columns, "require cell_type_label in input df"
assert "disease" in y_pred.columns, "require 'disease' in input df"
cells = y_true["cell_type_label"].unique()
diseases = y_true["disease"].unique()
assert len(cells) == len(
y_pred["cell_type_label"].unique()
), "number of cell types in input df and test df do not match. expected {}".format(
len(cells))
assert len(diseases) == len(
y_pred["disease"].unique()
), "number of diseases in input df do not match test df. expected {}".format(
len(diseases))
results = {d: [] for d in diseases}
for disease in diseases:
for cell in cells:
preds = y_pred[(y_pred["disease"] == disease) &
(y_pred["cell_type_label"] == cell)]
yt = y_true[(y_true["disease"] == disease) &
(y_true["cell_type_label"] == cell)]
assert len(preds) == len(
yt
), "mismatch in length of predictions and results for a specific disease {} and cell type {}".format(
disease, cell)
if len(yt) == 0:
continue
auc = roc_auc_score(yt["y"], preds["preds"])
recall_k, precision_k, accuracy_k, ap_k = self.precision_recall_at_k(
yt["y"].values, preds["preds"].values, k=k)
results[disease].append({
"auc": auc,
"recall": recall_k,
"precision": precision_k,
"accuracy": accuracy_k,
"ap": ap_k
})
# for now, we benchmark with only ap@k with top 20 cells
for d, scores in results.items():
assert type(
scores
) == list, "scores should be a list. got {} with value {}".format(
scores, type(scores))
assert type(scores[0]
) == dict, "scores should contain dictionary of metrics"
assert "ap" in scores[0], "scores should include 'ap'"
topk_cells = [
x["ap"] for x in sorted(scores, key=lambda s: s["ap"])[-top_k:]
]
results[d] = mean(topk_cells)
return results

def evaluate_many(self, preds: list):
from numpy import mean, std
assert type(
preds
) == list, "expected preds to be a list containing prediction dataframes for multiple seeds"
if len(preds) < 5:
raise Exception(
"Run your model on at least 5 seeds to compare results and provide your outputs in preds."
)
out = dict()
preds = [self.evaluate(p) for p in preds]
out["precision"] = (mean([x[0] for x in preds]),
std([x[0] for x in preds]))
out["recall"] = (mean([x[1] for x in preds]), std([x[1] for x in preds
]))
out["accuracy"] = (mean([x[2] for x in preds]),
std([x[2] for x in preds]))
out["f1"] = (mean([x[3] for x in preds]), std([x[3] for x in preds]))
return out
evals = [self.evaluate(x) for x in preds]
diseases = preds[0]["disease"].unique()
return {
d: [mean([x[d] for x in evals]),
std([x[d] for x in evals])] for d in diseases
}
10 changes: 8 additions & 2 deletions tdc/resource/pinnacle.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def get_exp_data(self, seed=1, split="train"):
# clean data directory
file_list = os.listdir("./data")
for file in file_list:
os.remove(os.path.join("./data", file))
try:
os.remove(os.path.join("./data", file))
except:
continue
print("downloading pinancle zip data...")
zip_data_download_wrapper(
filename, "./data",
Expand All @@ -94,7 +97,10 @@ def get_exp_data(self, seed=1, split="train"):
f for f in os.listdir("./data") if not f.endswith(".csv")
]
for x in non_csv_files:
os.remove("./data/{}".format(x))
try:
os.remove("./data/{}".format(x))
except:
continue
# Get a list of all CSV files in the unzipped folder
csv_files = [f for f in os.listdir("./data") if f.endswith(".csv")]
if not csv_files:
Expand Down
45 changes: 26 additions & 19 deletions tdc/test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,34 @@ def test_ADME_evaluate_many(self):
self.assertTrue(my_group["name"] in results)

def test_SCDTI_benchmark(self):
from tdc.resource.dataloader import DataLoader

data = DataLoader(name="opentargets_dti")
group = scdti_group.SCDTIGroup()
train, val = group.get_train_valid_split()
assert len(val) == 0 # this benchmark has no validation set
# test simple preds
y_true = group.get_test()["Y"]
results = group.evaluate(y_true)
assert results[-1] == 1.0 # should be perfect F1 score
# assert it matches the opentargets official test scores
tst = data.get_split()["test"]["Y"]
train_val = group.get_train_valid_split()
assert "train" in train_val, "no training set"
assert "val" in train_val, "no validation set"
assert len(train_val["train"]) > 0, "no entries in training set"
tst = group.get_test()["test"]
tst["preds"] = tst["y"] # switch predictions to ground truth
results = group.evaluate(tst)
assert results[-1] == 1.0
zero_pred = [0] * len(y_true)
results = group.evaluate(zero_pred)
assert results[-1] != 1.0 # should not be perfect F1 score
many_results = group.evaluate_many([y_true] * 5)
assert "f1" in many_results
assert len(many_results["f1"]
) == 2 # should include mean and standard deviation
assert "IBD" in results, "missing ibd from diseases. got {}".format(
results.keys())
assert "RA" in results, "missing ra from diseases. got {}".format(
results.keys())
assert results["IBD"] == results[
"RA"], "both should be perfect scores but got IBD {} vs RA {}".format(
results["IBD"], results["RA"]) # both should be perfect scores
assert results["IBD"] - 1.0 < 0.000001 # should be a perfect score
many_results = group.evaluate_many([tst] * 5)
assert "IBD" in many_results, "missing ibd from diseases in evaluate many. got {}".format(
many_results.keys())
assert "RA" in many_results, "missing ra from diseases in evaluate many. got {}".format(
many_results.keys())
assert len(many_results["IBD"]) == len(
many_results["RA"]
), "both diseases should include mean and standard deviation"
assert len(many_results["IBD"]
) == 2, "results should include mean and standard deviation"
assert many_results["IBD"][
0] - 1.0 < 0.000001, "should get perfect score"

@unittest.skip(
"counterfactual test is taking up too much memory"
Expand Down
Loading