Skip to content

Commit

Permalink
Merge pull request #55 from SCAI-BIO/add-python-tests
Browse files Browse the repository at this point in the history
add simple python test to check if app starts
  • Loading branch information
tiadams authored Aug 1, 2024
2 parents 565919d + b4c9b11 commit daa39bf
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 38 deletions.
13 changes: 5 additions & 8 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,24 @@ on:

jobs:
build:

runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
python-version: "3.11"
- name: Install dependencies
run: |
cd api
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Lint with flake8
run: |
cd api
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
27 changes: 17 additions & 10 deletions api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
tzdata==2023.3
wheel==0.37.1
aiofiles~=0.7.0
uvicorn>=0.15.0
python-multipart
fastapi~=0.87.0
starlette~=0.21.0
pytest~=8.2.2
httpx~=0.27.0
uvicorn~=0.30.1
fastapi~=0.111.0
starlette~=0.37.2
datastew~=0.3.3
numpy~=1.25.2
pandas~=2.1.0
requests~=2.31.0
pydantic~=1.10.14
fastapi_keycloak
datastew==0.2.0
scikit-learn~=1.3.2
SQLAlchemy~=2.0.31
openai~=0.28.1
thefuzz~=0.20.0
scipy~=1.11.4
seaborn~=0.13.2
matplotlib~=3.8.4
plotly~=5.17.0
packaging~=24.1
65 changes: 47 additions & 18 deletions api/routes.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import logging
import os
import tempfile
import time

import uvicorn
from datastew import DataDictionarySource, BaseRepository
from datastew.process.ols import OLSTerminologyImportTask
from datastew.repository import WeaviateRepository

from fastapi import FastAPI, HTTPException, File, UploadFile
from starlette.background import BackgroundTasks
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse, HTMLResponse

from datastew import DataDictionarySource
from datastew.repository.model import Terminology, Concept, Mapping
from datastew.embedding import MPNetAdapter
from datastew.repository.sqllite import SQLLiteRepository
from datastew.visualisation import get_html_plot_for_current_database_state

logger = logging.getLogger("uvicorn.info")
repository = SQLLiteRepository(mode="disk", path="snomed.db")
embedding_model = MPNetAdapter()
db_plot_html = None
from datastew.visualisation import get_plot_for_current_database_state

app = FastAPI(
title="INDEX",
Expand Down Expand Up @@ -50,6 +48,23 @@
},
)


def connect_to_remote_weaviate_repository():
retries = 5
for i in range(retries):
try:
return WeaviateRepository(mode="remote", path="http://weaviate:8080")
except Exception as e:
logger.info(f"Attempt {i + 1} failed: {e}")
time.sleep(5)
raise ConnectionError("Could not connect to Weaviate after multiple attempts.")


logger = logging.getLogger("uvicorn.info")
repository = connect_to_remote_weaviate_repository()
embedding_model = MPNetAdapter()
db_plot_html = None

origins = ["*"]

app.add_middleware(
Expand All @@ -75,14 +90,14 @@ def get_current_version():
def serve_visualization():
global db_plot_html
if not db_plot_html:
db_plot_html = get_html_plot_for_current_database_state(repository)
db_plot_html = get_plot_for_current_database_state(repository)
return db_plot_html


@app.patch("/visualization", tags=["visualization"])
def update_visualization():
global db_plot_html
db_plot_html = get_html_plot_for_current_database_state(repository)
db_plot_html = get_plot_for_current_database_state(repository)
return {"message": "DB visualization plot has been updated successfully"}


Expand Down Expand Up @@ -161,16 +176,15 @@ async def create_or_update_mapping(concept_id: str, text: str):
@app.post("/mappings", tags=["mappings"])
async def get_closest_mappings_for_text(text: str, limit: int = 5):
embedding = embedding_model.get_embedding(text).tolist()
print(embedding)
closest_mappings, similarities = repository.get_closest_mappings(embedding, limit)
closest_mappings = repository.get_closest_mappings_with_similarities(embedding, limit)
mappings = []
for mapping, similarity in zip(closest_mappings, similarities):
for mapping, similarity in closest_mappings:
concept = mapping.concept
terminology = concept.terminology
mappings.append({
"concept": {
"id": concept.id,
"name": concept.name,
"id": concept.concept_identifier,
"name": concept.pref_label,
"terminology": {
"id": terminology.id,
"name": terminology.name
Expand All @@ -179,12 +193,14 @@ async def get_closest_mappings_for_text(text: str, limit: int = 5):
"text": mapping.text,
"similarity": similarity
})

return mappings


# Endpoint to get mappings for a data dictionary source
@app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.")
async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', description_field: str = 'description'):
async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable',
description_field: str = 'description'):
try:
# Determine file extension and create a temporary file with the correct extension
_, file_extension = os.path.splitext(file.filename)
Expand All @@ -193,7 +209,8 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari
tmp_file_path = tmp_file.name

# Initialize DataDictionarySource with the temporary file path
data_dict_source = DataDictionarySource(file_path=tmp_file_path, variable_field=variable_field, description_field=description_field)
data_dict_source = DataDictionarySource(file_path=tmp_file_path, variable_field=variable_field,
description_field=description_field)
df = data_dict_source.to_dataframe()

response = []
Expand All @@ -208,7 +225,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari
terminology = concept.terminology
mappings_list.append({
"concept": {
"id": concept.id,
"id": concept.concept_id,
"name": concept.name,
"terminology": {
"id": terminology.id,
Expand All @@ -231,5 +248,17 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


def import_snomed_ct_task():
task = OLSTerminologyImportTask(repository, embedding_model, "SNOMED CT", "snomed")
task.process()


@app.put("/import/terminology/snomed", description="Import whole SNOMED CT from OLS.", tags=["import", "tasks"])
async def import_snomed_ct(background_tasks: BackgroundTasks):
background_tasks.add_task(import_snomed_ct_task)
return {"message": "SNOMED CT import started in the background"}


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=5000)
17 changes: 15 additions & 2 deletions docker-compose.local.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
version: "3.12"

services:

frontend:
image: index-client
build:
Expand All @@ -18,4 +17,18 @@ services:
context: ./api
dockerfile: Dockerfile
ports:
- "5000:80"
- "5000:80"
depends_on:
- weaviate

weaviate:
image: cr.weaviate.io/semitechnologies/weaviate:1.24.20
ports:
- "8080:8080"
volumes:
- weaviate_data:/var/lib/weaviate

volumes:

weaviate_data:
driver: local

0 comments on commit daa39bf

Please sign in to comment.