Skip to content

Commit

Permalink
Create embedding model dynamically
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Sep 6, 2024
1 parent e33976b commit 8eb88a5
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ def connect_to_remote_weaviate_repository():
try:
return WeaviateRepository(mode="remote", path=weaviate_url)
except Exception as e:
logger.info(f"Attempt {i + 1} failed: {e}")
logger.info(f"Attempt {i + 1} to connect to {weaviate_url} failed with error: {e}")
time.sleep(5)
raise ConnectionError("Could not connect to Weaviate after multiple attempts.")


logger = logging.getLogger("uvicorn.info")
repository = connect_to_remote_weaviate_repository()
embedding_model = MPNetAdapter()
db_plot_html = None

origins = ["*"]
Expand Down Expand Up @@ -119,7 +118,7 @@ async def create_terminology(id: str, name: str):
return {"message": f"Terminology {id} created successfully"}
except Exception as e:
raise HTTPException(status_code=400, detail=f"Failed to create terminology: {str(e)}")


@app.get("/models", tags=["models"])
async def get_all_models():
Expand Down Expand Up @@ -151,11 +150,16 @@ async def get_all_mappings():


@app.put("/concepts/{id}/mappings", tags=["concepts", "mappings"])
async def create_concept_and_attach_mapping(id: str, concept_name: str, terminology_name, text: str):
async def create_concept_and_attach_mapping(id: str,
concept_name: str,
terminology_name,
text: str,
model: str = "sentence-transformers/all-mpnet-base-v2"):
try:
terminology = repository.get_terminology(terminology_name)
concept = Concept(terminology=terminology, pref_label=concept_name, concept_identifier=id)
repository.store(concept)
embedding_model = MPNetAdapter(model)
embedding = embedding_model.get_embedding(text)
model_name = embedding_model.get_model_name()
mapping = Mapping(concept=concept, text=text, embedding=embedding, sentence_embedder=model_name)
Expand All @@ -166,9 +170,12 @@ async def create_concept_and_attach_mapping(id: str, concept_name: str, termino


@app.put("/mappings/", tags=["mappings"])
async def create_mapping(concept_id: str, text: str):
async def create_mapping(concept_id: str,
text: str,
model: str = "sentence-transformers/all-mpnet-base-v2"):
try:
concept = repository.get_concept(concept_id)
embedding_model = MPNetAdapter(model)
embedding = embedding_model.get_embedding(text)
model_name = embedding_model.get_model_name()
mapping = Mapping(concept=concept, text=text, embedding=embedding, sentence_embedder=model_name)
Expand All @@ -179,10 +186,14 @@ async def create_mapping(concept_id: str, text: str):


@app.post("/mappings", tags=["mappings"])
async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOMED CT",
sentence_embedder: str = "sentence-transformers/all-mpnet-base-v2", limit: int = 5):
async def get_closest_mappings_for_text(text: str,
terminology_name: str = "SNOMED CT",
model: str = "sentence-transformers/all-mpnet-base-v2",
limit: int = 5):
embedding_model = MPNetAdapter(model)
embedding = embedding_model.get_embedding(text).tolist()
closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, sentence_embedder, limit)
closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name,
model, limit)
mappings = []
for mapping, similarity in closest_mappings:
concept = mapping.concept
Expand All @@ -205,8 +216,10 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOM

# Endpoint to get mappings for a data dictionary source
@app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.")
async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable',
description_field: str = 'description'):
async def get_closest_mappings_for_dictionary(file: UploadFile = File(...),
variable_field: str = 'variable',
description_field: str = 'description',
model: str = "sentence-transformers/all-mpnet-base-v2"):
try:
# Determine file extension and create a temporary file with the correct extension
_, file_extension = os.path.splitext(file.filename)
Expand All @@ -223,6 +236,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari
for _, row in df.iterrows():
variable = row['variable']
description = row['description']
embedding_model = MPNetAdapter(model)
embedding = embedding_model.get_embedding(description)
closest_mappings, similarities = repository.get_closest_mappings(embedding, limit=5)
mappings_list = []
Expand Down Expand Up @@ -255,14 +269,24 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari
raise HTTPException(status_code=500, detail=str(e))


def import_snomed_ct_task():
def import_snomed_ct_task(model: str = "sentence-transformers/all-mpnet-base-v2"):
embedding_model = MPNetAdapter(model)
task = OLSTerminologyImportTask(repository, embedding_model, "SNOMED CT", "snomed")
task.process()


@app.put("/import/terminology/snomed", description="Import whole SNOMED CT from OLS.", tags=["import", "tasks"])
async def import_snomed_ct(background_tasks: BackgroundTasks):
background_tasks.add_task(import_snomed_ct_task)
def import_ols_terminology_task(terminology_id, model: str = "sentence-transformers/all-mpnet-base-v2"):
embedding_model = MPNetAdapter(model)
task = OLSTerminologyImportTask(repository, embedding_model, terminology_id, terminology_id)
task.process()


@app.put("/import/terminology/snomed",
description="Import whole SNOMED CT from OLS.",
tags=["import", "tasks"])
async def import_snomed_ct(background_tasks: BackgroundTasks,
model: str = "sentence-transformers/all-mpnet-base-v2"):
background_tasks.add_task(import_snomed_ct_task, model=model)
return {"message": "SNOMED CT import started in the background"}


Expand Down

0 comments on commit 8eb88a5

Please sign in to comment.