From 22bbc07ddcd25af4e6a26c45783d7f7b147972ad Mon Sep 17 00:00:00 2001 From: Alejandro Velez Date: Mon, 6 Jan 2025 21:01:44 -0500 Subject: [PATCH] mend --- tdc/test/test_hf.py | 153 +++++++++++++++++++++++++++++--------------- 1 file changed, 101 insertions(+), 52 deletions(-) diff --git a/tdc/test/test_hf.py b/tdc/test/test_hf.py index 74a5cbf6..223012d1 100644 --- a/tdc/test/test_hf.py +++ b/tdc/test/test_hf.py @@ -1,66 +1,115 @@ -# -*- coding: utf-8 -*- +from huggingface_hub import create_repo +from huggingface_hub import HfApi, snapshot_download, hf_hub_download +import os -from __future__ import division -from __future__ import print_function +deeppurpose_repo = [ + 'hERG_Karim-Morgan', + 'hERG_Karim-CNN', + 'hERG_Karim-AttentiveFP', + 'BBB_Martins-AttentiveFP', + 'BBB_Martins-Morgan', + 'BBB_Martins-CNN', + 'CYP3A4_Veith-Morgan', + 'CYP3A4_Veith-CNN', + 'CYP3A4_Veith-AttentiveFP', +] -import os -import sys +model_hub = ["Geneformer", "scGPT"] -import unittest -import shutil -import pytest -# temporary solution for relative imports in case TDC is not installed -# if TDC is installed, no need to use the following line -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -# TODO: add verification for the generation other than simple integration +class tdc_hf_interface: + ''' + Example use cases: + # initialize an interface object with HF repo name + tdc_hf_herg = tdc_hf_interface("hERG_Karim-Morgan") + # upload folder/files to this repo + tdc_hf_herg.upload('./Morgan_herg_karim_optimal') + # load deeppurpose model from this repo + dp_model = tdc_hf_herg.load_deeppurpose('./data') + dp_model.predict(XXX) + ''' + def __init__(self, repo_name): + self.repo_id = "tdc/" + repo_name + try: + self.model_name = repo_name.split('-')[1] + except: + self.model_name = repo_name -class TestHF(unittest.TestCase): + def upload(self, folder_path): + create_repo(repo_id=self.repo_id) + api = HfApi() + api.upload_folder(folder_path=folder_path, + path_in_repo="model", + repo_id=self.repo_id, + repo_type="model") - def setUp(self): - print(os.getcwd()) - pass + def file_download(self, save_path, filename): + model_ckpt = hf_hub_download(repo_id=self.repo_id, + filename=filename, + cache_dir=save_path) - @pytest.mark.skip( - reason="This test is skipped due to deeppurpose installation dependency" - ) - @unittest.skip(reason="DeepPurpose") - def test_hf_load_predict(self): - from tdc.single_pred import Tox - data = Tox(name='herg_karim') + def repo_download(self, save_path): + snapshot_download(repo_id=self.repo_id, cache_dir=save_path) - from tdc import tdc_hf_interface - tdc_hf = tdc_hf_interface("hERG_Karim-CNN") - # load deeppurpose model from this repo - dp_model = tdc_hf.load_deeppurpose('./data') - tdc_hf.predict_deeppurpose(dp_model, ['CC(=O)NC1=CC=C(O)C=C1']) + def load(self): + if self.model_name not in model_hub: + raise Exception("this model is not in the TDC model hub GH repo.") + elif self.model_name == "Geneformer": + from transformers import AutoModelForMaskedLM + model = AutoModelForMaskedLM.from_pretrained( + "ctheodoris/Geneformer") + return model + elif self.model_name == "scGPT": + from transformers import AutoModel + model = AutoModel.from_pretrained("tdc/scGPT") + return model + raise Exception("Not implemented yet!") - def test_hf_transformer(self): - from tdc import tdc_hf_interface - # from transformers import Pipeline - from transformers import BertForMaskedLM as BertModel - geneformer = tdc_hf_interface("Geneformer") - model = geneformer.load() - # assert isinstance(pipeline, Pipeline) - assert isinstance(model, BertModel), type(model) + def load_deeppurpose(self, save_path): + if self.repo_id[4:] in deeppurpose_repo: + save_path = save_path + '/' + self.repo_id[4:] + if not os.path.exists(save_path): + os.mkdir(save_path) + self.file_download(save_path, "model/model.pt") + self.file_download(save_path, "model/config.pkl") - # def test_hf_load_new_pytorch_standard(self): - # from tdc import tdc_hf_interface - # # from tdc.resource.dataloader import DataLoader - # # data = DataLoader(name="pinnacle_dti") - # tdc_hf = tdc_hf_interface("mli-PINNACLE") - # dp_model = tdc_hf.load() - # assert dp_model is not None + save_path = save_path + '/models--tdc--' + self.repo_id[ + 4:] + '/blobs/' + file_name1 = save_path + os.listdir(save_path)[0] + file_name2 = save_path + os.listdir(save_path)[1] - def tearDown(self): - try: - print(os.getcwd()) - shutil.rmtree(os.path.join(os.getcwd(), "data")) - except: - pass + if os.path.getsize(file_name1) > os.path.getsize(file_name2): + model_file, config_file = file_name1, file_name2 + else: + config_file, model_file = file_name1, file_name2 + + os.rename(model_file, save_path + 'model.pt') + os.rename(config_file, save_path + 'config.pkl') + try: + from DeepPurpose import CompoundPred + except: + raise ValueError( + "Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation" + ) + net = CompoundPred.model_pretrained(path_dir=save_path) + return net + else: + raise ValueError("This repo does not host a DeepPurpose model!") -if __name__ == "__main__": - unittest.main() + def predict_deeppurpose(self, model, drugs): + try: + from DeepPurpose import utils + except: + raise ValueError( + "Please install DeepPurpose package following https://github.com/kexinhuang12345/DeepPurpose#installation" + ) + if self.model_name == 'AttentiveFP': + self.model_name = 'DGL_' + self.model_name + X_pred = utils.data_process(X_drug=drugs, + y=[0] * len(drugs), + drug_encoding=self.model_name, + split_method='no_split') + y_pred = model.predict(X_pred)[0] + return y_pred