Skip to content

Commit

Permalink
Merge pull request #47 from SCAI-BIO/add-file-upload
Browse files Browse the repository at this point in the history
Add file upload
  • Loading branch information
tiadams authored Jul 4, 2024
2 parents 519999d + 7ece780 commit f708538
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 46 deletions.
43 changes: 0 additions & 43 deletions .github/workflows/python-publish.yaml

This file was deleted.

54 changes: 51 additions & 3 deletions datastew/api/routes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import os
import tempfile

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse, HTMLResponse

from datastew import DataDictionarySource
from datastew.repository.model import Terminology, Concept, Mapping
from datastew.embedding import MPNetAdapter
from datastew.repository.sqllite import SQLLiteRepository
Expand Down Expand Up @@ -155,10 +158,10 @@ async def create_or_update_mapping(concept_id: str, text: str):


@app.post("/mappings", tags=["mappings"])
async def get_closest_mappings_for_text(text: str):
async def get_closest_mappings_for_text(text: str, limit: int = 5):
embedding = embedding_model.get_embedding(text).tolist()
print(embedding)
closest_mappings, similarities = repository.get_closest_mappings(embedding)
closest_mappings, similarities = repository.get_closest_mappings(embedding, limit)
response_data = []
for mapping, similarity in zip(closest_mappings, similarities):
concept = mapping.concept
Expand All @@ -178,5 +181,50 @@ async def get_closest_mappings_for_text(text: str):
return response_data


# Endpoint to get mappings for a data dictionary source
@app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.")
async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', description_field: str = 'description'):
try:
# Determine file extension and create a temporary file with the correct extension
_, file_extension = os.path.splitext(file.filename)
with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
tmp_file.write(await file.read())
tmp_file_path = tmp_file.name

# Initialize DataDictionarySource with the temporary file path
data_dict_source = DataDictionarySource(file_path=tmp_file_path, variable_field=variable_field, description_field=description_field)
df = data_dict_source.to_dataframe()

response = {}
for _, row in df.iterrows():
variable = row['variable']
description = row['description']
embedding = embedding_model.get_embedding(description)
closest_mappings, similarities = repository.get_closest_mappings(embedding, limit=5)
mappings_list = []
for mapping, similarity in zip(closest_mappings, similarities):
concept = mapping.concept
terminology = concept.terminology
mappings_list.append({
"concept": {
"id": concept.id,
"name": concept.name,
"terminology": {
"id": terminology.id,
"name": terminology.name
}
},
"text": mapping.text,
"similarity": similarity
})
response[variable] = mappings_list

# Clean up temporary file
os.remove(tmp_file_path)

return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=5000)

0 comments on commit f708538

Please sign in to comment.