Skip to content

Commit

Permalink
Merge pull request #129 from codelion/fix-parsing-tagged-conv-bug
Browse files Browse the repository at this point in the history
Fix parsing tagged conv bug
  • Loading branch information
codelion authored Jan 11, 2025
2 parents 7a23694 + f96b435 commit 6b803e2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 17 deletions.
29 changes: 22 additions & 7 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,23 @@ def parse_conversation(messages):

def tagged_conversation_to_messages(response_text):
"""Convert a tagged conversation string or list of strings into a list of messages.
If the input doesn't contain User:/Assistant: tags, return it as is.
Args:
response_text: Either a string containing "User:" and "Assistant:" tags,
or a list of such strings.
Returns:
If input is a string: A list of message dictionaries.
If input is a list: A list of lists of message dictionaries.
If input has tags: A list of message dictionaries.
If input has no tags: The original input.
"""
def has_conversation_tags(text):
return "User:" in text or "Assistant:" in text

def process_single_response(text):
if not has_conversation_tags(text):
return text

messages = []
# Split on "User:" or "Assistant:" while keeping the delimiter
parts = re.split(r'(?=(User:|Assistant:))', text.strip())
Expand All @@ -447,7 +454,11 @@ def process_single_response(text):
return messages

if isinstance(response_text, list):
return [process_single_response(text) for text in response_text]
processed = [process_single_response(text) for text in response_text]
# If none of the responses had tags, return original list
if all(isinstance(p, str) for p in processed):
return response_text
return processed
else:
return process_single_response(response_text)

Expand Down Expand Up @@ -555,14 +566,18 @@ def proxy():
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500

# Convert tagged conversation to messages format if needed
if isinstance(response, list):
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
for msg in tagged_conversation_to_messages(response)]
processed_response = tagged_conversation_to_messages(response)
# If processed_response is a list of message lists, extract last message content
if processed_response != response: # Only process if format changed
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
for msg in processed_response]
# Otherwise keep original response
else:
messages = tagged_conversation_to_messages(response)
if messages: # Only take the last message if we have any
if isinstance(messages, list) and messages: # Only process if format changed
response = messages[-1]['content']

if stream:
Expand Down
53 changes: 44 additions & 9 deletions scripts/eval_aime_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import re
import time
from typing import List, Dict, Tuple, Optional
from typing import List, Dict, Tuple, Optional, Union
from datetime import datetime
from openai import OpenAI
from datasets import load_dataset
Expand Down Expand Up @@ -89,9 +89,17 @@ def extract_answer(response: str) -> Optional[int]:

return None

def get_llm_response(problem: str, model: str) -> str:
def get_llm_response(problem: str, model: str) -> Union[str, List[Dict]]:
"""
Get response from the LLM for a given problem.
If multiple choices are returned, formats them as attempt dictionaries.
Args:
problem (str): The problem text
model (str): The model identifier
Returns:
Union[str, List[Dict]]: Either a string response or list of attempt dictionaries
"""
try:
response = client.with_options(timeout=1000.0).chat.completions.create(
Expand All @@ -101,7 +109,23 @@ def get_llm_response(problem: str, model: str) -> str:
],
max_tokens=8192,
)

# If there's more than one choice, format as attempts
if len(response.choices) > 1:
attempts = []
for i, choice in enumerate(response.choices):
response_text = choice.message.content.strip()
predicted_answer = extract_answer(response_text)
attempts.append({
"attempt_number": i + 1,
"response": response_text,
"predicted_answer": predicted_answer
})
return attempts

# If single choice, return as before
return response.choices[0].message.content.strip()

except Exception as e:
logger.error(f"Error getting LLM response: {e}")
return ""
Expand All @@ -119,14 +143,25 @@ def make_n_attempts(problem: str, model: str, n: int) -> List[Dict]:
List[Dict]: List of dictionaries containing response and predicted answer for each attempt
"""
attempts = []
for i in range(n):
remaining_attempts = n

while remaining_attempts > 0:
response = get_llm_response(problem, model)
predicted_answer = extract_answer(response)
attempts.append({
"attempt_number": i + 1,
"response": response,
"predicted_answer": predicted_answer
})

# If response is already formatted as attempts
if isinstance(response, list):
attempts.extend(response)
remaining_attempts = n - len(attempts)
else:
# Process single response as before
predicted_answer = extract_answer(response)
attempts.append({
"attempt_number": len(attempts) + 1,
"response": response,
"predicted_answer": predicted_answer
})
remaining_attempts -= 1

return attempts

def evaluate_pass_at_n(attempts: List[Dict], correct_answer: int) -> Tuple[bool, Optional[int]]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="optillm",
version="0.0.23",
version="0.0.24",
packages=find_packages(),
py_modules=['optillm'],
package_data={
Expand Down

0 comments on commit 6b803e2

Please sign in to comment.