From 48abe712d8cd760fc25130a5ffde96d60004f654 Mon Sep 17 00:00:00 2001 From: TimAdams84 Date: Thu, 22 Feb 2024 15:06:00 +0100 Subject: [PATCH] Feat: Extend API and add Unit Test --- README.md | 2 +- index/api/routes.py | 49 ++++++- index/db/model.py | 3 + index/main.py | 281 ------------------------------------ index/repository/sqllite.py | 2 +- tests/test_system.py | 39 +++++ 6 files changed, 91 insertions(+), 285 deletions(-) delete mode 100644 index/main.py create mode 100644 tests/test_system.py diff --git a/README.md b/README.md index 828b705..3e981bf 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ pip install -r requirements.txt Run the Backend API on port 5000: ```bash -uvicorn main:app --reload --port 5000 +uvicorn index.api.routes:app --reload --port 5000 ``` ### Run the Backend via Docker diff --git a/index/api/routes.py b/index/api/routes.py index 75aaa59..9c9b0e1 100644 --- a/index/api/routes.py +++ b/index/api/routes.py @@ -1,9 +1,12 @@ +import json import logging +from typing import Dict -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from starlette.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse +from index.db.model import Terminology, Concept, Mapping from index.repository.sqllite import SQLLiteRepository from index.embedding import MPNetAdapter @@ -48,9 +51,51 @@ def get_current_version(): return app.version +@app.put("/terminologies/{id}", tags=["terminologies"]) +async def create_or_update_terminology(id: str, name: str): + try: + terminology = Terminology(name=name, id=id) + repository.store(terminology) + return {"message": f"Terminology {id} created or updated successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to create or update terminology: {str(e)}") + + +@app.put("/concepts/{id}", tags=["concepts"]) +async def create_or_update_concept(id: str, terminology_id: str, name: str): + try: + terminology = repository.session.query(Terminology).filter(Terminology.id == terminology_id).first() + if not terminology: + raise HTTPException(status_code=404, detail=f"Terminology with id {terminology_id} not found") + + concept = Concept(terminology=terminology, name=name, id=id) + repository.store(concept) + return {"message": f"Concept {id} created or updated successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to create or update concept: {str(e)}") + + +@app.put("/mappings/{id}", tags=["mappings"]) +async def create_or_update_mapping(id: str, concept_id: str, text: str): + try: + concept = repository.session.query(Concept).filter(Concept.id == concept_id).first() + if not concept: + raise HTTPException(status_code=404, detail=f"Concept with id {concept_id} not found") + embedding = embedding_model.get_embedding(text) + # Convert embedding from numpy array to list + embedding_list = embedding.tolist() + print(embedding_list) + mapping = Mapping(concept=concept, text=text, embedding=json.dumps(embedding_list)) + repository.store(mapping) + return {"message": f"Mapping {id} created or updated successfully"} + except Exception as e: + raise HTTPException(status_code=400, detail=f"Failed to create or update mapping: {str(e)}") + + @app.post("/mappings", tags=["mappings"]) async def get_closest_mappings_for_text(text: str): - embedding = embedding_model.get_embedding(text) + embedding = embedding_model.get_embedding(text).tolist() + print(embedding) closest_mappings, similarities = repository.get_closest_mappings(embedding) response_data = [] for mapping, similarity in zip(closest_mappings, similarities): diff --git a/index/db/model.py b/index/db/model.py index 074baec..4a89d27 100644 --- a/index/db/model.py +++ b/index/db/model.py @@ -1,5 +1,6 @@ import json +import numpy as np from sqlalchemy import Column, ForeignKey, Integer, String, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship @@ -41,6 +42,8 @@ class Mapping(Base): def __init__(self, concept: Concept, text: str, embedding: list): self.concept = concept self.text = text + if isinstance(embedding, np.ndarray): + embedding = embedding.tolist() self.embedding_json = json.dumps(embedding) # Store embedding as JSON @property diff --git a/index/main.py b/index/main.py deleted file mode 100644 index c5bcf35..0000000 --- a/index/main.py +++ /dev/null @@ -1,281 +0,0 @@ -import os -import sys - -sys.path.append("../") -import pandas as pd - -from index import evaluation -from index.conf import PD_CDM_SRC, PPMI_DICT_SRC, LUXPARK_DICT_SRC, BIOFIND_DICT_SRC, AD_CDM_SRC -from index.embedding import GPT4Adapter, MPNetAdapter -from index.evaluation import match_closest_descriptions, MatchingMethod, enrichment_analysis, evaluate -from index.mapping import MappingTable -from index.parsing import MappingSource, DataDictionarySource -from dotenv import load_dotenv - -from index.visualisation import scatter_plot_two_distributions, enrichment_plot, scatter_plot_all_cohorts, \ - bar_chart_average_acc_two_distributions - -EVAL_PD = True -EVAL_AD = True - -load_dotenv() -gpt4 = GPT4Adapter(api_key=os.getenv("GPT_KEY")) # type: ignore -mpnet = MPNetAdapter() - -# PD Mappings - -if EVAL_PD: - cdm_pd_gpt = MappingTable(MappingSource(PD_CDM_SRC, "Feature", "CURIE")) - cdm_pd_gpt.joined_mapping_table["identifier"].to_csv("resources/cdm_curie.csv", index=False) - cdm_pd_gpt.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_pd_gpt.compute_embeddings(gpt4) - - cdm_pd_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "Feature", "CURIE")) - cdm_pd_mpnet.joined_mapping_table["identifier"].to_csv("resources/cdm_curie.csv", index=False) - cdm_pd_mpnet.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_pd_mpnet.compute_embeddings(mpnet) - - ppmi_gpt = MappingTable(MappingSource(PD_CDM_SRC, "PPMI", "CURIE")) - ppmi_gpt.add_descriptions(DataDictionarySource(PPMI_DICT_SRC, "ITM_NAME", "DSCR")) - ppmi_gpt.compute_embeddings(gpt4) - - ppmi_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "PPMI", "CURIE")) - ppmi_mpnet.add_descriptions(DataDictionarySource(PPMI_DICT_SRC, "ITM_NAME", "DSCR")) - ppmi_mpnet.compute_embeddings(mpnet) - - luxpark_gpt = MappingTable(MappingSource(PD_CDM_SRC, "LuxPARK", "CURIE")) - luxpark_gpt.add_descriptions(DataDictionarySource(LUXPARK_DICT_SRC, "Variable / Field Name", "Field Label")) - luxpark_gpt.compute_embeddings(gpt4) - - luxpark_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "LuxPARK", "CURIE")) - luxpark_mpnet.add_descriptions(DataDictionarySource(LUXPARK_DICT_SRC, "Variable / Field Name", "Field Label")) - luxpark_mpnet.compute_embeddings(mpnet) - - biofind_gpt = MappingTable(MappingSource(PD_CDM_SRC, "BIOFIND", "CURIE")) - biofind_gpt.add_descriptions(DataDictionarySource(BIOFIND_DICT_SRC, "ITM_NAME", "DSCR")) - biofind_gpt.compute_embeddings(gpt4) - - biofind_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "BIOFIND", "CURIE")) - biofind_mpnet.add_descriptions(DataDictionarySource(BIOFIND_DICT_SRC, "ITM_NAME", "DSCR")) - biofind_mpnet.compute_embeddings(mpnet) - - lrrk2_gpt = MappingTable(MappingSource(PD_CDM_SRC, "LRRK2", "CURIE")) - lrrk2_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/pd/LRRK2.xlsx", "Variable", "Label")) - lrrk2_gpt.compute_embeddings(gpt4) - - lrrk2_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "LRRK2", "CURIE")) - lrrk2_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/pd/LRRK2.xlsx", "Variable", "Label")) - lrrk2_mpnet.compute_embeddings(mpnet) - - opdc_gpt = MappingTable(MappingSource(PD_CDM_SRC, "OPDC", "CURIE")) - opdc_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/pd/OPDC.csv", "Variable Name", "Variable description")) - opdc_gpt.compute_embeddings(gpt4) - - opdc_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "OPDC", "CURIE")) - opdc_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/pd/OPDC.csv", "Variable Name", "Variable description")) - opdc_mpnet.compute_embeddings(mpnet) - - tpd_gpt = MappingTable(MappingSource(PD_CDM_SRC, "TPD", "CURIE")) - tpd_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/pd/TPD.csv", "Variable Name", "Variable description")) - tpd_gpt.compute_embeddings(gpt4) - - tpd_mpnet = MappingTable(MappingSource(PD_CDM_SRC, "TPD", "CURIE")) - tpd_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/pd/TPD.csv", "Variable Name", "Variable description")) - tpd_mpnet.compute_embeddings(mpnet) - - pd_datasets_gpt = [opdc_gpt, tpd_gpt, biofind_gpt, lrrk2_gpt, luxpark_gpt, ppmi_gpt, cdm_pd_gpt] - pd_datasets_mpnet = [opdc_mpnet, tpd_mpnet, biofind_mpnet, lrrk2_mpnet, luxpark_mpnet, ppmi_mpnet, cdm_pd_mpnet] - pd_datasets_labels = ["OPDC", "PRoBaND", "BIOFIND", "LCC", "LuxPARK", "PPMI", "PASSIONATE"] - - # enrichment analysis - luxpark_passionate_enrichment_gpt = enrichment_analysis(luxpark_gpt, cdm_pd_gpt, 20, MatchingMethod.EUCLIDEAN_EMBEDDING_DISTANCE) - luxpark_passionate_enrichment_mpnet = enrichment_analysis(luxpark_mpnet, cdm_pd_mpnet, 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) - luxpark_passionate_enrichment_fuzzy = enrichment_analysis(luxpark_gpt, cdm_pd_gpt, 20, MatchingMethod.FUZZY_STRING_MATCHING) - label1 = "Enrichment Plot LuxPARK to CDM" - ppmi_passionate_enrichment_gpt = enrichment_analysis(ppmi_gpt, cdm_pd_gpt, 20, MatchingMethod.EUCLIDEAN_EMBEDDING_DISTANCE) - ppmi_passionate_enrichment_mpnet = enrichment_analysis(ppmi_mpnet, cdm_pd_mpnet, 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) - ppmi_passionate_enrichment_fuzzy = enrichment_analysis(ppmi_gpt, cdm_pd_gpt, 20, MatchingMethod.FUZZY_STRING_MATCHING) - label2 = "Enrichment Plot PPMI to CDM" - ppmi_luxpark_enrichment_gpt = enrichment_analysis(ppmi_gpt, luxpark_gpt, 20, MatchingMethod.EUCLIDEAN_EMBEDDING_DISTANCE) - ppmi_luxpark_enrichment_mpnet = enrichment_analysis(ppmi_mpnet, luxpark_mpnet, 20, MatchingMethod.COSINE_EMBEDDING_DISTANCE) - ppmi_luxpark_enrichment_fuzzy = enrichment_analysis(ppmi_gpt, luxpark_gpt, 20, MatchingMethod.FUZZY_STRING_MATCHING) - label3 = "Enrichment Plot PPMI to LuxPARK" - enrichment_plot(luxpark_passionate_enrichment_gpt, luxpark_passionate_enrichment_mpnet, luxpark_passionate_enrichment_fuzzy, label1, save_plot=True) - enrichment_plot(ppmi_passionate_enrichment_gpt, ppmi_passionate_enrichment_mpnet, ppmi_passionate_enrichment_fuzzy, label2, save_plot=True) - enrichment_plot( ppmi_luxpark_enrichment_gpt, ppmi_luxpark_enrichment_mpnet, ppmi_luxpark_enrichment_fuzzy, label3, save_plot=True) - print(luxpark_passionate_enrichment_gpt) - print(luxpark_passionate_enrichment_mpnet) - print(luxpark_passionate_enrichment_fuzzy) - print(ppmi_passionate_enrichment_gpt) - print(ppmi_passionate_enrichment_mpnet) - print(ppmi_passionate_enrichment_fuzzy) - print(ppmi_luxpark_enrichment_gpt) - print(ppmi_luxpark_enrichment_mpnet) - print(ppmi_luxpark_enrichment_fuzzy) - - gpt_table1 = evaluate(pd_datasets_gpt, pd_datasets_labels, store_results=True) - fuzzy_table1 = evaluate(pd_datasets_gpt, pd_datasets_labels, store_results=True, model="fuzzy") - mpnet_table1 = evaluate(pd_datasets_mpnet, pd_datasets_labels, store_results=True, model="mpnet") - - print("PD RESULTS:") - print("GPT") - print("-----------") - print(gpt_table1) - print("-----------") - print("MPNet") - print("-----------") - print(mpnet_table1) - print("-----------") - print("Fuzzy") - print("-----------") - print(fuzzy_table1) - print("-----------") - -# AD Mappings - -if EVAL_AD: - cdm_ad_gpt = cdm_pd_gpt = MappingTable(MappingSource(AD_CDM_SRC, "Feature", "CURIE")) - cdm_ad_gpt.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_ad_gpt.compute_embeddings(gpt4) - - cdm_ad_mpnet = cdm_pd_gpt = MappingTable(MappingSource(AD_CDM_SRC, "Feature", "CURIE")) - cdm_ad_mpnet.add_descriptions(DataDictionarySource(PD_CDM_SRC, "Feature", "Definition")) - cdm_ad_mpnet.compute_embeddings(mpnet) - - a4_gpt = MappingTable(MappingSource(AD_CDM_SRC, "A4", "CURIE")) - a4_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/a4.csv", "FLDNAME", "TEXT")) - a4_gpt.compute_embeddings(gpt4) - - a4_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "A4", "CURIE")) - a4_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/a4.csv", "FLDNAME", "TEXT")) - a4_mpnet.compute_embeddings(mpnet) - - abvib_gpt = MappingTable(MappingSource(AD_CDM_SRC, "ABVIB", "CURIE")) - abvib_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/abvib.csv", "variable_name", "description")) - abvib_gpt.compute_embeddings(gpt4) - - abvib_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "ABVIB", "CURIE")) - abvib_mpnet.add_descriptions( DataDictionarySource("resources/dictionaries/ad/abvib.csv", "variable_name", "description")) - abvib_mpnet.compute_embeddings(mpnet) - - adni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "ADNI", "CURIE")) - adni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/ADNIMERGE_DICT_27Nov2023 2.csv", "FLDNAME", "TEXT")) - adni_gpt.compute_embeddings(gpt4) - - adni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "ADNI", "CURIE")) - adni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/ADNIMERGE_DICT_27Nov2023 2.csv", "FLDNAME", "TEXT")) - adni_mpnet.compute_embeddings(mpnet) - - aibl_gpt = MappingTable(MappingSource(AD_CDM_SRC, "AIBL", "CURIE")) - aibl_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/aibl.csv", "Name", "Description")) - aibl_gpt.compute_embeddings(gpt4) - - aibl_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "AIBL", "CURIE")) - aibl_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/aibl.csv", "Name", "Description")) - aibl_mpnet.compute_embeddings(mpnet) - - arwibo_gpt = MappingTable(MappingSource(AD_CDM_SRC, "ARWIBO", "CURIE")) - arwibo_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/arwibo.csv", "Variable_Name", "Element_description")) - arwibo_gpt.compute_embeddings(gpt4) - - arwibo_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "ARWIBO", "CURIE")) - arwibo_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/arwibo.csv", "Variable_Name", "Element_description")) - arwibo_mpnet.compute_embeddings(mpnet) - - dod_adni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "DOD-ADNI", "CURIE")) - # TODO most descriptions missing - dod_adni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/dod-adni.csv", "FLDNAME", "TEXT")) - dod_adni_gpt.compute_embeddings(gpt4) - - dod_adni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "DOD-ADNI", "CURIE")) - # TODO most descriptions missing - dod_adni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/dod-adni.csv", "FLDNAME", "TEXT")) - dod_adni_mpnet.compute_embeddings(mpnet) - - edsd_gpt = MappingTable(MappingSource(AD_CDM_SRC, "EDSD", "CURIE")) - edsd_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/edsd.xlsx", "Variable_Name", "Element_description")) - edsd_gpt.compute_embeddings(gpt4) - - edsd_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "EDSD", "CURIE")) - edsd_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/edsd.xlsx", "Variable_Name", "Element_description")) - edsd_mpnet.compute_embeddings(mpnet) - - emif_gpt = MappingTable(MappingSource(AD_CDM_SRC, "EMIF", "CURIE")) - emif_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/emif.xlsx", "Variable", "Description")) - emif_gpt.compute_embeddings(gpt4) - - emif_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "EMIF", "CURIE")) - emif_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/emif.xlsx", "Variable", "Description")) - emif_mpnet.compute_embeddings(mpnet) - - i_adni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "I-ADNI", "CURIE")) - # TODO about half of descriptions missing - i_adni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/i-adni.csv", "acronym", "variable")) - i_adni_gpt.compute_embeddings(gpt4) - - i_adni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "I-ADNI", "CURIE")) - # TODO about half of descriptions missing - i_adni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/i-adni.csv", "acronym", "variable")) - i_adni_mpnet.compute_embeddings(mpnet) - - jadni_gpt = MappingTable(MappingSource(AD_CDM_SRC, "JADNI", "CURIE")) - jadni_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/jadni.tsv", "FLDNAME", "TEXT")) - jadni_gpt.compute_embeddings(gpt4) - - jadni_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "JADNI", "CURIE")) - jadni_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/jadni.tsv", "FLDNAME", "TEXT")) - jadni_mpnet.compute_embeddings(mpnet) - - pharmacog_gpt = MappingTable(MappingSource(AD_CDM_SRC, "PharmaCog", "CURIE")) - pharmacog_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/pharmacog.csv", "Variable_Name", "Element_description")) - pharmacog_gpt.compute_embeddings(gpt4) - - pharmacog_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "PharmaCog", "CURIE")) - pharmacog_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/pharmacog.csv", "Variable_Name", "Element_description")) - pharmacog_mpnet.compute_embeddings(mpnet) - - prevent_ad_gpt = MappingTable(MappingSource(AD_CDM_SRC, "PREVENT-AD", "CURIE")) - prevent_ad_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/prevent-ad.csv", "variable", "description")) - prevent_ad_gpt.compute_embeddings(gpt4) - - prevent_ad_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "PREVENT-AD", "CURIE")) - prevent_ad_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/prevent-ad.csv", "variable", "description")) - prevent_ad_mpnet.compute_embeddings(mpnet) - - vita_gpt = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) - vita_gpt.add_descriptions(DataDictionarySource("resources/dictionaries/ad/vita.csv", "Variable_Name", "Element_description")) - vita_gpt.compute_embeddings(gpt4) - - vita_mpnet = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) - vita_mpnet.add_descriptions(DataDictionarySource("resources/dictionaries/ad/vita.csv", "Variable_Name", "Element_description")) - vita_mpnet.compute_embeddings(mpnet) - - wmh_ad = MappingTable(MappingSource(AD_CDM_SRC, "VITA", "CURIE")) - - ad_datasets_gpt = [a4_gpt, abvib_gpt, adni_gpt, aibl_gpt, arwibo_gpt, dod_adni_gpt, edsd_gpt, emif_gpt, i_adni_gpt, jadni_gpt, - pharmacog_gpt, prevent_ad_gpt, vita_gpt, cdm_ad_gpt] - ad_datasets_mpnet = [a4_mpnet, abvib_mpnet, adni_mpnet, aibl_mpnet, arwibo_mpnet, dod_adni_mpnet, edsd_mpnet, emif_mpnet, - i_adni_mpnet, jadni_mpnet, pharmacog_mpnet, prevent_ad_mpnet, vita_mpnet, cdm_ad_mpnet] - ad_datasets_labels = ["A4", "ABVIB", "ADNI", "AIBL", "ARWIBO", "DOD-ADNI", "EDSD", "EMIF", "I-ADNI", "JADNI", "PharmaCog", - "PREVENT-AD", "VITA", "AD-Mapper"] - gpt_table2 = evaluate(ad_datasets_gpt, ad_datasets_labels, store_results=True, results_root_dir="resources/results/ad") - fuzzy_table2 = evaluate(ad_datasets_gpt, ad_datasets_labels, store_results=True, model="fuzzy", results_root_dir="resources/results/ad") - mpnet_table2 = evaluate(ad_datasets_mpnet, ad_datasets_labels, store_results=True, model="mpnet", results_root_dir="resources/results/ad") - - print("AD RESULTS:") - print("GPT") - print("-----------") - print(gpt_table2.to_string()) - print("-----------") - print("MPNet") - print("-----------") - print(mpnet_table2.to_string()) - print("-----------") - print("Fuzzy") - print("-----------") - print(fuzzy_table2.to_string()) - print("-----------") - -# embedding distribution -scatter_plot_two_distributions(pd_datasets_gpt, ad_datasets_gpt, "PD", "AD") -scatter_plot_all_cohorts(pd_datasets_gpt, ad_datasets_gpt, pd_datasets_labels, ad_datasets_labels) diff --git a/index/repository/sqllite.py b/index/repository/sqllite.py index aa6cd8f..6892029 100644 --- a/index/repository/sqllite.py +++ b/index/repository/sqllite.py @@ -44,7 +44,7 @@ def get_closest_mappings(self, embedding, limit=5): ) distances_and_mappings = distances_and_mappings_query.all() - + print(distances_and_mappings) # Sort results based on distances sorted_results = sorted(distances_and_mappings, key=lambda x: x[1]) diff --git a/tests/test_system.py b/tests/test_system.py new file mode 100644 index 0000000..519a955 --- /dev/null +++ b/tests/test_system.py @@ -0,0 +1,39 @@ +import unittest + +from index.db.model import Terminology, Concept, Mapping +from index.embedding import MPNetAdapter +from index.repository.sqllite import SQLLiteRepository + + +class TestGetClosestEmbedding(unittest.TestCase): + + def setUp(self): + self.repository = SQLLiteRepository(mode="memory") + self.embedding_model = MPNetAdapter() + + def tearDown(self): + self.repository.shut_down() + + def test_mapping_storage_and_closest_retrieval(self): + # preset knowledge + terminology = Terminology("test", "test") + concept1 = Concept(terminology, "depression", "TEST:1") + concept1_description = "A heavy fog obscures joy, suffocating hope in an endless struggle." + mapping1 = Mapping(concept1, concept1_description, self.embedding_model.get_embedding(concept1_description)) + concept2 = Concept(terminology, "euphoria", "TEST:2") + concept2_description = "An intense state of joy and elation, engulfing the senses." + mapping2 = Mapping(concept2, concept2_description, self.embedding_model.get_embedding(concept2_description)) + self.repository.store_all([terminology, concept1, mapping1, concept2, mapping2]) + # test new mappings + text1 = "Trapped in glass, distant from life's vibrancy, feeling isolated and disconnected." + text1_embedding = self.embedding_model.get_embedding(text1) + text2 = "A profound feeling of happiness, exuberance, and boundless positivity." + text2_embedding = self.embedding_model.get_embedding(text2) + mappings1, distances1 = self.repository.get_closest_mappings(text1_embedding, limit=2) + mappings2, distances2 = self.repository.get_closest_mappings(text2_embedding, limit=2) + self.assertEqual(len(mappings1), 2) + self.assertEqual(len(mappings1), 2) + self.assertEqual(concept1_description, mappings1[0].text) + self.assertEqual(concept2_description, mappings2[0].text) + +