From eafe8a41f962bf99554672bf54139e4c5bd09355 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Fri, 23 Feb 2024 15:45:11 +0100 Subject: [PATCH 1/2] refactor: reduce redundancy --- index/embedding.py | 8 +- index/evaluation.py | 10 +- index/main.py | 450 +++++++++++++++--------------------- index/visualisation.py | 20 +- tests/test_embedding.py | 12 +- tests/test_visualisation.py | 8 +- 6 files changed, 220 insertions(+), 288 deletions(-) diff --git a/index/embedding.py b/index/embedding.py index 379c803..7b5dfae 100644 --- a/index/embedding.py +++ b/index/embedding.py @@ -38,10 +38,10 @@ def get_embeddings(self, messages: [str], model="text-embedding-ada-002"): return [item["embedding"] for item in response["data"]] -class MPNetAdapter(EmbeddingModel): +class SentenceTransformerAdapter(EmbeddingModel): def __init__(self, model="sentence-transformers/all-mpnet-base-v2"): logging.getLogger().setLevel(logging.INFO) - self.mpnet_model = SentenceTransformer(model) + self.model = SentenceTransformer(model) def get_embedding(self, text: str): logging.info(f"Getting embedding for {text}") @@ -51,13 +51,13 @@ def get_embedding(self, text: str): return None if isinstance(text, str): text = text.replace("\n", " ") - return self.mpnet_model.encode(text) + return self.model.encode(text) except Exception as e: logging.error(f"Error getting embedding for {text}: {e}") return None def get_embeddings(self, messages: [str]) -> [[float]]: - embeddings = self.mpnet_model.encode(messages) + embeddings = self.model.encode(messages) flattened_embeddings = [[float(element) for element in row] for row in embeddings] return flattened_embeddings diff --git a/index/evaluation.py b/index/evaluation.py index 5da9711..0602383 100644 --- a/index/evaluation.py +++ b/index/evaluation.py @@ -197,19 +197,19 @@ def score_mappings(matches: pd.DataFrame) -> float: return accuracy -def evaluate(datasets, labels, store_results=False, model="gpt", results_root_dir="resources/results/pd"): +def evaluate(datasets, labels, model: str, matching_method="euclidean", store_results=False, results_root_dir="resources/results/pd"): data = {} for idx, source in enumerate(datasets): acc = [] for idy, target in enumerate(datasets): - if model == "gpt": + if matching_method == "euclidean": map = match_closest_descriptions(source, target) - elif model == "mpnet": + elif matching_method == "cosine": map = match_closest_descriptions(source,target, matching_method=MatchingMethod.COSINE_EMBEDDING_DISTANCE) - elif model == "fuzzy": + elif matching_method == "fuzzy": map = match_closest_descriptions(source, target, matching_method=MatchingMethod.FUZZY_STRING_MATCHING) else: - raise NotImplementedError("Specified model is not implemented!") + raise NotImplementedError("Matching method is not implemented!") if store_results: map.to_excel(results_root_dir + f"/{model}_" + f"{labels[idx]}_to_{labels[idy]}.xlsx") acc.append(round(score_mappings(map), 2)) diff --git a/index/main.py b/index/main.py index c5bcf35..ce6e0e8 100644 --- a/index/main.py +++ b/index/main.py @@ -6,7 +6,7 @@ from index import evaluation from index.conf import PD_CDM_SRC, PPMI_DICT_SRC, LUXPARK_DICT_SRC, BIOFIND_DICT_SRC, AD_CDM_SRC -from index.embedding import GPT4Adapter, MPNetAdapter +from index.embedding import EmbeddingModel, GPT4Adapter, SentenceTransformerAdapter from index.evaluation import match_closest_descriptions, MatchingMethod, enrichment_analysis, evaluate from index.mapping import MappingTable from index.parsing import MappingSource, DataDictionarySource @@ -15,267 +15,197 @@ from index.visualisation import scatter_plot_two_distributions, enrichment_plot, scatter_plot_all_cohorts, \ bar_chart_average_acc_two_distributions -EVAL_PD = True -EVAL_AD = True + +def create_datasets(model: EmbeddingModel, EVAL_PD = True, EVAL_AD = True): + results_pd = {} + results_ad = {} + if EVAL_PD: + opdc = MappingTable(MappingSource(PD_CDM_SRC, "OPDC", "CURIE")) + opdc.add_descriptions(DataDictionarySource("resources/dictionaries/pd/OPDC.csv", "Variable Name", "Variable description")) + opdc.compute_embeddings(model) + results_pd["OPDC"] = opdc + + proband = MappingTable(MappingSource(PD_CDM_SRC, "TPD", "CURIE")) + proband.add_descriptions(DataDictionarySource("resources/dictionaries/pd/TPD.csv", "Variable Name", "Variable description")) + proband.compute_embeddings(model) + results_pd["PRoBaND"] = proband + + biofind = MappingTable(MappingSource(PD_CDM_SRC, "BIOFIND", "CURIE")) + biofind.add_descriptions(DataDictionarySource(BIOFIND_DICT_SRC, "ITM_NAME", "DSCR")) + biofind.compute_embeddings(model) + results_pd["BIOFIND"] = biofind + + lcc = MappingTable(MappingSource(PD_CDM_SRC, "LRRK2", "CURIE")) + lcc.add_descriptions(DataDictionarySource("resources/dictionaries/pd/LRRK2.xlsx", "Variable", "Label")) + lcc.compute_embeddings(model) + results_pd["LCC"] = lcc + + luxpark = MappingTable(MappingSource(PD_CDM_SRC, "LuxPARK", "CURIE")) + luxpark.add_descriptions(DataDictionarySource(LUXPARK_DICT_SRC, "Variable / Field Name", "Field Label")) + luxpark.compute_embeddings(model) + results_pd["LuxPARK"] = luxpark + + ppmi = MappingTable(MappingSource(PD_CDM_SRC, "PPMI", "CURIE")) + ppmi.add_descriptions(DataDictionarySource(PPMI_DICT_SRC, "ITM_NAME", "DSCR")) + ppmi.compute_embeddings(model) + results_pd["PPMI"] = ppmi + + cdm_pd = MappingTable(MappingSource(PD_CDM_SRC, "Feature", "CURIE")) + cdm_pd.joined_mapping_table["identifier"].to_csv("resources/cdm_curie.csv", index=False) + cdm_pd.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) + cdm_pd.compute_embeddings(model) + results_pd["PASSIONATE"] = cdm_pd + + if EVAL_AD: + cdm_ad = cdm_pd = MappingTable(MappingSource(AD_CDM_SRC, "Feature", "CURIE")) + cdm_ad.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) + cdm_ad.compute_embeddings(model) + results_ad["AD-Mapper"] = cdm_ad + + a4 = MappingTable(MappingSource(AD_CDM_SRC, "A4", "CURIE")) + a4.add_descriptions(DataDictionarySource("resources/dictionaries/ad/a4.csv", "FLDNAME", "TEXT")) + a4.compute_embeddings(model) + results_ad["A4"] = a4 + + abvib = MappingTable(MappingSource(AD_CDM_SRC, "ABVIB", "CURIE")) + abvib.add_descriptions(DataDictionarySource("resources/dictionaries/ad/abvib.csv", "variable_name", "description")) + abvib.compute_embeddings(model) + results_ad["ABVIB"] = abvib + + adni = MappingTable(MappingSource(AD_CDM_SRC, "ADNI", "CURIE")) + adni.add_descriptions(DataDictionarySource("resources/dictionaries/ad/ADNIMERGE_DICT_27Nov2023 2.csv", "FLDNAME", "TEXT")) + adni.compute_embeddings(model) + results_ad["ADNI"] = adni + + aibl = MappingTable(MappingSource(AD_CDM_SRC, "AIBL", "CURIE")) + aibl.add_descriptions(DataDictionarySource("resources/dictionaries/ad/aibl.csv", "Name", "Description")) + aibl.compute_embeddings(model) + results_ad["AIBL"] = aibl + + arwibo = MappingTable(MappingSource(AD_CDM_SRC, "ARWIBO", "CURIE")) + arwibo.add_descriptions(DataDictionarySource("resources/dictionaries/ad/arwibo.csv", "Variable_Name", "Element_description")) + arwibo.compute_embeddings(model) + results_ad["ARWIBO"] = arwibo + + dod_adni = MappingTable(MappingSource(AD_CDM_SRC, "DOD-ADNI", "CURIE")) + dod_adni.add_descriptions(DataDictionarySource("resources/dictionaries/ad/dod-adni.csv", "FLDNAME", "TEXT")) + dod_adni.compute_embeddings(model) + results_ad["DOD-ADNI"] = dod_adni + + edsd = MappingTable(MappingSource(AD_CDM_SRC, "EDSD", "CURIE")) + edsd.add_descriptions(DataDictionarySource("resources/dictionaries/ad/edsd.xlsx", "Variable_Name", "Element_description")) + edsd.compute_embeddings(model) + results_ad["EDSD"] = edsd + + emif = MappingTable(MappingSource(AD_CDM_SRC, "EMIF", "CURIE")) + emif.add_descriptions(DataDictionarySource("resources/dictionaries/ad/emif.xlsx", "Variable", "Description")) + emif.compute_embeddings(model) + results_ad["EMIF"] = emif + + i_adni = MappingTable(MappingSource(AD_CDM_SRC, "I-ADNI", "CURIE")) + i_adni.add_descriptions(DataDictionarySource("resources/dictionaries/ad/i-adni.csv", "acronym", "variable")) + i_adni.compute_embeddings(model) + results_ad["I-ADNI"] = i_adni + + jadni = MappingTable(MappingSource(AD_CDM_SRC, "JADNI", "CURIE")) + jadni.add_descriptions(DataDictionarySource("resources/dictionaries/ad/jadni.tsv", "FLDNAME", "TEXT")) + jadni.compute_embeddings(model) + results_ad["JADNI"] = jadni + + pharmacog = MappingTable(MappingSource(AD_CDM_SRC, "PharmaCog", "CURIE")) + pharmacog.add_descriptions(DataDictionarySource("resources/dictionaries/ad/pharmacog.csv", "Variable_Name", "Element_description")) + pharmacog.compute_embeddings(model) + results_ad["PharmaCog"] = pharmacog + + prevent_ad = MappingTable(MappingSource(AD_CDM_SRC, "PREVENT-AD", "CURIE")) + prevent_ad.add_descriptions(DataDictionarySource("resources/dictionaries/ad/prevent-ad.csv", "variable", "description")) + prevent_ad.compute_embeddings(model) + results_ad["PREVENT-AD"] = prevent_ad + + vita = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) + vita.add_descriptions(DataDictionarySource("resources/dictionaries/ad/vita.csv", "Variable_Name", "Element_description")) + vita.compute_embeddings(model) + results_ad["VITA"] = vita + + return results_pd, results_ad + load_dotenv() gpt4 = GPT4Adapter(api_key=os.getenv("GPT_KEY")) # type: ignore -mpnet = MPNetAdapter() - -# PD Mappings - -if EVAL_PD: - cdm_pd_gpt = MappingTable(MappingSource(PD_CDM_SRC, "Feature", "CURIE")) - cdm_pd_gpt.joined_mapping_table["identifier"].to_csv("resources/cdm_curie.csv", index=False) - cdm_pd_gpt.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_pd_gpt.compute_embeddings(gpt4) - - cdm_pd_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "Feature", "CURIE")) - cdm_pd_mpnet.joined_mapping_table["identifier"].to_csv("resources/cdm_curie.csv", index=False) - cdm_pd_mpnet.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_pd_mpnet.compute_embeddings(mpnet) - - ppmi_gpt = MappingTable(MappingSource(PD_CDM_SRC, "PPMI", "CURIE")) - ppmi_gpt.add_descriptions(DataDictionarySource(PPMI_DICT_SRC, "ITM_NAME", "DSCR")) - ppmi_gpt.compute_embeddings(gpt4) - - ppmi_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "PPMI", "CURIE")) - ppmi_mpnet.add_descriptions(DataDictionarySource(PPMI_DICT_SRC, "ITM_NAME", "DSCR")) - ppmi_mpnet.compute_embeddings(mpnet) - - luxpark_gpt = MappingTable(MappingSource(PD_CDM_SRC, "LuxPARK", "CURIE")) - luxpark_gpt.add_descriptions(DataDictionarySource(LUXPARK_DICT_SRC, "Variable / Field Name", "Field Label")) - luxpark_gpt.compute_embeddings(gpt4) - - luxpark_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "LuxPARK", "CURIE")) - luxpark_mpnet.add_descriptions(DataDictionarySource(LUXPARK_DICT_SRC, "Variable / Field Name", "Field Label")) - luxpark_mpnet.compute_embeddings(mpnet) - - biofind_gpt = MappingTable(MappingSource(PD_CDM_SRC, "BIOFIND", "CURIE")) - biofind_gpt.add_descriptions(DataDictionarySource(BIOFIND_DICT_SRC, "ITM_NAME", "DSCR")) - biofind_gpt.compute_embeddings(gpt4) - - biofind_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "BIOFIND", "CURIE")) - biofind_mpnet.add_descriptions(DataDictionarySource(BIOFIND_DICT_SRC, "ITM_NAME", "DSCR")) - biofind_mpnet.compute_embeddings(mpnet) - - lrrk2_gpt = MappingTable(MappingSource(PD_CDM_SRC, "LRRK2", "CURIE")) - lrrk2_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/pd/LRRK2.xlsx", "Variable", "Label")) - lrrk2_gpt.compute_embeddings(gpt4) - - lrrk2_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "LRRK2", "CURIE")) - lrrk2_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/pd/LRRK2.xlsx", "Variable", "Label")) - lrrk2_mpnet.compute_embeddings(mpnet) - - opdc_gpt = MappingTable(MappingSource(PD_CDM_SRC, "OPDC", "CURIE")) - opdc_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/pd/OPDC.csv", "Variable Name", "Variable description")) - opdc_gpt.compute_embeddings(gpt4) - - opdc_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "OPDC", "CURIE")) - opdc_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/pd/OPDC.csv", "Variable Name", "Variable description")) - opdc_mpnet.compute_embeddings(mpnet) - - tpd_gpt = MappingTable(MappingSource(PD_CDM_SRC, "TPD", "CURIE")) - tpd_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/pd/TPD.csv", "Variable Name", "Variable description")) - tpd_gpt.compute_embeddings(gpt4) - - tpd_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "TPD", "CURIE")) - tpd_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/pd/TPD.csv", "Variable Name", "Variable description")) - tpd_mpnet.compute_embeddings(mpnet) - - pd_datasets_gpt = [opdc_gpt, tpd_gpt, biofind_gpt, lrrk2_gpt, luxpark_gpt, ppmi_gpt, cdm_pd_gpt] - pd_datasets_mpnet = [opdc_mpnet, tpd_mpnet, biofind_mpnet, lrrk2_mpnet, luxpark_mpnet, ppmi_mpnet, cdm_pd_mpnet] - pd_datasets_labels = ["OPDC", "PRoBaND", "BIOFIND", "LCC", "LuxPARK", "PPMI", "PASSIONATE"] - - # enrichment analysis - luxpark_passionate_enrichment_gpt = enrichment_analysis(luxpark_gpt, cdm_pd_gpt, 20, MatchingMethod.EUCLIDEAN_EMBEDDING_DISTANCE) - luxpark_passionate_enrichment_mpnet = enrichment_analysis(luxpark_mpnet, cdm_pd_mpnet, 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) - luxpark_passionate_enrichment_fuzzy = enrichment_analysis(luxpark_gpt, cdm_pd_gpt, 20, MatchingMethod.FUZZY_STRING_MATCHING) - label1 = "Enrichment Plot LuxPARK to CDM" - ppmi_passionate_enrichment_gpt = enrichment_analysis(ppmi_gpt, cdm_pd_gpt, 20, MatchingMethod.EUCLIDEAN_EMBEDDING_DISTANCE) - ppmi_passionate_enrichment_mpnet = enrichment_analysis(ppmi_mpnet, cdm_pd_mpnet, 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) - ppmi_passionate_enrichment_fuzzy = enrichment_analysis(ppmi_gpt, cdm_pd_gpt, 20, MatchingMethod.FUZZY_STRING_MATCHING) - label2 = "Enrichment Plot PPMI to CDM" - ppmi_luxpark_enrichment_gpt = enrichment_analysis(ppmi_gpt, luxpark_gpt, 20, MatchingMethod.EUCLIDEAN_EMBEDDING_DISTANCE) - ppmi_luxpark_enrichment_mpnet = enrichment_analysis(ppmi_mpnet, luxpark_mpnet, 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) - ppmi_luxpark_enrichment_fuzzy = enrichment_analysis(ppmi_gpt, luxpark_gpt, 20, MatchingMethod.FUZZY_STRING_MATCHING) - label3 = "Enrichment Plot PPMI to LuxPARK" - enrichment_plot(luxpark_passionate_enrichment_gpt, luxpark_passionate_enrichment_mpnet, luxpark_passionate_enrichment_fuzzy, label1, save_plot=True) - enrichment_plot(ppmi_passionate_enrichment_gpt, ppmi_passionate_enrichment_mpnet, ppmi_passionate_enrichment_fuzzy, label2, save_plot=True) - enrichment_plot( ppmi_luxpark_enrichment_gpt, ppmi_luxpark_enrichment_mpnet, ppmi_luxpark_enrichment_fuzzy, label3, save_plot=True) - print(luxpark_passionate_enrichment_gpt) - print(luxpark_passionate_enrichment_mpnet) - print(luxpark_passionate_enrichment_fuzzy) - print(ppmi_passionate_enrichment_gpt) - print(ppmi_passionate_enrichment_mpnet) - print(ppmi_passionate_enrichment_fuzzy) - print(ppmi_luxpark_enrichment_gpt) - print(ppmi_luxpark_enrichment_mpnet) - print(ppmi_luxpark_enrichment_fuzzy) - - gpt_table1 = evaluate(pd_datasets_gpt, pd_datasets_labels, store_results=True) - fuzzy_table1 = evaluate(pd_datasets_gpt, pd_datasets_labels, store_results=True, model="fuzzy") - mpnet_table1 = evaluate(pd_datasets_mpnet, pd_datasets_labels, store_results=True, model="mpnet") - - print("PD RESULTS:") - print("GPT") - print("-----------") - print(gpt_table1) - print("-----------") - print("MPNet") - print("-----------") - print(mpnet_table1) - print("-----------") - print("Fuzzy") - print("-----------") - print(fuzzy_table1) - print("-----------") - -# AD Mappings - -if EVAL_AD: - cdm_ad_gpt = cdm_pd_gpt = MappingTable(MappingSource(AD_CDM_SRC, "Feature", "CURIE")) - cdm_ad_gpt.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_ad_gpt.compute_embeddings(gpt4) - - cdm_ad_mpnet = cdm_pd_gpt = MappingTable(MappingSource(AD_CDM_SRC, "Feature", "CURIE")) - cdm_ad_mpnet.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_ad_mpnet.compute_embeddings(mpnet) - - a4_gpt = MappingTable(MappingSource(AD_CDM_SRC, "A4", "CURIE")) - a4_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/a4.csv", "FLDNAME", "TEXT")) - a4_gpt.compute_embeddings(gpt4) - - a4_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "A4", "CURIE")) - a4_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/a4.csv", "FLDNAME", "TEXT")) - a4_mpnet.compute_embeddings(mpnet) - - abvib_gpt = MappingTable(MappingSource(AD_CDM_SRC, "ABVIB", "CURIE")) - abvib_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/abvib.csv", "variable_name", "description")) - abvib_gpt.compute_embeddings(gpt4) - - abvib_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "ABVIB", "CURIE")) - abvib_mpnet.add_descriptions( DataDictionarySource("resources/dictionaries/ad/abvib.csv", "variable_name", "description")) - abvib_mpnet.compute_embeddings(mpnet) - - adni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "ADNI", "CURIE")) - adni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/ADNIMERGE_DICT_27Nov2023 2.csv", "FLDNAME", "TEXT")) - adni_gpt.compute_embeddings(gpt4) - - adni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "ADNI", "CURIE")) - adni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/ADNIMERGE_DICT_27Nov2023 2.csv", "FLDNAME", "TEXT")) - adni_mpnet.compute_embeddings(mpnet) - - aibl_gpt = MappingTable(MappingSource(AD_CDM_SRC, "AIBL", "CURIE")) - aibl_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/aibl.csv", "Name", "Description")) - aibl_gpt.compute_embeddings(gpt4) - - aibl_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "AIBL", "CURIE")) - aibl_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/aibl.csv", "Name", "Description")) - aibl_mpnet.compute_embeddings(mpnet) - - arwibo_gpt = MappingTable(MappingSource(AD_CDM_SRC, "ARWIBO", "CURIE")) - arwibo_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/arwibo.csv", "Variable_Name", "Element_description")) - arwibo_gpt.compute_embeddings(gpt4) - - arwibo_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "ARWIBO", "CURIE")) - arwibo_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/arwibo.csv", "Variable_Name", "Element_description")) - arwibo_mpnet.compute_embeddings(mpnet) - - dod_adni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "DOD-ADNI", "CURIE")) - # TODO most descriptions missing - dod_adni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/dod-adni.csv", "FLDNAME", "TEXT")) - dod_adni_gpt.compute_embeddings(gpt4) - - dod_adni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "DOD-ADNI", "CURIE")) - # TODO most descriptions missing - dod_adni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/dod-adni.csv", "FLDNAME", "TEXT")) - dod_adni_mpnet.compute_embeddings(mpnet) - - edsd_gpt = MappingTable(MappingSource(AD_CDM_SRC, "EDSD", "CURIE")) - edsd_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/edsd.xlsx", "Variable_Name", "Element_description")) - edsd_gpt.compute_embeddings(gpt4) - - edsd_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "EDSD", "CURIE")) - edsd_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/edsd.xlsx", "Variable_Name", "Element_description")) - edsd_mpnet.compute_embeddings(mpnet) - - emif_gpt = MappingTable(MappingSource(AD_CDM_SRC, "EMIF", "CURIE")) - emif_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/emif.xlsx", "Variable", "Description")) - emif_gpt.compute_embeddings(gpt4) - - emif_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "EMIF", "CURIE")) - emif_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/emif.xlsx", "Variable", "Description")) - emif_mpnet.compute_embeddings(mpnet) - - i_adni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "I-ADNI", "CURIE")) - # TODO about half of descriptions missing - i_adni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/i-adni.csv", "acronym", "variable")) - i_adni_gpt.compute_embeddings(gpt4) - - i_adni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "I-ADNI", "CURIE")) - # TODO about half of descriptions missing - i_adni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/i-adni.csv", "acronym", "variable")) - i_adni_mpnet.compute_embeddings(mpnet) - - jadni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "JADNI", "CURIE")) - jadni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/jadni.tsv", "FLDNAME", "TEXT")) - jadni_gpt.compute_embeddings(gpt4) - - jadni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "JADNI", "CURIE")) - jadni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/jadni.tsv", "FLDNAME", "TEXT")) - jadni_mpnet.compute_embeddings(mpnet) - - pharmacog_gpt = MappingTable(MappingSource(AD_CDM_SRC, "PharmaCog", "CURIE")) - pharmacog_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/pharmacog.csv", "Variable_Name", "Element_description")) - pharmacog_gpt.compute_embeddings(gpt4) - - pharmacog_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "PharmaCog", "CURIE")) - pharmacog_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/pharmacog.csv", "Variable_Name", "Element_description")) - pharmacog_mpnet.compute_embeddings(mpnet) - - prevent_ad_gpt = MappingTable(MappingSource(AD_CDM_SRC, "PREVENT-AD", "CURIE")) - prevent_ad_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/prevent-ad.csv", "variable", "description")) - prevent_ad_gpt.compute_embeddings(gpt4) - - prevent_ad_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "PREVENT-AD", "CURIE")) - prevent_ad_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/prevent-ad.csv", "variable", "description")) - prevent_ad_mpnet.compute_embeddings(mpnet) - - vita_gpt = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) - vita_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/vita.csv", "Variable_Name", "Element_description")) - vita_gpt.compute_embeddings(gpt4) - - vita_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) - vita_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/vita.csv", "Variable_Name", "Element_description")) - vita_mpnet.compute_embeddings(mpnet) - - wmh_ad = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) - - ad_datasets_gpt = [a4_gpt, abvib_gpt, adni_gpt, aibl_gpt, arwibo_gpt, dod_adni_gpt, edsd_gpt, emif_gpt, i_adni_gpt, jadni_gpt, - pharmacog_gpt, prevent_ad_gpt, vita_gpt, cdm_ad_gpt] - ad_datasets_mpnet = [a4_mpnet, abvib_mpnet, adni_mpnet, aibl_mpnet, arwibo_mpnet, dod_adni_mpnet, edsd_mpnet, emif_mpnet, - i_adni_mpnet, jadni_mpnet, pharmacog_mpnet, prevent_ad_mpnet, vita_mpnet, cdm_ad_mpnet] - ad_datasets_labels = ["A4", "ABVIB", "ADNI", "AIBL", "ARWIBO", "DOD-ADNI", "EDSD", "EMIF", "I-ADNI", "JADNI", "PharmaCog", - "PREVENT-AD", "VITA", "AD-Mapper"] - gpt_table2 = evaluate(ad_datasets_gpt, ad_datasets_labels, store_results=True, results_root_dir="resources/results/ad") - fuzzy_table2 = evaluate(ad_datasets_gpt, ad_datasets_labels, store_results=True, model="fuzzy", results_root_dir="resources/results/ad") - mpnet_table2 = evaluate(ad_datasets_mpnet, ad_datasets_labels, store_results=True, model="mpnet", results_root_dir="resources/results/ad") - - print("AD RESULTS:") - print("GPT") - print("-----------") - print(gpt_table2.to_string()) - print("-----------") - print("MPNet") - print("-----------") - print(mpnet_table2.to_string()) - print("-----------") - print("Fuzzy") - print("-----------") - print(fuzzy_table2.to_string()) - print("-----------") +mpnet = SentenceTransformerAdapter(model="sentence-transformers/all-mpnet-base-v2") +pd_gpt, ad_gpt = create_datasets(gpt4) +pd_mpnet, ad_mpnet = create_datasets(mpnet) + +# PD Mapping Analyses +# enrichment analysis +luxpark_passionate_enrichment_gpt = enrichment_analysis(pd_gpt["LuxPARK"], pd_gpt["PASSIONATE"], 20) +luxpark_passionate_enrichment_mpnet = enrichment_analysis(pd_mpnet["LuxPARK"], pd_mpnet["PASSIONATE"], 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) +luxpark_passionate_enrichment_fuzzy = enrichment_analysis(pd_gpt["LuxPARK"], pd_gpt["PASSIONATE"], 20, MatchingMethod.FUZZY_STRING_MATCHING) +label1 = "Enrichment Plot LuxPARK to CDM" +ppmi_passionate_enrichment_gpt = enrichment_analysis(pd_gpt["PPMI"], pd_gpt["PASSIONATE"], 20) +ppmi_passionate_enrichment_mpnet = enrichment_analysis(pd_mpnet["PPMI"], pd_mpnet["PASSIONATE"], 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) +ppmi_passionate_enrichment_fuzzy = enrichment_analysis(pd_gpt["PPMI"], pd_gpt["PASSIONATE"], 20, MatchingMethod.FUZZY_STRING_MATCHING) +label2 = "Enrichment Plot PPMI to CDM" +ppmi_luxpark_enrichment_gpt = enrichment_analysis(pd_gpt["PPMI"], pd_gpt["LuxPARK"], 20) +ppmi_luxpark_enrichment_mpnet = enrichment_analysis(pd_mpnet["PPMI"], pd_mpnet["LuxPARK"], 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) +ppmi_luxpark_enrichment_fuzzy = enrichment_analysis(pd_gpt["PPMI"], pd_gpt["LuxPARK"], 20, MatchingMethod.FUZZY_STRING_MATCHING) +label3 = "Enrichment Plot PPMI to LuxPARK" + +luxpark_passionate_accuracies = {"GPT": luxpark_passionate_enrichment_gpt, + "MPNet": luxpark_passionate_enrichment_mpnet, + "Fuzzy": luxpark_passionate_enrichment_fuzzy} + +ppmi_passionate_accuracies = {"GPT": ppmi_passionate_enrichment_gpt, + "MPNet": ppmi_passionate_enrichment_mpnet, + "Fuzzy": ppmi_passionate_enrichment_fuzzy} + +ppmi_luxpark_accuracies = {"GPT": ppmi_luxpark_enrichment_gpt, + "MPNet": ppmi_luxpark_enrichment_mpnet, + "Fuzzy": ppmi_luxpark_enrichment_fuzzy} + +enrichment_plot(luxpark_passionate_accuracies, label1, save_plot=False) +enrichment_plot(ppmi_passionate_accuracies, label2, save_plot=False) +enrichment_plot( ppmi_luxpark_accuracies, label3, save_plot=False) + +gpt_table1 = evaluate(list(pd_gpt.values()), list(pd_gpt.keys()), model="gpt", store_results=False) +fuzzy_table1 = evaluate(list(pd_gpt.values()), list(pd_mpnet.keys()), model="fuzzy", matching_method="fuzzy", store_results=False) +mpnet_table1 = evaluate(list(pd_mpnet.values()), list(pd_gpt.keys()), model="mpnet", matching_method="cosine", store_results=False) + +print("\n PD RESULTS: \n") +print("GPT") +print("-----------") +print(gpt_table1) +print("-----------") +print("MPNet") +print("-----------") +print(mpnet_table1) +print("-----------") +print("Fuzzy") +print("-----------") +print(fuzzy_table1) +print("-----------") + +# AD Mapping Analyses +gpt_table2 = evaluate(list(ad_gpt.values()), list(ad_gpt.keys()), model="gpt", store_results=False, results_root_dir="resources/results/ad") +fuzzy_table2 = evaluate(list(ad_gpt.values()), list(ad_gpt.keys()), model="fuzzy", matching_method="fuzzy", + store_results=False, results_root_dir="resources/results/ad") +mpnet_table2 = evaluate(list(ad_mpnet.values()), list(ad_mpnet.keys()), model="mpnet", matching_method="cosine", + store_results=False, results_root_dir="resources/results/ad") + +print("\n AD RESULTS: \n") +print("GPT") +print("-----------") +print(gpt_table2.to_string()) +print("-----------") +print("MPNet") +print("-----------") +print(mpnet_table2.to_string()) +print("-----------") +print("Fuzzy") +print("-----------") +print(fuzzy_table2.to_string()) +print("-----------") # embedding distribution -scatter_plot_two_distributions(pd_datasets_gpt, ad_datasets_gpt, "PD", "AD") -scatter_plot_all_cohorts(pd_datasets_gpt, ad_datasets_gpt, pd_datasets_labels, ad_datasets_labels) +scatter_plot_two_distributions(list(pd_gpt.values()), list(ad_gpt.values()), "PD", "AD") +scatter_plot_all_cohorts(list(pd_gpt.values()), list(ad_gpt.values()), list(pd_gpt.keys()), list(ad_gpt.keys())) diff --git a/index/visualisation.py b/index/visualisation.py index 71c90f2..c1c2427 100644 --- a/index/visualisation.py +++ b/index/visualisation.py @@ -1,3 +1,4 @@ +from typing import Dict from enum import Enum import numpy as np @@ -35,20 +36,21 @@ def get_cohort_specific_color_code(cohort_name: str): return None -def enrichment_plot(acc_gpt, acc_mpnet, acc_fuzzy, title, save_plot=False, save_dir="resources/results/plots"): - if not (len(acc_gpt) == len(acc_fuzzy) == len(acc_mpnet)): - raise ValueError("acc_gpt, acc_mpnet and acc_fuzzy should be of the same length!") - data = {"Maximum Considered Rank": list(range(1, len(acc_gpt) + 1)), "GPT": acc_gpt, - "MPNet": acc_mpnet, "Fuzzy": acc_fuzzy} +def enrichment_plot(accuracies: Dict[str, [float]], title, save_plot=False, save_dir="resources/results/plots"): + lengths = set(len(lst) for lst in accuracies.values()) + if len(lengths) != 1: + raise ValueError("Accuracy scores of models should be of the same length!") + + data = copy.deepcopy(accuracies) + data["Maximum Considered Rank"] = list(range(1, list(lengths)[0] + 1)) df = pd.DataFrame(data) sns.set(style="whitegrid") - sns.lineplot(data=df, x="Maximum Considered Rank", y="GPT", label="GPT") - sns.lineplot(data=df, x="Maximum Considered Rank", y="MPNet", label="MPNet") - sns.lineplot(data=df, x="Maximum Considered Rank", y="Fuzzy", label="Fuzzy String Matching") + for model, _ in accuracies.items(): + sns.lineplot(data=df, x="Maximum Considered Rank", y=model, label=model) sns.set(style="whitegrid") plt.xlabel("Maximum Considered Rank") plt.ylabel("Accuracy") - plt.xticks(range(1, len(acc_gpt) + 1), labels=range(1, len(acc_gpt) + 1)) + plt.xticks(range(1, list(lengths)[0] + 1), labels=range(1, list(lengths)[0] + 1)) plt.yticks([i / 10 for i in range(11)]) plt.gca().set_yticklabels([f"{i:.1f}" for i in plt.gca().get_yticks()]) plt.title(title) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 207056f..802056c 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -1,21 +1,21 @@ import unittest -from index.embedding import MPNetAdapter, TextEmbedding +from index.embedding import SentenceTransformerAdapter, TextEmbedding import numpy as np class TestEmbedding(unittest.TestCase): def setUp(self): - self.mpnet_adapter = MPNetAdapter(model="sentence-transformers/all-mpnet-base-v2") + self.sentence_transformer_adapter = SentenceTransformerAdapter(model="sentence-transformers/all-mpnet-base-v2") - def test_mpnet_adapter_get_embedding(self): + def test_sentence_transformer_adapter_get_embedding(self): text = "This is a test sentence." - embedding = self.mpnet_adapter.get_embedding(text) + embedding = self.sentence_transformer_adapter.get_embedding(text) self.assertIsInstance(embedding, np.ndarray) self.assertEqual(len(embedding), 768) - def test_mpnet_adapter_get_embeddings(self): + def test_sentence_transformer_adapter_get_embeddings(self): messages = ["This is message 1.", "This is message 2."] - embeddings = self.mpnet_adapter.get_embeddings(messages) + embeddings = self.sentence_transformer_adapter.get_embeddings(messages) self.assertIsInstance(embeddings, list) self.assertEqual(len(embeddings), len(messages)) self.assertEqual(len(embeddings[0]), 768) diff --git a/tests/test_visualisation.py b/tests/test_visualisation.py index 66b4ac6..32f8af6 100644 --- a/tests/test_visualisation.py +++ b/tests/test_visualisation.py @@ -108,11 +108,11 @@ def test_scatter_plot_all_cohorts(self): ["A1", "A2"], ["B1", "B2"], store_html=False) def test_enrichment_plot(self): - acc_gpt = [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0] - acc_mpnet = [0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0] - acc_fuzzy = [0.2, 0.3, 0.4, 0.5, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0] + acc = {"M1": [0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0], + "M2": [0.3, 0.4, 0.5, 0.6, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0], + "M3": [0.2, 0.3, 0.4, 0.5, 0.8, 0.9, 1.0, 1.0, 1.0, 1.0]} title = "Test" - enrichment_plot(acc_gpt, acc_mpnet, acc_fuzzy, title, save_plot=False) + enrichment_plot(acc, title, save_plot=False) def test_bar_chart_average_acc_two_distributions(self): labels = ["M1", "M2", "M3"] From db831567438fa6a1c30882f7d29b2f57ff4ec7a7 Mon Sep 17 00:00:00 2001 From: Mehmet Can Ay Date: Fri, 23 Feb 2024 15:52:19 +0100 Subject: [PATCH 2/2] fix: typing error --- index/visualisation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/index/visualisation.py b/index/visualisation.py index c1c2427..60dc6b9 100644 --- a/index/visualisation.py +++ b/index/visualisation.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, List from enum import Enum import numpy as np @@ -36,7 +36,7 @@ def get_cohort_specific_color_code(cohort_name: str): return None -def enrichment_plot(accuracies: Dict[str, [float]], title, save_plot=False, save_dir="resources/results/plots"): +def enrichment_plot(accuracies: Dict[str, List[float]], title, save_plot=False, save_dir="resources/results/plots"): lengths = set(len(lst) for lst in accuracies.values()) if len(lengths) != 1: raise ValueError("Accuracy scores of models should be of the same length!")