Skip to content

Commit

Permalink
Feat: Add tsne plot to API displaying current db state
Browse files Browse the repository at this point in the history
  • Loading branch information
tiadams committed Mar 11, 2024
1 parent 2841678 commit 5bd5694
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 11 deletions.
42 changes: 33 additions & 9 deletions index/api/routes.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
import json
import logging
from typing import Dict

import uvicorn
from fastapi import FastAPI, HTTPException
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from starlette.responses import RedirectResponse, HTMLResponse

from index.db.model import Terminology, Concept, Mapping
from index.repository.sqllite import SQLLiteRepository
from index.embedding import MPNetAdapter
from index.visualisation import get_html_plot_for_current_database_state

logger = logging.getLogger("uvicorn.info")
repository = SQLLiteRepository(path="../db/index.db")
embedding_model = MPNetAdapter()
db_plot_html = None

app = FastAPI(
title="INDEX",
description="Intelligent data steward toolbox using Large Language Model embeddings "
"for automated Data-Harmonization .",
description="<div id=info-text><h1>Introduction</h1>"
"INDEX uses vector embeddings from variable descriptions to suggest mappings for datasets based on "
"their semantic similarity. Mappings are stored with their vector representations in a knowledge "
"base, where they can be used for subsequent harmonisation tasks, potentially improving the following "
"suggestions with each iteration. Models for the computation as well as databases for storage are "
"meant to be configurable and extendable to adapt the tool for specific use-cases.</div>"
"<div id=db-plot><h1>Current DB state</h1>"
"<p>Showing 2D Visualization of DB entries up to a limit of 1000 entries</p>"
'<a href="/visualization">Click here to view visualization</a></div>',
version="0.0.1",
terms_of_service="https://www.scai.fraunhofer.de/",
contact={
Expand All @@ -25,6 +36,14 @@
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
openapi_extra={
"info": {
"x-logo": {
"url": "https://example.com/logo.png",
"altText": "Your API Logo"
}
}
},
)

origins = ["*"]
Expand All @@ -37,10 +56,6 @@
allow_headers=["*"],
)

logger = logging.getLogger("uvicorn.info")
repository = SQLLiteRepository()
embedding_model = MPNetAdapter()


@app.get("/", include_in_schema=False)
def swagger_redirect():
Expand All @@ -52,6 +67,14 @@ def get_current_version():
return app.version


@app.get("/visualization", response_class=HTMLResponse, tags=["visualization"])
def serve_visualization():
global db_plot_html
if not db_plot_html:
db_plot_html = get_html_plot_for_current_database_state(repository)
return db_plot_html


@app.put("/terminologies/{id}", tags=["terminologies"])
async def create_or_update_terminology(id: str, name: str):
try:
Expand Down Expand Up @@ -113,5 +136,6 @@ async def get_closest_mappings_for_text(text: str):
})
return response_data


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=5000)
7 changes: 7 additions & 0 deletions index/repository/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod

from index.db.model import Mapping


class BaseRepository(ABC):
@abstractmethod
Expand All @@ -12,6 +14,11 @@ def store_all(self, model_object_instances):
"""Store multiple model object instances."""
pass

@abstractmethod
def get_all_mappings(self, limit=1000) -> [Mapping]:
"""Get all embeddings up to a limit"""
pass

@abstractmethod
def get_closest_mappings(self, embedding, limit=5):
"""Get the closest mappings based on embedding."""
Expand Down
14 changes: 12 additions & 2 deletions index/repository/sqllite.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import numpy as np

from typing import Union, List
Expand All @@ -10,9 +11,9 @@

class SQLLiteRepository(BaseRepository):

def __init__(self, mode="disk", name="index"):
def __init__(self, mode="disk", path="index/db/index.db"):
if mode == "disk":
self.engine = create_engine(f'sqlite:///{name}.db')
self.engine = create_engine(f'sqlite:///{path}')
# for tests
elif mode == "memory":
self.engine = create_engine('sqlite:///:memory:')
Expand All @@ -30,6 +31,15 @@ def store_all(self, model_object_instances: List[Union[Terminology, Concept, Map
self.session.add_all(model_object_instances)
self.session.commit()

def get_all_mappings(self, limit=1000):
# Determine the total count of mappings in the database
total_count = self.session.query(func.count(Mapping.id)).scalar()
# Generate random indices for the subset of embeddings
random_indices = random.sample(range(total_count), min(limit, total_count))
# Query for mappings corresponding to the random indices
mappings = self.session.query(Mapping).filter(Mapping.id.in_(random_indices)).all()
return mappings

def get_closest_mappings(self, embedding: List[float], limit=5):
mappings = self.session.query(Mapping).all()
all_embeddings = np.array([mapping.embedding for mapping in mappings])
Expand Down
52 changes: 52 additions & 0 deletions index/scripts/fill_db_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from index.db.model import Terminology, Concept, Mapping
from index.embedding import MPNetAdapter
from index.repository.sqllite import SQLLiteRepository

repository = SQLLiteRepository()
embedding_model = MPNetAdapter()

terminology = Terminology("snomed CT", "SNOMED")

text1 = "Diabetes mellitus (disorder)"
concept1 = Concept(terminology, text1, "Concept ID: 11893007")
mapping1 = Mapping(concept1, text1, embedding_model.get_embedding(text1))

text2 = "Hypertension (disorder)"
concept2 = Concept(terminology, text2, "Concept ID: 73211009")
mapping2 = Mapping(concept2, text2, embedding_model.get_embedding(text2))

text3 = "Asthma"
concept3 = Concept(terminology, text3, "Concept ID: 195967001")
mapping3 = Mapping(concept3, text3, embedding_model.get_embedding(text3))

text4 = "Heart attack"
concept4 = Concept(terminology, text4, "Concept ID: 22298006")
mapping4 = Mapping(concept4, text4, embedding_model.get_embedding(text4))

text5 = "Common cold"
concept5 = Concept(terminology, text5, "Concept ID: 13260007")
mapping5 = Mapping(concept5, text5, embedding_model.get_embedding(text5))

text6 = "Stroke"
concept6 = Concept(terminology, text6, "Concept ID: 422504002")
mapping6 = Mapping(concept6, text6, embedding_model.get_embedding(text6))

text7 = "Migraine"
concept7 = Concept(terminology, text7, "Concept ID: 386098009")
mapping7 = Mapping(concept7, text7, embedding_model.get_embedding(text7))

text8 = "Influenza"
concept8 = Concept(terminology, text8, "Concept ID: 57386000")
mapping8 = Mapping(concept8, text8, embedding_model.get_embedding(text8))

text9 = "Osteoarthritis"
concept9 = Concept(terminology, text9, "Concept ID: 399206004")
mapping9 = Mapping(concept9, text9, embedding_model.get_embedding(text9))

text10 = "Depression"
concept10 = Concept(terminology, text10, "Concept ID: 386584008")
mapping10 = Mapping(concept10, text10, embedding_model.get_embedding(text10))

repository.store_all([terminology, concept1, mapping1, concept2, mapping2, concept3, mapping3, concept4, mapping4,
concept5, mapping5, concept6, mapping6, concept7, mapping7, concept8, mapping8,
concept9, mapping9, concept10, mapping10])
40 changes: 40 additions & 0 deletions index/visualisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from index.conf import COLORS_AD, COLORS_PD
from index.mapping import MappingTable
from index.repository.base import BaseRepository


class PlotSide(Enum):
Expand Down Expand Up @@ -173,3 +174,42 @@ def scatter_plot_all_cohorts(tables1: [MappingTable], tables2: [MappingTable], l
if store_html:
fig.write_html(store_base_dir + "/tsne_all_cohorts.html")
fig.show()


def get_html_plot_for_current_database_state(repository: BaseRepository, perplexity: int = 5) -> str:
# get up to 1000 entries from db
mappings = repository.get_all_mappings()
# Extract embeddings
embeddings = np.array([mapping.embedding for mapping in mappings])
# Increase perplexity up to 30 if applicable
if embeddings.shape[0] > 30:
perplexity = 30
if embeddings.shape[0] > perplexity:
# Compute t-SNE embeddings
tsne_embeddings = TSNE(n_components=2, perplexity=perplexity).fit_transform(embeddings)
# Create Plotly scatter plot
scatter_plot = go.Scatter(
x=tsne_embeddings[:, 0],
y=tsne_embeddings[:, 1],
mode='markers',
marker=dict(
size=8,
color='blue',
opacity=0.5
),
text=[str(mapping) for mapping in mappings],
hoverinfo='text'
)
layout = go.Layout(
title='t-SNE Embeddings of Database Mappings',
xaxis=dict(title='t-SNE Component 1'),
yaxis=dict(title='t-SNE Component 2'),
)
fig = go.Figure(data=[scatter_plot], layout=layout)
# Convert the Plotly figure to HTML
html_plot = fig.to_html(full_html=False)
else:
html_plot = '<b>Too few database entries to visualize</b>'
return html_plot


0 comments on commit 5bd5694

Please sign in to comment.