Skip to content

Commit

Permalink
Merge pull request #16 from krflorian/feature/chroma_search
Browse files Browse the repository at this point in the history
Feature/chroma search
  • Loading branch information
krflorian authored Jul 27, 2024
2 parents 7a0459d + 5fb0556 commit 8cca44e
Show file tree
Hide file tree
Showing 20 changed files with 2,878 additions and 380 deletions.
21 changes: 1 addition & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ http://127.0.0.1:8000/docs

## SETUP

At the moment there are two Vector Databases that have to be filled before the dataservice can start workting.
At the moment there are two Vector Databases that have to be filled before the dataservice can start working.

1. Rules DB
- includes data relevant for understanding the game:
Expand All @@ -29,20 +29,6 @@ At the moment there are two Vector Databases that have to be filled before the d

To fill the database there are scripts in the folder `src/etl` every script beginning with `create_` will create data as json files that can then be vectorized and inserted in the corresponding database. For vectorizing the data at the moment we are using the opensouce model [gte-large](https://huggingface.co/thenlper/gte-large) from huggingface.

To speed up vectorization the model can be placed on gpu.

https://docs.rapids.ai/install
```shell
conda create --solver=libmamba -n rapids-24.04 -c rapidsai-nightly -c conda-forge -c nvidia \
python=3.11 cuda-version=12.0 \
pytorch
conda init
conda activate rapids-24.04

poetry shell
python src/etl/create_card_db.py
```

## Development

```shell
Expand All @@ -51,8 +37,3 @@ poetry install
poetry shell
```

To build the container download the three necessary huggingface models to data/models:
- gte-large
- hallucination_evaluation_model
- bart-large-mnli

210 changes: 124 additions & 86 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,68 @@
import logging
import uvicorn
from pathlib import Path
from typing import Optional
import difflib

from fastapi import FastAPI
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer, CrossEncoder

from mtg.objects import Card, Document
from mtg.vector_db import VectorDB
from mtg.util import load_config
from mtg.hallucination import validate_answer
from mtg.chroma.config import ChromaConfig
from mtg.chroma.chroma_db import ChromaDB, CollectionType
from mtg.util import load_config, read_json_file
from mtg.url_parsing import parse_card_names

config: dict = load_config(Path("configs/config.yaml"))
vector_model: SentenceTransformer = SentenceTransformer(
config.get("vector_model_name", "thenlper/gte-large")
)
hallucination_model: CrossEncoder = CrossEncoder(
config.get("halucination_model_name", "vectara/hallucination_evaluation_model")
)

db: dict[str, VectorDB] = {
"card": VectorDB.load(config.get("cards_db_file", None)),
"rule": VectorDB.load(config.get("rules_db_file", None)),
}

chroma_config = ChromaConfig(**config["CHROMA"])
db = ChromaDB(chroma_config)

cards_collection = db.get_collection(CollectionType.CARDS)
documents_collection = db.get_collection(CollectionType.DOCUMENTS)


# load card data
# TODO -> make class CardDB and load from config
cards_folder = Path("../data/etl/processed/cards")
data = []
for file in cards_folder.iterdir():
data.append(read_json_file(file))

cards = [Card(**d) for d in data]
card_name_2_card = {card.name: card for card in cards}

all_keywords, all_legalities = set(), set()

for card in data:
for keyword in card["keywords"]:
all_keywords.add(keyword)
for legality in card["legalities"]:
if card["legalities"][legality] == "legal":
all_legalities.add(legality)

all_keywords = list(all_keywords)
all_legalities = list(all_legalities)
all_color_identities = {"W", "U", "B", "R", "G"}

# load rules data
docs_folder = Path("../data/etl/processed/documents")
documents = []
for file in docs_folder.iterdir():
data = read_json_file(file)
for doc in data:
documents.append(Document(**doc))
document_name_2_document = {doc.name: doc for doc in documents}

# app
app = FastAPI()


# Interface
## Rules


class RulesRequest(BaseModel):
text: str
k: int = Field(default=5)
threshold: float = Field(default=0.2)
lasso_threshold: float = Field(default=0.02)


class GetRulesResponse(BaseModel):
Expand All @@ -43,102 +71,112 @@ class GetRulesResponse(BaseModel):


## Cards


class CardsRequest(BaseModel):
text: str
k: int = Field(default=5)
keywords: list[str] = Field(default_factory=list)
color_identity: list[str] = Field(default_factory=list)
legality: Optional[str] = Field(default=None)
k: int = Field(default=10)
threshold: float = Field(default=0.4)
lasso_threshold: float = Field(default=0.1)
sample_results: bool = Field(default=False)


class GetCardsResponse(BaseModel):
card: Card
distance: float


## Halucination
class CardNameRequest(BaseModel):
card_name: str


class HalucinationRequest(BaseModel):
class CardParseRequest(BaseModel):
text: str
chunks: list[str]


class HalucinationResponse(BaseModel):
chunk: str
score: float


## NLI
class CardParseResponse(BaseModel):
text: str


class NLIClassificationRequest(BaseModel):
text: str
# Routes
@app.get("/card_name/{card_name}")
async def search_card(card_name: str) -> GetCardsResponse:

card = card_name_2_card.get(card_name, None)
return GetCardsResponse(card=card, distance=0.0)

class NLIClassificationResponse(BaseModel):
intent: str
score: float

@app.post("/parse_card_urls/")
async def get_cards(request: CardParseRequest) -> CardParseResponse:

# Routes
text = parse_card_names(request.text, card_name_2_card=card_name_2_card)
return CardParseResponse(text=text)


@app.post("/cards/")
async def get_cards(request: CardsRequest) -> list[GetCardsResponse]:
# when sampling retrieve more cards
if request.sample_results:
k = request.k * 2
else:
k = request.k

# query database
query_result = db["card"].query(
text=request.text,
k=k,
threshold=request.threshold,
lasso_threshold=request.lasso_threshold,
model=vector_model,
)
if request.sample_results:
query_result = db["card"].sample_results(query_result, request.k)

return [
GetCardsResponse(card=result[0], distance=result[1]) for result in query_result
]
# TODO sampling
# TODO code should be in class CardDB
# create query
query = {"query_texts": [request.text], "n_results": request.k, "where": {}}

# keywords
for search_term in request.keywords:
matches = difflib.get_close_matches(search_term, all_keywords, n=1)
if matches:
query["where"][f"keyword_{matches[0]}"] = True
else:
logging.info(f"did not find keyword: {query}")

# legalities
if request.legality is not None:
matches = difflib.get_close_matches(request.legality, all_legalities, n=1)
if matches:
query["where"][f"{matches[0]}_legal"] = True
else:
logging.info(f"did not find legality: {query}")

# color identity
for color in request.color_identity:
if color.upper() in all_color_identities:
query["where"][f"color_identity_{color.upper()}"] = True

if len(query["where"]) > 1:
query["where"] = {
"$and": [{key: value} for key, value in query["where"].items()]
}

# query
results = cards_collection.query(**query)

response = []
for distance, metadata in zip(results["distances"][0], results["metadatas"][0]):
if distance <= request.threshold:
response.append(
GetCardsResponse(
card=card_name_2_card.get(metadata["name"], None), distance=distance
)
)

return response


@app.post("/rules/")
async def get_rules(request: RulesRequest) -> list[GetRulesResponse]:
# query database
query_result = db["rule"].query(
text=request.text,
k=request.k,
threshold=request.threshold,
lasso_threshold=request.lasso_threshold,
model=vector_model,
)
# filter unique documents
documents = []
for doc, distance in query_result:
if doc not in documents:
documents.append((doc, distance))
return [
GetRulesResponse(document=doc, distance=distance) for doc, distance in documents
]


@app.post("/hallucination/")
async def validate_rag_chunks(
request: HalucinationRequest,
) -> list[HalucinationResponse]:
scores = validate_answer(request.text, request.chunks, model=hallucination_model)
return [
HalucinationResponse(chunk=chunk, score=score)
for chunk, score in zip(request.chunks, scores)
]

query = {"query_texts": [request.text], "n_results": request.k}
results = documents_collection.query(**query)

response = []
for distance, metadata in zip(results["distances"][0], results["metadatas"][0]):
if distance <= request.threshold:
response.append(
GetRulesResponse(
document=document_name_2_document.get(metadata["name"], None),
distance=distance,
)
)

return response


if __name__ == "__main__":
Expand Down
12 changes: 7 additions & 5 deletions configs/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@


cards_db_file: "../data/artifacts/card_db_gte.p"
rules_db_file: "../data/artifacts/rules_db_gte.p"
vector_model_name: "../data/models/gte-large"
halucination_model_name: "../data/models/hallucination_evaluation_model"
nli_classifier: "../data/models/deberta-v3-large-zeroshot-v1.1-all-33"
CHROMA:
host: "../data/chromadb"
port: "8000"
embedding_model: "../data/models/gte-large"
embedding_device: "cpu"
collection_name_documents: "documents"
collection_name_cards: "cards"
3 changes: 3 additions & 0 deletions mtg/chroma/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .chroma_db import ChromaDB
from .document import ChromaDocument
from .config import ChromaConfig
Loading

0 comments on commit 8cca44e

Please sign in to comment.