Skip to content

Commit

Permalink
feat(search): similarity search using situation
Browse files Browse the repository at this point in the history
  • Loading branch information
eksno committed Nov 18, 2024
1 parent 4f426a2 commit e2a762d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 33 deletions.
15 changes: 10 additions & 5 deletions services/api/src/lib/few_shot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Literal
from langchain_core.prompts import FewShotPromptTemplate, PromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
Expand All @@ -16,8 +17,14 @@ class StrippedCritique(BaseModel):
instructions: str


SimilarityKey = Literal["query"] | Literal["situation"] | Literal["context"]


def find_relevant_critiques(
critiques: list[StrippedCritique], query: str, k: int = 4
critiques: list[StrippedCritique],
similarity: str,
k: int = 4,
similarity_key: SimilarityKey = "query",
) -> list[StrippedCritique]:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if OPENAI_API_KEY is None:
Expand All @@ -26,18 +33,16 @@ def find_relevant_critiques(
OPENAI_API_KEY = SecretStr(OPENAI_API_KEY)

embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
print("1")

example_selector = SemanticSimilarityExampleSelector.from_examples(
[critique.model_dump() for critique in critiques],
embeddings,
InMemoryVectorStore,
k=k,
input_keys=["query"],
input_keys=[similarity_key],
)
print("2")

return [
StrippedCritique(**critique)
for critique in example_selector.select_examples({"query": query})
for critique in example_selector.select_examples({similarity_key: similarity})
]
125 changes: 97 additions & 28 deletions services/api/src/routers/critiques.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import traceback
import logging
from functools import wraps
from typing import Annotated, cast
from typing import Annotated, Literal, cast
import urllib.parse
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_openai.chat_models import ChatOpenAI
Expand All @@ -16,6 +16,7 @@
from src.lib import auth, validators as vd

from src.lib.few_shot import (
SimilarityKey,
find_relevant_critiques,
StrippedCritique,
)
Expand Down Expand Up @@ -59,6 +60,38 @@ async def wrapper(*args, **kwargs):
return wrapper


def generate_situation(model: ChatOpenAI, context: str) -> str:
class Situation(BaseModel):
situation: str = Field(
description="A ~10 word description of the situation from the context and query. The situation should be generic such that it's similarly worded to others since it's used for similarity search."
)

prompt = ChatPromptTemplate(
[
HumanMessage(
content=f"""
<context>
{context}
</context>
Please deduce the situation from the context provided.
"""
),
]
)

agent = model.with_structured_output(Situation)

situation = cast(
Situation,
agent.invoke(prompt.invoke({})),
).situation

logging.info(f"critiques: generate_situation: {situation}")

return situation


@router.get("/ids")
def get_critique_ids() -> list[str]:
supabase = db.client()
Expand All @@ -71,8 +104,10 @@ class GetCritiquesQuery(BaseModel):
environment_name: str
workflow_name: str | None = None
agent_name: str | None = None
context: str | None = None
query: str | None = None
k: int | None = None
similarity_key: SimilarityKey = "query"


class GetCritiquesResult(BaseModel):
Expand All @@ -85,6 +120,7 @@ class GetCritiquesResult(BaseModel):
async def list_critiques(
x_critino_key: Annotated[str, Header()],
query: Annotated[GetCritiquesQuery, Depends(GetCritiquesQuery)],
x_openrouter_api_key: Annotated[str | None, Header()],
) -> GetCritiquesResult:
logging.info(f"list_critiques: x_critino_key: {x_critino_key} - params: {query}")

Expand Down Expand Up @@ -154,7 +190,36 @@ async def list_critiques(
for critique in response.data
]

relevant_critiques = find_relevant_critiques(critiques, query.query, k=query.k)
if query.similarity_key == "situation":
model = (
llm.chat_open_router(
model="anthropic/claude-3-5-haiku-20241022:beta",
api_key=x_openrouter_api_key,
)
if x_openrouter_api_key
else None
)

if not model:
raise HTTPException(
status_code=400,
detail="'similarity_key' is set to 'situation' but no model is available to generate the situation.",
)

context = query.context + "\n" if query.context else "" + query.query
situation = generate_situation(model, context)

relevant_critiques = find_relevant_critiques(
critiques, situation, k=query.k, similarity_key=query.similarity_key
)

return GetCritiquesResult(
data=relevant_critiques, count=len(relevant_critiques)
)

relevant_critiques = find_relevant_critiques(
critiques, query.query, k=query.k, similarity_key=query.similarity_key
)

return GetCritiquesResult(data=relevant_critiques, count=len(relevant_critiques))

Expand Down Expand Up @@ -191,10 +256,20 @@ def generate_Fields(
attempts: int = 3,
messages: list[BaseMessage] = [],
) -> FilledBody:
context = (body.context + "\n" if body.context else "") + (
body.query if body.query else ""
)
situation = generate_situation(model, context)
filled_body = FilledBody(
query=body.query,
context=body.context,
response=body.response,
optimal=body.optimal,
instructions=body.instructions,
situation=situation,
)

class Populate(BaseModel):
situation: str = Field(
description="A ~10 word description of the situation from the context and query. The situation should be generic such that it's similarly worded to others since it's used for similarity search."
)
chain_of_thought: str = Field(
description="This is your reasoning, use it to evaluate the current information given. Especially the context and original response 'response'. Evaluate how the response was optimized 'optimal'. Always start this field with `Let's think step by step. `"
)
Expand All @@ -217,7 +292,7 @@ class Populate(BaseModel):
HumanMessage(
content=f"""
Fields and context:
{body.model_dump_json(indent=4)}
{filled_body.model_dump_json(indent=4)}
Please deduce the missing fields.
Do NOT change the fields already present.
Expand All @@ -230,23 +305,15 @@ class Populate(BaseModel):

agent = model.with_structured_output(Populate)

filled_body = FilledBody(
query=body.query,
context=body.context,
response=body.response,
optimal=body.optimal,
instructions=body.instructions,
situation="",
)

for attempt in range(attempts):
result = cast(
Populate,
agent.invoke(prompt.invoke({"msgs": messages})),
)
if (body.instructions != "" and result.instructions != body.instructions) or (
body.optimal != "" and result.optimal != body.optimal
):
if (
filled_body.instructions != ""
and result.instructions != filled_body.instructions
) or (filled_body.optimal != "" and result.optimal != filled_body.optimal):
messages.append(
AIMessage(name="populator", content=result.model_dump_json(indent=4))
)
Expand All @@ -258,22 +325,24 @@ class Populate(BaseModel):
continue

if query.populate_missing:
body.instructions = (
result.instructions if result.instructions else body.instructions
filled_body.instructions = (
result.instructions if result.instructions else filled_body.instructions
)
filled_body.optimal = (
result.optimal if result.optimal else filled_body.instructions
)
body.optimal = result.optimal if result.optimal else body.instructions

filled_body = FilledBody(
query=body.query,
context=body.context,
response=body.response,
optimal=body.optimal,
instructions=body.instructions,
situation=result.situation,
query=filled_body.query,
context=filled_body.context,
response=filled_body.response,
optimal=filled_body.optimal,
instructions=filled_body.instructions,
situation=filled_body.situation,
)

logging.info(
f"critiques: generate_fields: attempt {attempt + 1}: (situation: {result.situation}, instructions: {result.instructions}, optimal: {result.optimal})"
f"critiques: generate_fields: attempt {attempt + 1}: (instructions: {result.instructions}, optimal: {result.optimal})"
)

return filled_body
Expand Down

0 comments on commit e2a762d

Please sign in to comment.