Skip to content

Commit

Permalink
Scout Recompute (#5)
Browse files Browse the repository at this point in the history
* update strategy flow

* inverse score logic for llm quality

* add retry to dispute request

* add ability to recompute if possible
  • Loading branch information
cloudre01 authored Aug 28, 2024
1 parent f0e385b commit 12a5520
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 39 deletions.
14 changes: 11 additions & 3 deletions dispute/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import dotenv

dotenv.load_dotenv()

import asyncio
from util.chasm import ChasmConnection
from strategy import analyze_text


async def main():
connection = ChasmConnection()
histories = connection.get_benchmark_test()
i = 0
for history in histories:
print(f"--- {i} ---")
result = await analyze_text(history["input"], history["output"])
result = await analyze_text(
history["input"], history["output"], 0, "openai", "gpt-3"
)
print(f"Result: {result}")
print(f"Score: {result['confidence_score']}")
print(f"Assert Check: {'✅' if result['correct'] == history['answer'] else '❌'}")
print(
f"Assert Check: {'✅' if result['correct'] == history['answer'] else '❌'}"
)
print(f"Dispute: {result['dispute']}")
i += 1
pass
Expand All @@ -23,5 +29,7 @@ async def main():
if __name__ == "__main__":
# set TOKENIZERS_PARALLELISM
import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"
asyncio.run(main())
asyncio.run(main())

3 changes: 3 additions & 0 deletions dispute/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

LLM_BASE_URL = os.getenv("LLM_BASE_URL")
LLM_API_KEY = os.getenv("LLM_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
MODELS = os.getenv("MODELS", "gemma2-9b-it").split(",")
SIMULATION_MODEL = os.getenv("SIMULATION_MODEL", "gemma2-9b-it")
ORCHESTRATOR_URL = os.getenv("ORCHESTRATOR_URL")
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY")
MIN_CONFIDENCE_SCORE = float(os.getenv("MIN_CONFIDENCE_SCORE", 0.5))
MIN_RESPONSE_DIFFERENCE = float(os.getenv("MIN_RESPONSE_DIFFERENCE", 0.8))
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

print(f"LOG_LEVEL: {LOG_LEVEL}")
Expand Down
8 changes: 7 additions & 1 deletion dispute/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@ async def process_histories():
logging.debug(f"Skipping already processed history: {history['_id']}")
continue
output = history["result"]["choices"][0]["message"]["content"]
result = await analyze_text(history["messages"], output)
result = await analyze_text(
history["messages"],
output,
history["seed"],
history["result"]["scout"]["provider"],
history["result"]["scout"]["model"],
)
logging.debug(f"Result: {result}")
logging.debug(f"Score: {result['confidence_score']}")

Expand Down
52 changes: 52 additions & 0 deletions dispute/strategies/ResponseRecompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from util.chasm import Message
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from typing import List, Tuple
from util.llm import LLMProviders


class ResponseRecomputeAnalysis:
def __init__(self, model: str):
self.model = model
self.llm_providers = LLMProviders()

async def analyse(
self,
input: List[Message],
output: str,
seed: int,
provider: str = "ollama",
) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
score, output = await loop.run_in_executor(
None, self._sync_analyse, input, output, seed, provider
)
return score, output

def _sync_analyse(
self,
input: List[Message],
output: str,
seed: int,
provider: str,
) -> Tuple[float, str]:
llm = self.llm_providers[provider]
recompute_output = llm.chat.completions.create(
messages=input,
model=self.model,
temperature=0,
seed=seed,
)
output_response = recompute_output.choices[0].message.content
assert output_response is not None, "Recompute output is None"
assert isinstance(output_response, str), "Recompute output is not a string"
score = self.similarity(output_response, output)
return score, output_response

def similarity(self, text1: str, text2: str) -> float:
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
vectors = vectorizer.toarray()
cosine_sim = cosine_similarity(vectors[0:1], vectors[1:])
score = cosine_sim[0][0]
return 1 - score
4 changes: 3 additions & 1 deletion dispute/strategies/ResponseSimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from typing import List, Tuple
from util.llm import llm
from util.llm import LLMProviders


class ResponseSimilarityAnalysis:
def __init__(self, model: str):
self.model = model
self.provider = LLMProviders()

async def analyze(self, input: List[Message], output: str) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
Expand All @@ -18,6 +19,7 @@ async def analyze(self, input: List[Message], output: str) -> Tuple[float, str]:
return score, output

def _sync_analyse(self, input: List[Message], output: str):
llm = self.provider["ollama"]
simulated_output = llm.chat.completions.create(
messages=input,
model=self.model,
Expand Down
8 changes: 7 additions & 1 deletion dispute/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List
from util.chasm import Message


if __name__ == "__main__":
Expand All @@ -14,7 +13,9 @@
from LLMQuality import LLMQualityStrategy
from ResponseSimilarity import ResponseSimilarityAnalysis
from SemanticSimilarity import SemanticSimilarityAnalysis
from ResponseRecompute import ResponseRecomputeAnalysis
from StaticTextAnalysis import StaticTextAnalysisStrategy
from util.chasm import Message

from config import MODELS, SIMULATION_MODEL

Expand Down Expand Up @@ -44,3 +45,8 @@
sta = StaticTextAnalysisStrategy()
sta_result = sta.analyze(output)
print("Static Text Analysis:", sta_result)

# Response Recompute Analysis
rra = ResponseRecomputeAnalysis(model=SIMULATION_MODEL)
rra_result = asyncio.run(rra.analyse(input, output, 42, "groq"))
print("Response Recompute Analysis:", rra_result)
40 changes: 36 additions & 4 deletions dispute/strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from datetime import datetime
from typing import List, Optional, TypedDict
import logging
from config import MIN_CONFIDENCE_SCORE, SIMULATION_MODEL, MODELS

from config import (
GROQ_API_KEY,
MIN_CONFIDENCE_SCORE,
MIN_RESPONSE_DIFFERENCE,
OPENROUTER_API_KEY,
SIMULATION_MODEL,
MODELS,
)
from util.chasm import Message
from strategies.ResponseSimilarity import ResponseSimilarityAnalysis
from strategies.LLMQuality import LLMQualityStrategy
from strategies.StaticTextAnalysis import StaticTextAnalysisStrategy
from strategies.ResponseRecompute import ResponseRecomputeAnalysis


class TextAnalysisResult(TypedDict):
Expand All @@ -17,7 +26,13 @@ class TextAnalysisResult(TypedDict):
rs_output: Optional[str]


async def analyze_text(input: List[Message], output: str) -> TextAnalysisResult:
async def analyze_text(
input: List[Message],
output: str,
seed: int,
provider: str,
model: str,
) -> TextAnalysisResult:
"""
Analyze the given input and output texts to determine scores and dispute status.
Expand Down Expand Up @@ -62,8 +77,6 @@ async def analyze_text(input: List[Message], output: str) -> TextAnalysisResult:

# 3. Response
rs_strategy = ResponseSimilarityAnalysis(model=SIMULATION_MODEL)

# Gather scores concurrently
rs_result = await rs_strategy.analyze(input, output)
rs_score, rs_output = rs_result

Expand All @@ -77,6 +90,25 @@ async def analyze_text(input: List[Message], output: str) -> TextAnalysisResult:

dispute = confidence_score > MIN_CONFIDENCE_SCORE

recomputable_provider = True if provider in ["groq", "openrouter"] else False

# Try to recompute the output
if GROQ_API_KEY and OPENROUTER_API_KEY and dispute and recomputable_provider:
rr_strategy = ResponseRecomputeAnalysis(model=model)
print("Dispute detected. Trying to recompute...")
recompute_result = await rr_strategy.analyse(input, output, seed, provider)
if recompute_result:
rs_score, _ = recompute_result
if rs_score < MIN_RESPONSE_DIFFERENCE:
return {
"llm_score": llm_score,
"rs_score": rs_score,
"rs_output": rs_output,
"confidence_score": confidence_score,
"dispute": dispute,
"correct": False,
}

return {
"llm_score": llm_score,
"rs_score": rs_score,
Expand Down
Loading

0 comments on commit 12a5520

Please sign in to comment.