Skip to content

Commit

Permalink
[CHA-99] Scout Outlier Detection (#3)
Browse files Browse the repository at this point in the history
* update strategy flow

* inverse score logic for llm quality

* add retry to dispute request
  • Loading branch information
cloudre01 authored Aug 23, 2024
1 parent 912bd49 commit f0e385b
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 60 deletions.
27 changes: 16 additions & 11 deletions dispute/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,57 @@
import logging

logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

chasm = ChasmConnection()
PROCESSED_HISTORIES_FILE = "processed_histories.json"


def load_processed_histories():
if os.path.exists(PROCESSED_HISTORIES_FILE):
with open(PROCESSED_HISTORIES_FILE, "r") as f:
return set(json.load(f))
return set()


def save_processed_histories(histories):
with open(PROCESSED_HISTORIES_FILE, "w") as f:
json.dump(list(histories), f)


processed_histories = load_processed_histories()


async def process_histories():
histories = chasm.get_prompt_history()
logging.info(f"Histories: {len(histories)}")
for history in histories:
if history["_id"] in processed_histories:
logging.debug(f"Skipping already processed history: {history['_id']}")
continue
input = map(lambda x: x["content"], history["messages"])
input = "\n".join(input)
output = history["result"]["choices"][0]["message"]["content"]
result = await analyze_text(input, output)
result = await analyze_text(history["messages"], output)
logging.debug(f"Result: {result}")
logging.debug(f"Score: {result['confidence_score']}")

if result["confidence_score"] > MIN_CONFIDENCE_SCORE:
rs_output = result.get("rs_output")
assert rs_output is not None, "rs_output is not generated"
response = chasm.file_dispute(
history["_id"],
history["messages"],
history["result"]["choices"][0]["message"],
{"role": "assistant", "content": rs_output},
)
logging.info("Dispute filed: ", response)

if response is not None:
logging.info(f"Dispute filed: {response['result']}")

# Cache history
processed_histories.add(history["_id"])
save_processed_histories(processed_histories)




async def main():
while True:
Expand Down
23 changes: 18 additions & 5 deletions dispute/strategies/LLMQuality.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import re
import logging
from typing import List
from util.chasm import Message
from util.poll import PollAlgo
from strategies.tokenizer import Tokenizer


class LLMQualityStrategy:
def __init__(self, models):
"""
LLM Quality Strategy
Higher score means the model is more likely to be good.
Lower score means the model is more likely to be bad.
"""

def __init__(self, models):
# llama3 tokenizer
self.tokenizer = Tokenizer()
self.poll = PollAlgo(models)
Expand All @@ -15,8 +22,11 @@ def format_text_limit(self, text: str, limit: int):
encoded_result = self.tokenizer.encode(text, bos=False, eos=False)
return self.tokenizer.decode(encoded_result[:limit])

async def analyze(self, input: str, output: str):
input = self.format_text_limit(input, 3000)
async def analyze(self, input: List[Message], output: str) -> float:
text_input = map(lambda x: x["content"], input)
text_input = "\n".join(text_input)

formatted_input = self.format_text_limit(text_input, 3000)
output = self.format_text_limit(output, 3000)

# Prompt from https://arxiv.org/abs/2306.05685v4
Expand Down Expand Up @@ -48,7 +58,10 @@ async def analyze(self, input: str, output: str):
""",
},
{"role": "assistant", "content": "False"},
{"role": "user", "content": f"Question: {input}\nAnswer: {output}"},
{
"role": "user",
"content": f"Question: {formatted_input}\nAnswer: {output}",
},
],
)
logging.debug(results)
Expand All @@ -73,4 +86,4 @@ async def analyze(self, input: str, output: str):
llm_results.append(0.5)
logging.debug(f"LLM Results: {llm_results}")
score = sum(llm_results) / len(llm_results)
return score
return 1 - score
26 changes: 15 additions & 11 deletions dispute/strategies/ResponseSimilarity.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
from util.chasm import Message
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from typing import List, Tuple
from util.llm import llm


class ResponseSimilarityAnalysis:
def __init__(self, model: str):
self.model = model

async def analyze(self, input: str, output: str):
async def analyze(self, input: List[Message], output: str) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
score = await loop.run_in_executor(None, self._sync_analyse, input, output)
return score

def _sync_analyse(self, input: str, output: str):
score, output = await loop.run_in_executor(
None, self._sync_analyse, input, output
)
return score, output

def _sync_analyse(self, input: List[Message], output: str):
simulated_output = llm.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": input},
],
messages=input,
model=self.model,
temperature=0.8,
temperature=0,
)
simulated_output_response = simulated_output.choices[0].message.content
assert simulated_output_response is not None, "Simulated output is None"
assert isinstance(
simulated_output_response, str
), "Simulated output is not a string"
score = self.similarity(simulated_output_response, output)
return score
return score, simulated_output_response

def similarity(self, text1: str, text2: str) -> float:
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
Expand Down
20 changes: 11 additions & 9 deletions dispute/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import List
from util.chasm import Message


if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

os.environ["LOG_LEVEL"] = "DEBUG"

Expand All @@ -15,13 +18,13 @@

from config import MODELS, SIMULATION_MODEL

input = "What is the capital of France?"
input: List[Message] = [
{"role": "user", "content": "What is the capital of France?"}
]
output = "Paris"

# LLM Quality
lq = LLMQualityStrategy(
models=MODELS
)
lq = LLMQualityStrategy(models=MODELS)
lq_result = asyncio.run(lq.analyze(input, output))
print("LLM Quality: ", lq_result)

Expand All @@ -32,13 +35,12 @@

# Semantic Similarity
ss = SemanticSimilarityAnalysis()
ss_result = asyncio.run(ss.analyze(input, output))
text_input = map(lambda x: x["content"], input)
text_input = "\n".join(text_input)
ss_result = asyncio.run(ss.analyze(text_input, output))
print("Semantic Similarity:", ss_result)

# Static Text Analysis
sta = StaticTextAnalysisStrategy()
sta_result = sta.analyze(output)
print("Static Text Analysis:", sta_result)



50 changes: 28 additions & 22 deletions dispute/strategy.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import asyncio
from datetime import datetime
from typing import TypedDict
from typing import List, Optional, TypedDict
import logging
from config import SIMULATION_MODEL, MODELS
from config import MIN_CONFIDENCE_SCORE, SIMULATION_MODEL, MODELS
from util.chasm import Message
from strategies.ResponseSimilarity import ResponseSimilarityAnalysis
from strategies.LLMQuality import LLMQualityStrategy
from strategies.StaticTextAnalysis import StaticTextAnalysisStrategy
from strategies.SemanticSimilarity import SemanticSimilarityAnalysis


class TextAnalysisResult(TypedDict):
ss_score: float
llm_score: float
rs_score: float
confidence_score: float
correct: bool
dispute: bool
rs_output: Optional[str]


async def analyze_text(input: str, output: str) -> TextAnalysisResult:
async def analyze_text(input: List[Message], output: str) -> TextAnalysisResult:
"""
Analyze the given input and output texts to determine scores and dispute status.
Expand All @@ -41,40 +40,47 @@ async def analyze_text(input: str, output: str) -> TextAnalysisResult:
"confidence_score": 1.0,
"dispute": True,
"correct": False,
"ss_score": 1.0,
"llm_score": 1.0,
"rs_output": None,
"rs_score": 1.0,
}

# 2. Semantic Similarity
ss_strategy = SemanticSimilarityAnalysis()

# 3. LLM Analysis
# 2. LLM Analysis
llm_strategy = LLMQualityStrategy(models=MODELS)
llm_score = await llm_strategy.analyze(input, output)

# 4. Response
print(f"LLM Score: {llm_score}")
if llm_score > 0.5:
return {
"confidence_score": 0.0,
"dispute": False,
"correct": True,
"llm_score": llm_score,
"rs_output": None,
"rs_score": 0.0,
}

# 3. Response
rs_strategy = ResponseSimilarityAnalysis(model=SIMULATION_MODEL)

# Gather scores concurrently
scores = await asyncio.gather(
ss_strategy.analyze(input, output),
llm_strategy.analyze(input, output),
rs_strategy.analyze(input, output),
)
ss_score, llm_score, rs_score = scores
rs_result = await rs_strategy.analyze(input, output)
rs_score, rs_output = rs_result

print(f"RS Score: {rs_score}")

# Final Score Calculation
confidence_score = (llm_score * 0.5) + (ss_score * 0.1) + (rs_score * 0.4)
confidence_score = (llm_score * 0.5) + (rs_score * 0.5)

time_diff = datetime.now() - start_time
print(f"Time taken: {time_diff}")

dispute = confidence_score > 0.5
dispute = confidence_score > MIN_CONFIDENCE_SCORE

return {
"ss_score": ss_score,
"llm_score": llm_score,
"rs_score": rs_score,
"rs_output": rs_output,
"confidence_score": confidence_score,
"dispute": dispute,
"correct": not dispute,
Expand All @@ -84,9 +90,9 @@ async def analyze_text(input: str, output: str) -> TextAnalysisResult:
logging.error(f"Error in analyze_text: {e}")
# Default to no dispute
return {
"ss_score": 0.0,
"llm_score": 0.0,
"rs_score": 0.0,
"rs_output": None,
"confidence_score": 0.0,
"dispute": False,
"correct": True,
Expand Down
5 changes: 4 additions & 1 deletion dispute/util/chasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import requests
from urllib.parse import urljoin
from config import ORCHESTRATOR_URL, WEBHOOK_API_KEY
from util.fetch import request_with_backoff


class Message(TypedDict):
Expand Down Expand Up @@ -47,7 +48,9 @@ def file_dispute(self, id: str, input: List[Message], output: Message):
print(f"Dispute payload: {payload}")

try:
response = requests.post(url, json=payload, headers=headers)
response = request_with_backoff(
lambda: requests.post(url, json=payload, headers=headers)
)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as http_err:
Expand Down
19 changes: 19 additions & 0 deletions dispute/util/fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import time
import random
from typing import Callable
import requests


def request_with_backoff(request_func: Callable[[], requests.Response], max_retries=3):
retry_delay = 1 # Initial delay in seconds
for _ in range(max_retries):
try:
response = request_func()
response.raise_for_status()
return response.json()
except requests.RequestException:
time.sleep(retry_delay)
retry_delay *= 2 # Double the delay for the next attempt
retry_delay += random.uniform(0, 1) # Add jitter

raise Exception("Maximum retry attempts reached")
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"description": "",
"main": "index.js",
"scripts": {
"dev": "nodemon src/server/express.ts",
"dev": "nodemon --watch src src/server/express.ts",
"build": "tsc",
"prepare": "husky",
"prettier": "prettier . --check",
Expand Down

0 comments on commit f0e385b

Please sign in to comment.