Skip to content

Commit

Permalink
Merge pull request #25 from krflorian/feature/cards_query
Browse files Browse the repository at this point in the history
upgraded cards query
  • Loading branch information
krflorian authored Dec 7, 2024
2 parents fd3e075 + 5e877c1 commit 7a6194a
Showing 1 changed file with 42 additions and 16 deletions.
58 changes: 42 additions & 16 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ class GetRulesResponse(BaseModel):
class CardsRequest(BaseModel):
text: str
keywords: list[str] = Field(default_factory=list)
color_identity: list[str] = Field(default_factory=list)
color_identity: list[str] = Field(
default_factory=list, description="Can be a single character of WUBRG"
)
legality: Optional[str] = Field(default=None)
k: int = Field(default=20)
threshold: float = Field(default=0.4)
Expand Down Expand Up @@ -185,37 +187,61 @@ async def parse_cards(request: CardParseRequest) -> CardParseResponse:

@app.post("/cards", tags=["Cards"])
async def get_cards(request: CardsRequest) -> list[GetCardsResponse]:
# TODO sampling
# TODO code should be in class CardDB
# create query
query = {"query_texts": [request.text], "n_results": request.k, "where": {}}
query = {"query_texts": [request.text], "n_results": request.k, "where": []}

# keywords
# adding keywords as or combinations
keyword_query = {}
for search_term in request.keywords:
matches = difflib.get_close_matches(search_term, app.all_keywords, n=1)
matches = difflib.get_close_matches(
search_term, app.all_keywords, n=1, cutoff=0.8
)
print("keyword matches", matches)
if matches:
query["where"][f"keyword_{matches[0]}"] = True
keyword_query[f"keyword_{matches[0]}"] = True
else:
logging.info(f"did not find keyword: {query}")

# legalities
if len(keyword_query) > 1:
query["where"].append(
{"$or": [{key: value} for key, value in keyword_query.items()]}
)
elif keyword_query:
query["where"].append(keyword_query)

# adding legalities as and (must be legal in...)
if request.legality is not None:
matches = difflib.get_close_matches(request.legality, app.all_legalities, n=1)
if matches:
query["where"][f"{matches[0]}_legal"] = True
query["where"].append({f"{matches[0]}_legal": True})
else:
logging.info(f"did not find legality: {query}")

# color identity
# color identity (must have color identity...)
color_mapper = {"black": "B", "green": "G", "blue": "U", "white": "W", "red": "R"}
colors = []
for color in request.color_identity:
if color.upper() in app.all_color_identities:
query["where"][f"color_identity_{color.upper()}"] = True

colors.append(color.upper())
else:
color = color_mapper.get(color)
if color:
colors.append(color)
if len(colors) > 1:
query["where"].append(
{"$or": [{f"color_identity_{color}": True} for color in colors]}
)
elif colors:
query["where"].append({f"color_identity_{colors[0]}": True})

# combining where statements
if len(query["where"]) > 1:
query["where"] = {
"$and": [{key: value} for key, value in query["where"].items()]
}
query["where"] = {"$and": query["where"]}
elif len(query["where"]) == 1:
query["where"] = query["where"][0]
else:
del query["where"]

print("query", query)
# query
cards_collection = app.db.get_collection(CollectionType.CARDS)
results = cards_collection.query(**query)
Expand Down

0 comments on commit 7a6194a

Please sign in to comment.