diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index be6fc71..7fdf23f 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -8,27 +8,24 @@ on: jobs: build: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: ${{ matrix.python-version }} + python-version: "3.11" - name: Install dependencies run: | + cd api python -m pip install --upgrade pip python -m pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | + cd api # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics \ No newline at end of file diff --git a/api/requirements.txt b/api/requirements.txt index 8140a5f..23e4b01 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -1,11 +1,18 @@ -tzdata==2023.3 -wheel==0.37.1 -aiofiles~=0.7.0 -uvicorn>=0.15.0 -python-multipart -fastapi~=0.87.0 -starlette~=0.21.0 +pytest~=8.2.2 +httpx~=0.27.0 +uvicorn~=0.30.1 +fastapi~=0.111.0 +starlette~=0.37.2 +datastew~=0.3.3 +numpy~=1.25.2 +pandas~=2.1.0 requests~=2.31.0 -pydantic~=1.10.14 -fastapi_keycloak -datastew==0.2.0 \ No newline at end of file +scikit-learn~=1.3.2 +SQLAlchemy~=2.0.31 +openai~=0.28.1 +thefuzz~=0.20.0 +scipy~=1.11.4 +seaborn~=0.13.2 +matplotlib~=3.8.4 +plotly~=5.17.0 +packaging~=24.1 \ No newline at end of file diff --git a/api/routes.py b/api/routes.py index 1934c2c..6713d50 100644 --- a/api/routes.py +++ b/api/routes.py @@ -1,23 +1,21 @@ import logging import os import tempfile +import time import uvicorn +from datastew import DataDictionarySource, BaseRepository +from datastew.process.ols import OLSTerminologyImportTask +from datastew.repository import WeaviateRepository from fastapi import FastAPI, HTTPException, File, UploadFile +from starlette.background import BackgroundTasks from starlette.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse, HTMLResponse -from datastew import DataDictionarySource from datastew.repository.model import Terminology, Concept, Mapping from datastew.embedding import MPNetAdapter -from datastew.repository.sqllite import SQLLiteRepository -from datastew.visualisation import get_html_plot_for_current_database_state - -logger = logging.getLogger("uvicorn.info") -repository = SQLLiteRepository(mode="disk", path="snomed.db") -embedding_model = MPNetAdapter() -db_plot_html = None +from datastew.visualisation import get_plot_for_current_database_state app = FastAPI( title="INDEX", @@ -50,6 +48,23 @@ }, ) + +def connect_to_remote_weaviate_repository(): + retries = 5 + for i in range(retries): + try: + return WeaviateRepository(mode="remote", path="http://weaviate:8080") + except Exception as e: + logger.info(f"Attempt {i + 1} failed: {e}") + time.sleep(5) + raise ConnectionError("Could not connect to Weaviate after multiple attempts.") + + +logger = logging.getLogger("uvicorn.info") +repository = connect_to_remote_weaviate_repository() +embedding_model = MPNetAdapter() +db_plot_html = None + origins = ["*"] app.add_middleware( @@ -75,14 +90,14 @@ def get_current_version(): def serve_visualization(): global db_plot_html if not db_plot_html: - db_plot_html = get_html_plot_for_current_database_state(repository) + db_plot_html = get_plot_for_current_database_state(repository) return db_plot_html @app.patch("/visualization", tags=["visualization"]) def update_visualization(): global db_plot_html - db_plot_html = get_html_plot_for_current_database_state(repository) + db_plot_html = get_plot_for_current_database_state(repository) return {"message": "DB visualization plot has been updated successfully"} @@ -161,16 +176,15 @@ async def create_or_update_mapping(concept_id: str, text: str): @app.post("/mappings", tags=["mappings"]) async def get_closest_mappings_for_text(text: str, limit: int = 5): embedding = embedding_model.get_embedding(text).tolist() - print(embedding) - closest_mappings, similarities = repository.get_closest_mappings(embedding, limit) + closest_mappings = repository.get_closest_mappings_with_similarities(embedding, limit) mappings = [] - for mapping, similarity in zip(closest_mappings, similarities): + for mapping, similarity in closest_mappings: concept = mapping.concept terminology = concept.terminology mappings.append({ "concept": { - "id": concept.id, - "name": concept.name, + "id": concept.concept_identifier, + "name": concept.pref_label, "terminology": { "id": terminology.id, "name": terminology.name @@ -179,12 +193,14 @@ async def get_closest_mappings_for_text(text: str, limit: int = 5): "text": mapping.text, "similarity": similarity }) + return mappings # Endpoint to get mappings for a data dictionary source @app.post("/mappings/dict", tags=["mappings"], description="Get mappings for a data dictionary source.") -async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', description_field: str = 'description'): +async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), variable_field: str = 'variable', + description_field: str = 'description'): try: # Determine file extension and create a temporary file with the correct extension _, file_extension = os.path.splitext(file.filename) @@ -193,7 +209,8 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari tmp_file_path = tmp_file.name # Initialize DataDictionarySource with the temporary file path - data_dict_source = DataDictionarySource(file_path=tmp_file_path, variable_field=variable_field, description_field=description_field) + data_dict_source = DataDictionarySource(file_path=tmp_file_path, variable_field=variable_field, + description_field=description_field) df = data_dict_source.to_dataframe() response = [] @@ -208,7 +225,7 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari terminology = concept.terminology mappings_list.append({ "concept": { - "id": concept.id, + "id": concept.concept_id, "name": concept.name, "terminology": { "id": terminology.id, @@ -231,5 +248,17 @@ async def get_closest_mappings_for_dictionary(file: UploadFile = File(...), vari except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + +def import_snomed_ct_task(): + task = OLSTerminologyImportTask(repository, embedding_model, "SNOMED CT", "snomed") + task.process() + + +@app.put("/import/terminology/snomed", description="Import whole SNOMED CT from OLS.", tags=["import", "tasks"]) +async def import_snomed_ct(background_tasks: BackgroundTasks): + background_tasks.add_task(import_snomed_ct_task) + return {"message": "SNOMED CT import started in the background"} + + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=5000) diff --git a/docker-compose.local.yaml b/docker-compose.local.yaml index 09d1ff2..9b442b1 100644 --- a/docker-compose.local.yaml +++ b/docker-compose.local.yaml @@ -1,7 +1,6 @@ version: "3.12" services: - frontend: image: index-client build: @@ -18,4 +17,18 @@ services: context: ./api dockerfile: Dockerfile ports: - - "5000:80" \ No newline at end of file + - "5000:80" + depends_on: + - weaviate + + weaviate: + image: cr.weaviate.io/semitechnologies/weaviate:1.24.20 + ports: + - "8080:8080" + volumes: + - weaviate_data:/var/lib/weaviate + +volumes: + + weaviate_data: + driver: local \ No newline at end of file