From 8eb88a506de1eb3aa59b7d8104cee37cf74187c4 Mon Sep 17 00:00:00 2001 From: Tim Adams Date: Fri, 6 Sep 2024 16:04:23 +0200 Subject: [PATCH] Create embedding model dynamically --- api/routes.py | 52 +++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/api/routes.py b/api/routes.py index eba19b1..2ae117f 100644 --- a/api/routes.py +++ b/api/routes.py @@ -54,14 +54,13 @@ def connect_to_remote_weaviate_repository(): try: return WeaviateRepository(mode="remote", path=weaviate_url) except Exception as e: - logger.info(f"Attempt {i + 1} failed: {e}") + logger.info(f"Attempt {i + 1} to connect to {weaviate_url} failed with error: {e}") time.sleep(5) raise ConnectionError("Could not connect to Weaviate after multiple attempts.") logger = logging.getLogger("uvicorn.info") repository = connect_to_remote_weaviate_repository() -embedding_model = MPNetAdapter() db_plot_html = None origins = ["*"] @@ -119,7 +118,7 @@ async def create_terminology(id: str, name: str): return {"message": f"Terminology {id} created successfully"} except Exception as e: raise HTTPException(status_code=400, detail=f"Failed to create terminology: {str(e)}") - + @app.get("/models", tags=["models"]) async def get_all_models(): @@ -151,11 +150,16 @@ async def get_all_mappings(): @app.put("/concepts/{id}/mappings", tags=["concepts", "mappings"]) -async def create_concept_and_attach_mapping(id: str, concept_name: str, terminology_name, text: str): +async def create_concept_and_attach_mapping(id: str, + concept_name: str, + terminology_name, + text: str, + model: str = "sentence-transformers/all-mpnet-base-v2"): try: terminology = repository.get_terminology(terminology_name) concept = Concept(terminology=terminology, pref_label=concept_name, concept_identifier=id) repository.store(concept) + embedding_model = MPNetAdapter(model) embedding = embedding_model.get_embedding(text) model_name = embedding_model.get_model_name() mapping = Mapping(concept=concept, text=text, embedding=embedding, sentence_embedder=model_name) @@ -166,9 +170,12 @@ async def create_concept_and_attach_mapping(id: str, concept_name: str, termino @app.put("/mappings/", tags=["mappings"]) -async def create_mapping(concept_id: str, text: str): +async def create_mapping(concept_id: str, + text: str, + model: str = "sentence-transformers/all-mpnet-base-v2"): try: concept = repository.get_concept(concept_id) + embedding_model = MPNetAdapter(model) embedding = embedding_model.get_embedding(text) model_name = embedding_model.get_model_name() mapping = Mapping(concept=concept, text=text, embedding=embedding, sentence_embedder=model_name) @@ -179,10 +186,14 @@ async def create_mapping(concept_id: str, text: str): @app.post("/mappings", tags=["mappings"]) -async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOMED CT", - sentence_embedder: str = "sentence-transformers/all-mpnet-base-v2", limit: int = 5): +async def get_closest_mappings_for_text(text: str, + terminology_name: str = "SNOMED CT", + model: str = "sentence-transformers/all-mpnet-base-v2", + limit: int = 5): + embedding_model = MPNetAdapter(model) embedding = embedding_model.get_embedding(text).tolist() - closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, sentence_embedder, limit) + closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, + model, limit) mappings = [] for mapping, similarity in closest_mappings: concept = mapping.concept @@ -205,8 +216,10 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOM # Endpoint to get mappings for a data dictionary source @app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.") -async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', - description_field: str = 'description'): +async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), + variable_field: str = 'variable', + description_field: str = 'description', + model: str = "sentence-transformers/all-mpnet-base-v2"): try: # Determine file extension and create a temporary file with the correct extension _, file_extension = os.path.splitext(file.filename) @@ -223,6 +236,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari for _, row in df.iterrows(): variable = row['variable'] description = row['description'] + embedding_model = MPNetAdapter(model) embedding = embedding_model.get_embedding(description) closest_mappings, similarities = repository.get_closest_mappings(embedding, limit=5) mappings_list = [] @@ -255,14 +269,24 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari raise HTTPException(status_code=500, detail=str(e)) -def import_snomed_ct_task(): +def import_snomed_ct_task(model: str = "sentence-transformers/all-mpnet-base-v2"): + embedding_model = MPNetAdapter(model) task = OLSTerminologyImportTask(repository, embedding_model, "SNOMED CT", "snomed") task.process() -@app.put("/import/terminology/snomed", description="Import whole SNOMED CT from OLS.", tags=["import", "tasks"]) -async def import_snomed_ct(background_tasks: BackgroundTasks): - background_tasks.add_task(import_snomed_ct_task) +def import_ols_terminology_task(terminology_id, model: str = "sentence-transformers/all-mpnet-base-v2"): + embedding_model = MPNetAdapter(model) + task = OLSTerminologyImportTask(repository, embedding_model, terminology_id, terminology_id) + task.process() + + +@app.put("/import/terminology/snomed", + description="Import whole SNOMED CT from OLS.", + tags=["import", "tasks"]) +async def import_snomed_ct(background_tasks: BackgroundTasks, + model: str = "sentence-transformers/all-mpnet-base-v2"): + background_tasks.add_task(import_snomed_ct_task, model=model) return {"message": "SNOMED CT import started in the background"}