Skip to content

Commit

Permalink
Merge pull request #87 from codelion/feat-update-router-training
Browse files Browse the repository at this point in the history
Feat update router training
  • Loading branch information
codelion authored Nov 7, 2024
2 parents 0378fd4 + c5a4e95 commit 0ab9f8c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ or your own code where you want to use the results from optillm. You can use it

| Plugin | Slug | Description |
| ----------------------- | ------------------ | ---------------------------------------------------------------------------------------------- |
| Router | `router` | Uses the [optillm-bert-uncased](https://huggingface.co/codelion/optillm-bert-uncased) model to route requests to different approaches based on the user prompt |
| Memory | `memory` | Implements a short term memory layer, enables you to use unbounded context length with any LLM |
| Privacy | `privacy` | Anonymize PII data in request and deanonymize it back to original value in response |
| Read URLs | `readurls` | Reads all URLs found in the request, fetches the content at the URL and adds it to the context |
Expand Down
5 changes: 2 additions & 3 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import re
from concurrent.futures import ThreadPoolExecutor

# Import the LiteLLM wrapper
from optillm.litellm_wrapper import LiteLLMWrapper

# Import approach modules
from optillm.mcts import chat_with_mcts
from optillm.bon import best_of_n_sampling
Expand Down Expand Up @@ -74,6 +71,8 @@ def get_config():
azure_ad_token_provider=token_provider
)
else:
# Import the LiteLLM wrapper
from optillm.litellm_wrapper import LiteLLMWrapper
default_client = LiteLLMWrapper()
return default_client, API_KEY

Expand Down
130 changes: 130 additions & 0 deletions scripts/gen_optillm_ground_truth_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import json
import argparse
import asyncio
from tqdm import tqdm
from datasets import load_dataset
from openai import AsyncOpenAI
from typing import List, Dict, Any, Tuple
import random

# OptILM approaches remain the same as in original script
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]

# Dataset configurations
DATASET_CONFIGS = [
("MixEval", "free_form"),
("MixEval", "multiple_choice"),
("MixEval_Hard", "free_form"),
("MixEval_Hard", "multiple_choice")
]

def construct_prompt(sample: Dict[str, Any], split_type: str) -> str:
"""Construct prompt based on split type."""
context = sample.get("context", "")
prompt = sample["prompt"]

if split_type == "multiple_choice":
options = sample["options"]
options_text = "\nOptions:\n" + "\n".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
return f"Context: {context}\n\nQuestion: {prompt}{options_text}\n\nProvide the correct answer from the options above."
else:
return f"Context: {context}\n\nQuestion: {prompt}\n\nProvide your answer."

def is_correct_response(response: str, targets: List[str]) -> bool:
"""Check if response matches any of the target answers."""
response = response.strip().lower()
return any(target.strip().lower() == response for target in targets)

async def generate_response(prompt: str, approach: str) -> Dict[str, Any]:
"""Generate a response using the specified approach."""
if approach == "none":
client = AsyncOpenAI()
response = await client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
)
return {
"content": response.choices[0].message.content,
"tokens": response.usage.completion_tokens,
}
else:
client = AsyncOpenAI(api_key="none", base_url="http://localhost:8000/v1")
response = await client.chat.completions.create(
model=f"{approach}-gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
)
return {
"content": response.choices[0].message.content,
"tokens": response.usage.completion_tokens,
}

def rank_responses(responses: List[Dict[str, Any]], targets: List[str]) -> List[int]:
"""Rank responses based on correctness and token efficiency."""
# Create tuples of (index, is_correct, tokens) for sorting
ranked_data = []
for i, response in enumerate(responses):
is_correct = is_correct_response(response["content"], targets)
ranked_data.append((i, is_correct, response["tokens"]))

# Sort by correctness (True first) and then by tokens (ascending)
ranked_data.sort(key=lambda x: (-int(x[1]), x[2]))

# Extract indices for final ranking
return [idx for idx, _, _ in ranked_data]

async def process_sample(sample: Dict[str, Any], split_type: str) -> Dict[str, Any]:
"""Process a single sample from the dataset."""
prompt = construct_prompt(sample, split_type)
results = []

# Generate responses for each approach
for approach in APPROACHES:
response = await generate_response(prompt, approach)
results.append({"approach": approach, **response})

# Rank the responses based on correctness and token efficiency
rankings = rank_responses(results, sample["target"])

# Add rankings to results
for rank, idx in enumerate(rankings):
results[idx]["rank"] = rank

return {
"prompt": prompt,
"results": results,
}

async def generate_dataset(num_samples: int, output_file: str):
"""Generate the dataset and save it to a JSONL file."""
with open(output_file, "w") as f:
for config, split_type in DATASET_CONFIGS:
print(f"Processing {config} - {split_type}")
dataset = load_dataset("MixEval/MixEval", config, split=split_type)

# Calculate samples per configuration
samples_per_config = max(1, num_samples // len(DATASET_CONFIGS))

for sample in tqdm(dataset.select(range(samples_per_config)),
total=samples_per_config,
desc=f"{config}-{split_type}"):
try:
result = await process_sample(sample, split_type)
f.write(json.dumps(result) + "\n")
except Exception as e:
print(f"Error processing sample: {str(e)}")

def main():
parser = argparse.ArgumentParser(description="Generate OptILM Ground Truth dataset")
parser.add_argument("--num_samples", type=int, default=100,
help="Total number of samples to process (divided among configurations)")
parser.add_argument("--output_file", type=str,
default="optillm_ground_truth_dataset.jsonl",
help="Output file path")
args = parser.parse_args()

asyncio.run(generate_dataset(args.num_samples, args.output_file))
print(f"Dataset generated and saved to {args.output_file}")

if __name__ == "__main__":
main()
54 changes: 49 additions & 5 deletions scripts/train_optillm_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __getitem__(self, idx):
}

def load_and_preprocess_data(tokenizer):
dataset = load_dataset('json', data_files='optillm_dataset.jsonl')
dataset = load_dataset('json', data_files='optillm_combined_dataset.jsonl')

data_items = []

Expand Down Expand Up @@ -290,11 +290,54 @@ def main(args):
best_model.eval()

test_prompts = [
# Linear Programming (likely MCTS or Z3)
"Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0",
# Graph Theory (likely MCTS or RTO)
"Find the shortest path between nodes A and B in the given graph",
# Recursive Problem (likely MOA or COT)
"Solve the Tower of Hanoi problem with 4 disks",
# Number Theory (likely NONE or Z3)
"Determine if the given number is prime",
"Find all possible combinations of coins that sum up to $1"
# Combinatorics (likely MCTS or BON)
"Find all possible combinations of coins that sum up to $1",
# Symbolic Mathematics (likely Z3 or LEAP)
"Solve the equation: 2x^3 - 5x^2 + 3x - 7 = 0",
# Natural Language Processing (likely PVG or SELF_CONSISTENCY)
"Summarize the main points of the given article in three sentences",
# Computer Vision (likely RSTAR or PVG)
"Describe the contents of the image, including any text present",
# Game Theory (likely MCTS or BON)
"Find the Nash equilibrium for the prisoner's dilemma game",
# Constraint Satisfaction (likely Z3 or PLANSEARCH)
"Solve the Sudoku puzzle given the following initial configuration",
# Optimization (likely MCTS or RSTAR)
"Find the optimal route for a salesperson visiting 10 cities",
# Logical Reasoning (likely COT_REFLECTION or SELF_CONSISTENCY)
"If all A are B, and some B are C, what can we conclude about A and C?",
# Time Series Analysis (likely RSTAR or PVG)
"Predict the stock price for the next week given the past year's data",
# Robotics (likely MCTS or RTO)
"Plan a path for a robot to navigate through a room with obstacles",
# Natural Language Understanding (likely PVG or LEAP)
"Identify the sentiment and main topics in the following customer review",
# Theorem Proving (likely Z3 or COT_REFLECTION)
"Prove that the square root of 2 is irrational",
# Reinforcement Learning (likely MCTS or RSTAR)
"Design a policy for an agent to maximize its score in a given game environment",
# Information Retrieval (likely PVG or SELF_CONSISTENCY)
"Find the most relevant documents in the corpus for the given query",
# Cryptography (likely Z3 or LEAP)
"Decrypt the following message encrypted with a simple substitution cipher",
# Quantum Computing (likely NONE or Z3)
"Simulate a quantum circuit with 3 qubits and measure the output",
# Computer Graphics (likely RSTAR or PVG)
"Generate a 3D model of a house based on the given floor plan",
# Bioinformatics (likely Z3 or LEAP)
"Find potential binding sites for a given protein sequence in a DNA strand",
# Automated Reasoning (likely COT_REFLECTION or Z3)
"Given a set of logical statements, determine if the conclusion follows",
# Natural Language Generation (likely PVG or SELF_CONSISTENCY)
"Write a short story in the style of Edgar Allan Poe about a haunted lighthouse"
]

effort_levels = [0.0, 0.2, 0.5, 0.8, 1.0]
Expand All @@ -310,13 +353,14 @@ def main(args):
parser = argparse.ArgumentParser(description="Train OptILM classifier")
parser.add_argument("--model_name", type=str, default="google-bert/bert-large-uncased", help="Pretrained model name")
parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training")
parser.add_argument("--learning_rate", type=float, default=1e-6, help="Learning rate")
parser.add_argument("--num_epochs", type=int, default=10, help="Maximum number of training epochs")
parser.add_argument("--learning_rate", type=float, default=5e-7, help="Learning rate")
parser.add_argument("--num_epochs", type=int, default=20, help="Maximum number of training epochs")
parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub")
parser.add_argument("--hub_model_id", type=str, help="Model ID for Hugging Face Hub")
parser.add_argument("--k_folds", type=int, default=5, help="Number of folds for cross-validation")
parser.add_argument("--patience", type=int, default=3, help="Number of epochs to wait for improvement before early stopping")
parser.add_argument("--clip_value", type=float, default=1.0, help="Gradient clipping value")

args = parser.parse_args()
main(args)
main(args)

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="optillm",
version="0.0.8",
version="0.0.9",
packages=find_packages(),
py_modules=['optillm'],
package_data={
Expand Down

0 comments on commit 0ab9f8c

Please sign in to comment.