Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.0.6 Release: Scout Recompute & Exponential Retries #6

Merged
merged 7 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions dispute/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import dotenv

dotenv.load_dotenv()

import asyncio
from util.chasm import ChasmConnection
from strategy import analyze_text


async def main():
connection = ChasmConnection()
histories = connection.get_benchmark_test()
i = 0
for history in histories:
print(f"--- {i} ---")
result = await analyze_text(history["input"], history["output"])
result = await analyze_text(
history["input"], history["output"], 0, "openai", "gpt-3"
)
print(f"Result: {result}")
print(f"Score: {result['confidence_score']}")
print(f"Assert Check: {'✅' if result['correct'] == history['answer'] else '❌'}")
print(
f"Assert Check: {'✅' if result['correct'] == history['answer'] else '❌'}"
)
print(f"Dispute: {result['dispute']}")
i += 1
pass
Expand All @@ -23,5 +29,7 @@ async def main():
if __name__ == "__main__":
# set TOKENIZERS_PARALLELISM
import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"
asyncio.run(main())
asyncio.run(main())

3 changes: 3 additions & 0 deletions dispute/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

LLM_BASE_URL = os.getenv("LLM_BASE_URL")
LLM_API_KEY = os.getenv("LLM_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
MODELS = os.getenv("MODELS", "gemma2-9b-it").split(",")
SIMULATION_MODEL = os.getenv("SIMULATION_MODEL", "gemma2-9b-it")
ORCHESTRATOR_URL = os.getenv("ORCHESTRATOR_URL")
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY")
MIN_CONFIDENCE_SCORE = float(os.getenv("MIN_CONFIDENCE_SCORE", 0.5))
MIN_RESPONSE_DIFFERENCE = float(os.getenv("MIN_RESPONSE_DIFFERENCE", 0.8))
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

print(f"LOG_LEVEL: {LOG_LEVEL}")
Expand Down
33 changes: 22 additions & 11 deletions dispute/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,63 @@
import logging

logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

chasm = ChasmConnection()
PROCESSED_HISTORIES_FILE = "processed_histories.json"


def load_processed_histories():
if os.path.exists(PROCESSED_HISTORIES_FILE):
with open(PROCESSED_HISTORIES_FILE, "r") as f:
return set(json.load(f))
return set()


def save_processed_histories(histories):
with open(PROCESSED_HISTORIES_FILE, "w") as f:
json.dump(list(histories), f)


processed_histories = load_processed_histories()


async def process_histories():
histories = chasm.get_prompt_history()
logging.info(f"Histories: {len(histories)}")
for history in histories:
if history["_id"] in processed_histories:
logging.debug(f"Skipping already processed history: {history['_id']}")
continue
input = map(lambda x: x["content"], history["messages"])
input = "\n".join(input)
output = history["result"]["choices"][0]["message"]["content"]
result = await analyze_text(input, output)
result = await analyze_text(
history["messages"],
output,
history["seed"],
history["result"]["scout"]["provider"],
history["result"]["scout"]["model"],
)
logging.debug(f"Result: {result}")
logging.debug(f"Score: {result['confidence_score']}")

if result["confidence_score"] > MIN_CONFIDENCE_SCORE:
rs_output = result.get("rs_output")
assert rs_output is not None, "rs_output is not generated"
response = chasm.file_dispute(
history["_id"],
history["messages"],
history["result"]["choices"][0]["message"],
{"role": "assistant", "content": rs_output},
)
logging.info("Dispute filed: ", response)

if response is not None:
logging.info(f"Dispute filed: {response['result']}")

# Cache history
processed_histories.add(history["_id"])
save_processed_histories(processed_histories)




async def main():
while True:
Expand Down
23 changes: 18 additions & 5 deletions dispute/strategies/LLMQuality.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import re
import logging
from typing import List
from util.chasm import Message
from util.poll import PollAlgo
from strategies.tokenizer import Tokenizer


class LLMQualityStrategy:
def __init__(self, models):
"""
LLM Quality Strategy
Higher score means the model is more likely to be good.
Lower score means the model is more likely to be bad.
"""

def __init__(self, models):
# llama3 tokenizer
self.tokenizer = Tokenizer()
self.poll = PollAlgo(models)
Expand All @@ -15,8 +22,11 @@ def format_text_limit(self, text: str, limit: int):
encoded_result = self.tokenizer.encode(text, bos=False, eos=False)
return self.tokenizer.decode(encoded_result[:limit])

async def analyze(self, input: str, output: str):
input = self.format_text_limit(input, 3000)
async def analyze(self, input: List[Message], output: str) -> float:
text_input = map(lambda x: x["content"], input)
text_input = "\n".join(text_input)

formatted_input = self.format_text_limit(text_input, 3000)
output = self.format_text_limit(output, 3000)

# Prompt from https://arxiv.org/abs/2306.05685v4
Expand Down Expand Up @@ -48,7 +58,10 @@ async def analyze(self, input: str, output: str):
""",
},
{"role": "assistant", "content": "False"},
{"role": "user", "content": f"Question: {input}\nAnswer: {output}"},
{
"role": "user",
"content": f"Question: {formatted_input}\nAnswer: {output}",
},
],
)
logging.debug(results)
Expand All @@ -73,4 +86,4 @@ async def analyze(self, input: str, output: str):
llm_results.append(0.5)
logging.debug(f"LLM Results: {llm_results}")
score = sum(llm_results) / len(llm_results)
return score
return 1 - score
52 changes: 52 additions & 0 deletions dispute/strategies/ResponseRecompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from util.chasm import Message
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from typing import List, Tuple
from util.llm import LLMProviders


class ResponseRecomputeAnalysis:
def __init__(self, model: str):
self.model = model
self.llm_providers = LLMProviders()

async def analyse(
self,
input: List[Message],
output: str,
seed: int,
provider: str = "ollama",
) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
score, output = await loop.run_in_executor(
None, self._sync_analyse, input, output, seed, provider
)
return score, output

def _sync_analyse(
self,
input: List[Message],
output: str,
seed: int,
provider: str,
) -> Tuple[float, str]:
llm = self.llm_providers[provider]
recompute_output = llm.chat.completions.create(
messages=input,
model=self.model,
temperature=0,
seed=seed,
)
output_response = recompute_output.choices[0].message.content
assert output_response is not None, "Recompute output is None"
assert isinstance(output_response, str), "Recompute output is not a string"
score = self.similarity(output_response, output)
return score, output_response

def similarity(self, text1: str, text2: str) -> float:
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
vectors = vectorizer.toarray()
cosine_sim = cosine_similarity(vectors[0:1], vectors[1:])
score = cosine_sim[0][0]
return 1 - score
30 changes: 18 additions & 12 deletions dispute/strategies/ResponseSimilarity.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
from util.chasm import Message
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from util.llm import llm
from typing import List, Tuple
from util.llm import LLMProviders


class ResponseSimilarityAnalysis:
def __init__(self, model: str):
self.model = model
self.provider = LLMProviders()

async def analyze(self, input: str, output: str):
async def analyze(self, input: List[Message], output: str) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
score = await loop.run_in_executor(None, self._sync_analyse, input, output)
return score

def _sync_analyse(self, input: str, output: str):
score, output = await loop.run_in_executor(
None, self._sync_analyse, input, output
)
return score, output

def _sync_analyse(self, input: List[Message], output: str):
llm = self.provider["ollama"]
simulated_output = llm.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": input},
],
messages=input,
model=self.model,
temperature=0.8,
temperature=0,
)
simulated_output_response = simulated_output.choices[0].message.content
assert simulated_output_response is not None, "Simulated output is None"
assert isinstance(
simulated_output_response, str
), "Simulated output is not a string"
score = self.similarity(simulated_output_response, output)
return score
return score, simulated_output_response

def similarity(self, text1: str, text2: str) -> float:
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
Expand Down
24 changes: 16 additions & 8 deletions dispute/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
from typing import List


if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

os.environ["LOG_LEVEL"] = "DEBUG"

import asyncio
from LLMQuality import LLMQualityStrategy
from ResponseSimilarity import ResponseSimilarityAnalysis
from SemanticSimilarity import SemanticSimilarityAnalysis
from ResponseRecompute import ResponseRecomputeAnalysis
from StaticTextAnalysis import StaticTextAnalysisStrategy
from util.chasm import Message

from config import MODELS, SIMULATION_MODEL

input = "What is the capital of France?"
input: List[Message] = [
{"role": "user", "content": "What is the capital of France?"}
]
output = "Paris"

# LLM Quality
lq = LLMQualityStrategy(
models=MODELS
)
lq = LLMQualityStrategy(models=MODELS)
lq_result = asyncio.run(lq.analyze(input, output))
print("LLM Quality: ", lq_result)

Expand All @@ -32,13 +36,17 @@

# Semantic Similarity
ss = SemanticSimilarityAnalysis()
ss_result = asyncio.run(ss.analyze(input, output))
text_input = map(lambda x: x["content"], input)
text_input = "\n".join(text_input)
ss_result = asyncio.run(ss.analyze(text_input, output))
print("Semantic Similarity:", ss_result)

# Static Text Analysis
sta = StaticTextAnalysisStrategy()
sta_result = sta.analyze(output)
print("Static Text Analysis:", sta_result)



# Response Recompute Analysis
rra = ResponseRecomputeAnalysis(model=SIMULATION_MODEL)
rra_result = asyncio.run(rra.analyse(input, output, 42, "groq"))
print("Response Recompute Analysis:", rra_result)
Loading
Loading