Skip to content

Commit

Permalink
Merge pull request #6 from ChasmNetwork/0.0.6-release
Browse files Browse the repository at this point in the history
0.0.6 Release: Scout Recompute & Exponential Retries
  • Loading branch information
rekttdoteth authored Aug 29, 2024
2 parents 2b12d3d + 6cc6a86 commit 2acb6c9
Show file tree
Hide file tree
Showing 15 changed files with 494 additions and 146 deletions.
14 changes: 11 additions & 3 deletions dispute/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import dotenv

dotenv.load_dotenv()

import asyncio
from util.chasm import ChasmConnection
from strategy import analyze_text


async def main():
connection = ChasmConnection()
histories = connection.get_benchmark_test()
i = 0
for history in histories:
print(f"--- {i} ---")
result = await analyze_text(history["input"], history["output"])
result = await analyze_text(
history["input"], history["output"], 0, "openai", "gpt-3"
)
print(f"Result: {result}")
print(f"Score: {result['confidence_score']}")
print(f"Assert Check: {'✅' if result['correct'] == history['answer'] else '❌'}")
print(
f"Assert Check: {'✅' if result['correct'] == history['answer'] else '❌'}"
)
print(f"Dispute: {result['dispute']}")
i += 1
pass
Expand All @@ -23,5 +29,7 @@ async def main():
if __name__ == "__main__":
# set TOKENIZERS_PARALLELISM
import os

os.environ["TOKENIZERS_PARALLELISM"] = "true"
asyncio.run(main())
asyncio.run(main())

3 changes: 3 additions & 0 deletions dispute/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

LLM_BASE_URL = os.getenv("LLM_BASE_URL")
LLM_API_KEY = os.getenv("LLM_API_KEY")
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
MODELS = os.getenv("MODELS", "gemma2-9b-it").split(",")
SIMULATION_MODEL = os.getenv("SIMULATION_MODEL", "gemma2-9b-it")
ORCHESTRATOR_URL = os.getenv("ORCHESTRATOR_URL")
WEBHOOK_API_KEY = os.getenv("WEBHOOK_API_KEY")
MIN_CONFIDENCE_SCORE = float(os.getenv("MIN_CONFIDENCE_SCORE", 0.5))
MIN_RESPONSE_DIFFERENCE = float(os.getenv("MIN_RESPONSE_DIFFERENCE", 0.8))
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")

print(f"LOG_LEVEL: {LOG_LEVEL}")
Expand Down
33 changes: 22 additions & 11 deletions dispute/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,63 @@
import logging

logging.basicConfig(
level=LOG_LEVEL,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)

chasm = ChasmConnection()
PROCESSED_HISTORIES_FILE = "processed_histories.json"


def load_processed_histories():
if os.path.exists(PROCESSED_HISTORIES_FILE):
with open(PROCESSED_HISTORIES_FILE, "r") as f:
return set(json.load(f))
return set()


def save_processed_histories(histories):
with open(PROCESSED_HISTORIES_FILE, "w") as f:
json.dump(list(histories), f)


processed_histories = load_processed_histories()


async def process_histories():
histories = chasm.get_prompt_history()
logging.info(f"Histories: {len(histories)}")
for history in histories:
if history["_id"] in processed_histories:
logging.debug(f"Skipping already processed history: {history['_id']}")
continue
input = map(lambda x: x["content"], history["messages"])
input = "\n".join(input)
output = history["result"]["choices"][0]["message"]["content"]
result = await analyze_text(input, output)
result = await analyze_text(
history["messages"],
output,
history["seed"],
history["result"]["scout"]["provider"],
history["result"]["scout"]["model"],
)
logging.debug(f"Result: {result}")
logging.debug(f"Score: {result['confidence_score']}")

if result["confidence_score"] > MIN_CONFIDENCE_SCORE:
rs_output = result.get("rs_output")
assert rs_output is not None, "rs_output is not generated"
response = chasm.file_dispute(
history["_id"],
history["messages"],
history["result"]["choices"][0]["message"],
{"role": "assistant", "content": rs_output},
)
logging.info("Dispute filed: ", response)

if response is not None:
logging.info(f"Dispute filed: {response['result']}")

# Cache history
processed_histories.add(history["_id"])
save_processed_histories(processed_histories)




async def main():
while True:
Expand Down
23 changes: 18 additions & 5 deletions dispute/strategies/LLMQuality.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import re
import logging
from typing import List
from util.chasm import Message
from util.poll import PollAlgo
from strategies.tokenizer import Tokenizer


class LLMQualityStrategy:
def __init__(self, models):
"""
LLM Quality Strategy
Higher score means the model is more likely to be good.
Lower score means the model is more likely to be bad.
"""

def __init__(self, models):
# llama3 tokenizer
self.tokenizer = Tokenizer()
self.poll = PollAlgo(models)
Expand All @@ -15,8 +22,11 @@ def format_text_limit(self, text: str, limit: int):
encoded_result = self.tokenizer.encode(text, bos=False, eos=False)
return self.tokenizer.decode(encoded_result[:limit])

async def analyze(self, input: str, output: str):
input = self.format_text_limit(input, 3000)
async def analyze(self, input: List[Message], output: str) -> float:
text_input = map(lambda x: x["content"], input)
text_input = "\n".join(text_input)

formatted_input = self.format_text_limit(text_input, 3000)
output = self.format_text_limit(output, 3000)

# Prompt from https://arxiv.org/abs/2306.05685v4
Expand Down Expand Up @@ -48,7 +58,10 @@ async def analyze(self, input: str, output: str):
""",
},
{"role": "assistant", "content": "False"},
{"role": "user", "content": f"Question: {input}\nAnswer: {output}"},
{
"role": "user",
"content": f"Question: {formatted_input}\nAnswer: {output}",
},
],
)
logging.debug(results)
Expand All @@ -73,4 +86,4 @@ async def analyze(self, input: str, output: str):
llm_results.append(0.5)
logging.debug(f"LLM Results: {llm_results}")
score = sum(llm_results) / len(llm_results)
return score
return 1 - score
52 changes: 52 additions & 0 deletions dispute/strategies/ResponseRecompute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from util.chasm import Message
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from typing import List, Tuple
from util.llm import LLMProviders


class ResponseRecomputeAnalysis:
def __init__(self, model: str):
self.model = model
self.llm_providers = LLMProviders()

async def analyse(
self,
input: List[Message],
output: str,
seed: int,
provider: str = "ollama",
) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
score, output = await loop.run_in_executor(
None, self._sync_analyse, input, output, seed, provider
)
return score, output

def _sync_analyse(
self,
input: List[Message],
output: str,
seed: int,
provider: str,
) -> Tuple[float, str]:
llm = self.llm_providers[provider]
recompute_output = llm.chat.completions.create(
messages=input,
model=self.model,
temperature=0,
seed=seed,
)
output_response = recompute_output.choices[0].message.content
assert output_response is not None, "Recompute output is None"
assert isinstance(output_response, str), "Recompute output is not a string"
score = self.similarity(output_response, output)
return score, output_response

def similarity(self, text1: str, text2: str) -> float:
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
vectors = vectorizer.toarray()
cosine_sim = cosine_similarity(vectors[0:1], vectors[1:])
score = cosine_sim[0][0]
return 1 - score
30 changes: 18 additions & 12 deletions dispute/strategies/ResponseSimilarity.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,37 @@
from util.chasm import Message
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import asyncio
from util.llm import llm
from typing import List, Tuple
from util.llm import LLMProviders


class ResponseSimilarityAnalysis:
def __init__(self, model: str):
self.model = model
self.provider = LLMProviders()

async def analyze(self, input: str, output: str):
async def analyze(self, input: List[Message], output: str) -> Tuple[float, str]:
loop = asyncio.get_event_loop()
score = await loop.run_in_executor(None, self._sync_analyse, input, output)
return score

def _sync_analyse(self, input: str, output: str):
score, output = await loop.run_in_executor(
None, self._sync_analyse, input, output
)
return score, output

def _sync_analyse(self, input: List[Message], output: str):
llm = self.provider["ollama"]
simulated_output = llm.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": input},
],
messages=input,
model=self.model,
temperature=0.8,
temperature=0,
)
simulated_output_response = simulated_output.choices[0].message.content
assert simulated_output_response is not None, "Simulated output is None"
assert isinstance(
simulated_output_response, str
), "Simulated output is not a string"
score = self.similarity(simulated_output_response, output)
return score
return score, simulated_output_response

def similarity(self, text1: str, text2: str) -> float:
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
Expand Down
24 changes: 16 additions & 8 deletions dispute/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
from typing import List


if __name__ == "__main__":
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

os.environ["LOG_LEVEL"] = "DEBUG"

import asyncio
from LLMQuality import LLMQualityStrategy
from ResponseSimilarity import ResponseSimilarityAnalysis
from SemanticSimilarity import SemanticSimilarityAnalysis
from ResponseRecompute import ResponseRecomputeAnalysis
from StaticTextAnalysis import StaticTextAnalysisStrategy
from util.chasm import Message

from config import MODELS, SIMULATION_MODEL

input = "What is the capital of France?"
input: List[Message] = [
{"role": "user", "content": "What is the capital of France?"}
]
output = "Paris"

# LLM Quality
lq = LLMQualityStrategy(
models=MODELS
)
lq = LLMQualityStrategy(models=MODELS)
lq_result = asyncio.run(lq.analyze(input, output))
print("LLM Quality: ", lq_result)

Expand All @@ -32,13 +36,17 @@

# Semantic Similarity
ss = SemanticSimilarityAnalysis()
ss_result = asyncio.run(ss.analyze(input, output))
text_input = map(lambda x: x["content"], input)
text_input = "\n".join(text_input)
ss_result = asyncio.run(ss.analyze(text_input, output))
print("Semantic Similarity:", ss_result)

# Static Text Analysis
sta = StaticTextAnalysisStrategy()
sta_result = sta.analyze(output)
print("Static Text Analysis:", sta_result)



# Response Recompute Analysis
rra = ResponseRecomputeAnalysis(model=SIMULATION_MODEL)
rra_result = asyncio.run(rra.analyse(input, output, 42, "groq"))
print("Response Recompute Analysis:", rra_result)
Loading

0 comments on commit 2acb6c9

Please sign in to comment.