Skip to content

Commit

Permalink
Merge pull request #1 from SCAI-BIO/mpnet-implementation
Browse files Browse the repository at this point in the history
MPNet implementation
  • Loading branch information
tiadams authored Feb 15, 2024
2 parents c708854 + 3456fea commit b1f500a
Show file tree
Hide file tree
Showing 13 changed files with 1,274 additions and 349 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,5 @@ cython_debug/
#.idea/

gptstew/.env!/gptstew/resources/
.idea
.idea
.vscode
31 changes: 25 additions & 6 deletions index/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,29 @@
BIOFIND_DICT_SRC = "resources/dictionaries/pd/biofind.csv"
BIOFIND_EMBEDDINGS_SRC = "resources/embeddings/biofind.csv"

COLORS_AD = {'adni': '#d62728', 'aibl': '#ff7f0e', 'emif': '#8c564b', 'jadni': '#7f7f7f',
'a4': '#aec7e8', 'dod-adni': '#ffbb78', 'prevent-ad': '#98df8a', 'arwibo': '#ff9896',
'i-adni': '#c5b0d5', 'edsd': '#c49c94', 'pharmacog': '#c7c7c7',
'vita': '#bcbd22', 'abvib': '#e0d9e2', 'ad-mapper': '#800000'}
COLORS_AD = {
"adni": "#d62728",
"aibl": "#ff7f0e",
"emif": "#8c564b",
"jadni": "#7f7f7f",
"a4": "#aec7e8",
"dod-adni": "#ffbb78",
"prevent-ad": "#98df8a",
"arwibo": "#ff9896",
"i-adni": "#c5b0d5",
"edsd": "#c49c94",
"pharmacog": "#c7c7c7",
"vita": "#bcbd22",
"abvib": "#e0d9e2",
"ad-mapper": "#800000",
}

COLORS_PD = {'opdc': '#1f77b4', 'tpd': '#e377c2', 'biofind': '#9edae5', 'lrrk2': '#f7b6d2', 'luxpark': '#2ca02c',
'ppmi': '#9467bd', 'passionate': '#00ff00'}
COLORS_PD = {
"opdc": "#1f77b4",
"tpd": "#e377c2",
"biofind": "#9edae5",
"lrrk2": "#f7b6d2",
"luxpark": "#2ca02c",
"ppmi": "#9467bd",
"passionate": "#00ff00",
}
9 changes: 7 additions & 2 deletions index/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def to_dataframe(self):

class Variable:

def __init__(self, name: str, description: str, source: str, embedding: Embedding = None):
def __init__(
self, name: str, description: str, source: str, embedding: Embedding = None
):
self.name = name
self.description = description
self.source = source
Expand All @@ -42,7 +44,10 @@ def __init__(self, concept: Concept, variable: Variable, source: str):
self.source = source

def __eq__(self, other):
return self.concept.identifier == other.concept.identifier and self.variable.name == other.variable.name
return (
self.concept.identifier == other.concept.identifier
and self.variable.name == other.variable.name
)

def __hash__(self):
return hash((self.concept.identifier, self.variable.name))
Expand Down
30 changes: 25 additions & 5 deletions index/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from abc import ABC
import numpy as np
import openai
from sentence_transformers import SentenceTransformer


class EmbeddingModel(ABC):

def get_embedding(self, text: str) -> [float]:
pass

Expand All @@ -14,7 +14,6 @@ def get_embeddings(self, messages: [str]) -> [[float]]:


class GPT4Adapter(EmbeddingModel):

def __init__(self, api_key: str):
self.api_key = api_key
openai.api_key = api_key
Expand All @@ -28,19 +27,40 @@ def get_embedding(self, text: str, model="text-embedding-ada-002"):
return None
if isinstance(text, str):
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], model=model)['data'][0]['embedding']
return openai.Embedding.create(input=[text], model=model)["data"][0]["embedding"]
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str], model="text-embedding-ada-002"):
# store index of nan entries
response = openai.Embedding.create(input=messages, model=model)
return [item['embedding'] for item in response['data']]
return [item["embedding"] for item in response["data"]]


class TextEmbedding:
class MPNetAdapter(EmbeddingModel):
def __init__(self):
logging.getLogger().setLevel(logging.INFO)

def get_embedding(self, text: str, model="sentence-transformers/all-mpnet-base-v2"):
mpnet_model = SentenceTransformer(model)
logging.info(f"Getting embedding for {text}")
try:
if text is None or text == "" or text is np.nan:
logging.warn(f"Empty text passed to get_embedding")
return None
if isinstance(text, str):
text = text.replace("\n", " ")
return mpnet_model.encode(text)
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str]) -> [[float]]:
return [self.get_embedding(msg) for msg in messages]


class TextEmbedding:
def __init__(self, text: str, embedding: [float]):
self.text = text
self.embedding = embedding
Loading

0 comments on commit b1f500a

Please sign in to comment.