Skip to content

Commit

Permalink
fix: model selection
Browse files Browse the repository at this point in the history
  • Loading branch information
mehmetcanay committed Sep 16, 2024
1 parent 651326d commit 08d20cd
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,19 @@ async def create_mapping(concept_id: str, text: str):


@app.post("/mappings", tags=["mappings"])
async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOMED CT",
sentence_embedder: str = "sentence-transformers/all-mpnet-base-v2", limit: int = 5):
async def get_closest_mappings_for_text(text: str,
terminology_name: str = "SNOMED CT",
model: str = "sentence-transformers/all-mpnet-base-v2",
limit: int = 5):
if model == "text-embedding-ada-002":
embedding_model = GPT4Adapter(model)
elif model == "sentence-transformers/all-mpnet-base-v2":
embedding_model = MPNetAdapter(model)
else:
raise HTTPException(status_code=400, detail="Unsupported embedding model.")

embedding = embedding_model.get_embedding(text).tolist()
closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, sentence_embedder, limit)
closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(embedding, terminology_name, model, limit)
mappings = []
for mapping, similarity in closest_mappings:
concept = mapping.concept
Expand All @@ -201,15 +210,15 @@ async def get_closest_mappings_for_text(text: str, terminology_name: str = "SNOM
# Endpoint to get mappings for a data dictionary source
@app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.")
async def get_closest_mappings_for_dictionary(file: UploadFile = File(...),
selected_model: str = "sentence-transformers/all-mpnet-base-v2",
selected_terminology: str = "SNOMED CT",
model: str = "sentence-transformers/all-mpnet-base-v2",
terminology_name: str = "SNOMED CT",
variable_field: str = "variable",
description_field: str = "description"):
try:
if selected_model == "text-embedding-ada-002":
embedding_model = GPT4Adapter(selected_model)
elif selected_model == "sentence-transformers/all-mpnet-base-v2":
embedding_model = MPNetAdapter(selected_model)
if model == "text-embedding-ada-002":
embedding_model = GPT4Adapter(model)
elif model == "sentence-transformers/all-mpnet-base-v2":
embedding_model = MPNetAdapter(model)
else:
raise HTTPException(status_code=400, detail="Unsupported embedding model.")

Expand All @@ -234,7 +243,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...),
description = row['description']
embedding = embedding_model.get_embedding(description)
closest_mappings = repository.get_terminology_and_model_specific_closest_mappings(
embedding, selected_terminology, selected_model, limit=5
embedding, terminology_name, model, limit=5
)
mappings_list = []
for mapping, similarity in closest_mappings:
Expand Down

0 comments on commit 08d20cd

Please sign in to comment.