From f0e385b20708bb1d9470f79ebbeacd1b8cf087b2 Mon Sep 17 00:00:00 2001 From: cloudre01 <43627940+cloudre01@users.noreply.github.com> Date: Fri, 23 Aug 2024 10:36:05 +0800 Subject: [PATCH] [CHA-99] Scout Outlier Detection (#3) * update strategy flow * inverse score logic for llm quality * add retry to dispute request --- dispute/main.py | 27 +++++++------ dispute/strategies/LLMQuality.py | 23 ++++++++--- dispute/strategies/ResponseSimilarity.py | 26 ++++++------ dispute/strategies/__init__.py | 20 +++++----- dispute/strategy.py | 50 +++++++++++++----------- dispute/util/chasm.py | 5 ++- dispute/util/fetch.py | 19 +++++++++ package.json | 2 +- 8 files changed, 112 insertions(+), 60 deletions(-) create mode 100644 dispute/util/fetch.py diff --git a/dispute/main.py b/dispute/main.py index 63f6a34..af6a8bf 100644 --- a/dispute/main.py +++ b/dispute/main.py @@ -7,25 +7,30 @@ import logging logging.basicConfig( - level=LOG_LEVEL, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', - datefmt='%Y-%m-%d %H:%M:%S' + level=LOG_LEVEL, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) chasm = ChasmConnection() PROCESSED_HISTORIES_FILE = "processed_histories.json" + def load_processed_histories(): if os.path.exists(PROCESSED_HISTORIES_FILE): with open(PROCESSED_HISTORIES_FILE, "r") as f: return set(json.load(f)) return set() + def save_processed_histories(histories): with open(PROCESSED_HISTORIES_FILE, "w") as f: json.dump(list(histories), f) + + processed_histories = load_processed_histories() + async def process_histories(): histories = chasm.get_prompt_history() logging.info(f"Histories: {len(histories)}") @@ -33,26 +38,26 @@ async def process_histories(): if history["_id"] in processed_histories: logging.debug(f"Skipping already processed history: {history['_id']}") continue - input = map(lambda x: x["content"], history["messages"]) - input = "\n".join(input) output = history["result"]["choices"][0]["message"]["content"] - result = await analyze_text(input, output) + result = await analyze_text(history["messages"], output) logging.debug(f"Result: {result}") logging.debug(f"Score: {result['confidence_score']}") if result["confidence_score"] > MIN_CONFIDENCE_SCORE: + rs_output = result.get("rs_output") + assert rs_output is not None, "rs_output is not generated" response = chasm.file_dispute( history["_id"], history["messages"], - history["result"]["choices"][0]["message"], + {"role": "assistant", "content": rs_output}, ) - logging.info("Dispute filed: ", response) - + if response is not None: + logging.info(f"Dispute filed: {response['result']}") + # Cache history processed_histories.add(history["_id"]) save_processed_histories(processed_histories) - - + async def main(): while True: diff --git a/dispute/strategies/LLMQuality.py b/dispute/strategies/LLMQuality.py index 82ac340..732b2a1 100644 --- a/dispute/strategies/LLMQuality.py +++ b/dispute/strategies/LLMQuality.py @@ -1,12 +1,19 @@ import re import logging +from typing import List +from util.chasm import Message from util.poll import PollAlgo from strategies.tokenizer import Tokenizer class LLMQualityStrategy: - def __init__(self, models): + """ + LLM Quality Strategy + Higher score means the model is more likely to be good. + Lower score means the model is more likely to be bad. + """ + def __init__(self, models): # llama3 tokenizer self.tokenizer = Tokenizer() self.poll = PollAlgo(models) @@ -15,8 +22,11 @@ def format_text_limit(self, text: str, limit: int): encoded_result = self.tokenizer.encode(text, bos=False, eos=False) return self.tokenizer.decode(encoded_result[:limit]) - async def analyze(self, input: str, output: str): - input = self.format_text_limit(input, 3000) + async def analyze(self, input: List[Message], output: str) -> float: + text_input = map(lambda x: x["content"], input) + text_input = "\n".join(text_input) + + formatted_input = self.format_text_limit(text_input, 3000) output = self.format_text_limit(output, 3000) # Prompt from https://arxiv.org/abs/2306.05685v4 @@ -48,7 +58,10 @@ async def analyze(self, input: str, output: str): """, }, {"role": "assistant", "content": "False"}, - {"role": "user", "content": f"Question: {input}\nAnswer: {output}"}, + { + "role": "user", + "content": f"Question: {formatted_input}\nAnswer: {output}", + }, ], ) logging.debug(results) @@ -73,4 +86,4 @@ async def analyze(self, input: str, output: str): llm_results.append(0.5) logging.debug(f"LLM Results: {llm_results}") score = sum(llm_results) / len(llm_results) - return score + return 1 - score diff --git a/dispute/strategies/ResponseSimilarity.py b/dispute/strategies/ResponseSimilarity.py index 51900ab..0cfec6b 100644 --- a/dispute/strategies/ResponseSimilarity.py +++ b/dispute/strategies/ResponseSimilarity.py @@ -1,31 +1,35 @@ +from util.chasm import Message from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity import asyncio +from typing import List, Tuple from util.llm import llm + class ResponseSimilarityAnalysis: def __init__(self, model: str): self.model = model - async def analyze(self, input: str, output: str): + async def analyze(self, input: List[Message], output: str) -> Tuple[float, str]: loop = asyncio.get_event_loop() - score = await loop.run_in_executor(None, self._sync_analyse, input, output) - return score - - def _sync_analyse(self, input: str, output: str): + score, output = await loop.run_in_executor( + None, self._sync_analyse, input, output + ) + return score, output + def _sync_analyse(self, input: List[Message], output: str): simulated_output = llm.chat.completions.create( - messages=[ - {"role": "system", "content": "You are a helpful assistant"}, - {"role": "user", "content": input}, - ], + messages=input, model=self.model, - temperature=0.8, + temperature=0, ) simulated_output_response = simulated_output.choices[0].message.content assert simulated_output_response is not None, "Simulated output is None" + assert isinstance( + simulated_output_response, str + ), "Simulated output is not a string" score = self.similarity(simulated_output_response, output) - return score + return score, simulated_output_response def similarity(self, text1: str, text2: str) -> float: vectorizer = TfidfVectorizer().fit_transform([text1, text2]) diff --git a/dispute/strategies/__init__.py b/dispute/strategies/__init__.py index 425daef..1693e11 100644 --- a/dispute/strategies/__init__.py +++ b/dispute/strategies/__init__.py @@ -1,9 +1,12 @@ +from typing import List +from util.chasm import Message if __name__ == "__main__": import sys import os - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) os.environ["LOG_LEVEL"] = "DEBUG" @@ -15,13 +18,13 @@ from config import MODELS, SIMULATION_MODEL - input = "What is the capital of France?" + input: List[Message] = [ + {"role": "user", "content": "What is the capital of France?"} + ] output = "Paris" # LLM Quality - lq = LLMQualityStrategy( - models=MODELS - ) + lq = LLMQualityStrategy(models=MODELS) lq_result = asyncio.run(lq.analyze(input, output)) print("LLM Quality: ", lq_result) @@ -32,13 +35,12 @@ # Semantic Similarity ss = SemanticSimilarityAnalysis() - ss_result = asyncio.run(ss.analyze(input, output)) + text_input = map(lambda x: x["content"], input) + text_input = "\n".join(text_input) + ss_result = asyncio.run(ss.analyze(text_input, output)) print("Semantic Similarity:", ss_result) # Static Text Analysis sta = StaticTextAnalysisStrategy() sta_result = sta.analyze(output) print("Static Text Analysis:", sta_result) - - - diff --git a/dispute/strategy.py b/dispute/strategy.py index 6599625..db6f6f3 100644 --- a/dispute/strategy.py +++ b/dispute/strategy.py @@ -1,24 +1,23 @@ -import asyncio from datetime import datetime -from typing import TypedDict +from typing import List, Optional, TypedDict import logging -from config import SIMULATION_MODEL, MODELS +from config import MIN_CONFIDENCE_SCORE, SIMULATION_MODEL, MODELS +from util.chasm import Message from strategies.ResponseSimilarity import ResponseSimilarityAnalysis from strategies.LLMQuality import LLMQualityStrategy from strategies.StaticTextAnalysis import StaticTextAnalysisStrategy -from strategies.SemanticSimilarity import SemanticSimilarityAnalysis class TextAnalysisResult(TypedDict): - ss_score: float llm_score: float rs_score: float confidence_score: float correct: bool dispute: bool + rs_output: Optional[str] -async def analyze_text(input: str, output: str) -> TextAnalysisResult: +async def analyze_text(input: List[Message], output: str) -> TextAnalysisResult: """ Analyze the given input and output texts to determine scores and dispute status. @@ -41,40 +40,47 @@ async def analyze_text(input: str, output: str) -> TextAnalysisResult: "confidence_score": 1.0, "dispute": True, "correct": False, - "ss_score": 1.0, "llm_score": 1.0, + "rs_output": None, "rs_score": 1.0, } - # 2. Semantic Similarity - ss_strategy = SemanticSimilarityAnalysis() - - # 3. LLM Analysis + # 2. LLM Analysis llm_strategy = LLMQualityStrategy(models=MODELS) + llm_score = await llm_strategy.analyze(input, output) - # 4. Response + print(f"LLM Score: {llm_score}") + if llm_score > 0.5: + return { + "confidence_score": 0.0, + "dispute": False, + "correct": True, + "llm_score": llm_score, + "rs_output": None, + "rs_score": 0.0, + } + + # 3. Response rs_strategy = ResponseSimilarityAnalysis(model=SIMULATION_MODEL) # Gather scores concurrently - scores = await asyncio.gather( - ss_strategy.analyze(input, output), - llm_strategy.analyze(input, output), - rs_strategy.analyze(input, output), - ) - ss_score, llm_score, rs_score = scores + rs_result = await rs_strategy.analyze(input, output) + rs_score, rs_output = rs_result + + print(f"RS Score: {rs_score}") # Final Score Calculation - confidence_score = (llm_score * 0.5) + (ss_score * 0.1) + (rs_score * 0.4) + confidence_score = (llm_score * 0.5) + (rs_score * 0.5) time_diff = datetime.now() - start_time print(f"Time taken: {time_diff}") - dispute = confidence_score > 0.5 + dispute = confidence_score > MIN_CONFIDENCE_SCORE return { - "ss_score": ss_score, "llm_score": llm_score, "rs_score": rs_score, + "rs_output": rs_output, "confidence_score": confidence_score, "dispute": dispute, "correct": not dispute, @@ -84,9 +90,9 @@ async def analyze_text(input: str, output: str) -> TextAnalysisResult: logging.error(f"Error in analyze_text: {e}") # Default to no dispute return { - "ss_score": 0.0, "llm_score": 0.0, "rs_score": 0.0, + "rs_output": None, "confidence_score": 0.0, "dispute": False, "correct": True, diff --git a/dispute/util/chasm.py b/dispute/util/chasm.py index b43d001..7933328 100644 --- a/dispute/util/chasm.py +++ b/dispute/util/chasm.py @@ -3,6 +3,7 @@ import requests from urllib.parse import urljoin from config import ORCHESTRATOR_URL, WEBHOOK_API_KEY +from util.fetch import request_with_backoff class Message(TypedDict): @@ -47,7 +48,9 @@ def file_dispute(self, id: str, input: List[Message], output: Message): print(f"Dispute payload: {payload}") try: - response = requests.post(url, json=payload, headers=headers) + response = request_with_backoff( + lambda: requests.post(url, json=payload, headers=headers) + ) response.raise_for_status() return response.json() except requests.exceptions.HTTPError as http_err: diff --git a/dispute/util/fetch.py b/dispute/util/fetch.py new file mode 100644 index 0000000..ed465ab --- /dev/null +++ b/dispute/util/fetch.py @@ -0,0 +1,19 @@ +import time +import random +from typing import Callable +import requests + + +def request_with_backoff(request_func: Callable[[], requests.Response], max_retries=3): + retry_delay = 1 # Initial delay in seconds + for _ in range(max_retries): + try: + response = request_func() + response.raise_for_status() + return response.json() + except requests.RequestException: + time.sleep(retry_delay) + retry_delay *= 2 # Double the delay for the next attempt + retry_delay += random.uniform(0, 1) # Add jitter + + raise Exception("Maximum retry attempts reached") diff --git a/package.json b/package.json index b7cfd90..6c778df 100644 --- a/package.json +++ b/package.json @@ -4,7 +4,7 @@ "description": "", "main": "index.js", "scripts": { - "dev": "nodemon src/server/express.ts", + "dev": "nodemon --watch src src/server/express.ts", "build": "tsc", "prepare": "husky", "prettier": "prettier . --check",