Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PCSwingle committed Apr 11, 2024
1 parent faa8979 commit 08f172d
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 67 deletions.
8 changes: 4 additions & 4 deletions benchmarks/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from openai.types.chat.completion_create_params import ResponseFormat
from spice import SpiceMessage
from spice.spice import get_model_from_name

from benchmarks.arg_parser import common_benchmark_parser
from benchmarks.benchmark_result import BenchmarkResult
Expand All @@ -22,7 +23,6 @@
from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples
from mentat.config import Config
from mentat.git_handler import get_git_diff, get_mentat_branch, get_mentat_hexsha
from mentat.llm_api_handler import model_context_size, prompt_tokens
from mentat.sampler.sample import Sample
from mentat.sampler.utils import setup_repo
from mentat.session_context import SESSION_CONTEXT
Expand All @@ -45,20 +45,20 @@ def git_diff_from_comparison_commit(sample: Sample, comparison_commit: str) -> s

async def grade(to_grade, prompt, model="gpt-4-1106-preview"):
try:
llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
messages: List[SpiceMessage] = [
{"role": "system", "content": prompt},
{"role": "user", "content": to_grade},
]
tokens = prompt_tokens(messages, model)
max_tokens = model_context_size(model) - 1000 # Response buffer
tokens = llm_api_handler.spice.count_prompt_tokens(messages, model)
max_tokens = get_model_from_name(model).context_length - 1000 # Response buffer
if tokens > max_tokens:
print("Prompt too long! Truncating... (this may affect results)")
tokens_to_remove = tokens - max_tokens
chars_per_token = len(str(messages)) / tokens
chars_to_remove = int(chars_per_token * tokens_to_remove)
messages[1]["content"] = messages[1]["content"][:-chars_to_remove]

llm_api_handler = SESSION_CONTEXT.get().llm_api_handler
llm_grade = await llm_api_handler.call_llm_api(messages, model, None, False, ResponseFormat(type="json_object"))
content = llm_grade.text
return json.loads(content)
Expand Down
5 changes: 2 additions & 3 deletions benchmarks/exercism_practice.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,12 @@ async def run_exercise(problem_dir, language="python", max_iterations=2):
messages = client.get_conversation().literal_messages
await client.shutdown()
passed = exercise_runner.passed()
cost_tracker = SESSION_CONTEXT.get().cost_tracker
result = BenchmarkResult(
iterations=iterations,
passed=passed,
name=exercise_runner.name,
tokens=cost_tracker.total_tokens,
cost=cost_tracker.total_cost,
tokens=None,
cost=SESSION_CONTEXT.get().llm_api_handler.spice.total_cost / 100,
transcript={"id": problem_dir, "messages": messages},
)
if had_error:
Expand Down
5 changes: 2 additions & 3 deletions benchmarks/run_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Conf
await mentat.startup()
session_context = SESSION_CONTEXT.get()
conversation = session_context.conversation
cost_tracker = session_context.cost_tracker
for msg in sample.message_history:
if msg["role"] == "user":
conversation.add_user_message(msg["content"])
Expand Down Expand Up @@ -118,8 +117,8 @@ async def run_sample(sample: Sample, cwd: Path | str | None = None, config: Conf
"id": sample.id,
"message_eval": message_eval,
"diff_eval": diff_eval,
"cost": cost_tracker.total_cost,
"tokens": cost_tracker.total_tokens,
"cost": session_context.llm_api_handler.spice.total_cost / 100,
"tokens": None,
"transcript": {
"id": sample.id,
"messages": transcript_messages,
Expand Down
8 changes: 0 additions & 8 deletions docs/source/developer/mentat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,6 @@ mentat.conversation module
:undoc-members:
:show-inheritance:

mentat.cost\_tracker module
---------------------------

.. automodule:: mentat.cost_tracker
:members:
:undoc-members:
:show-inheritance:

mentat.ctags module
-------------------

Expand Down
4 changes: 1 addition & 3 deletions scripts/git_log_to_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from mentat.code_context import CodeContext
from mentat.code_file_manager import CodeFileManager
from mentat.config import Config
from mentat.llm_api import CostTracker, count_tokens
from mentat.parsers.git_parser import GitParser
from mentat.sampler.utils import clone_repo
from mentat.session_context import SESSION_CONTEXT, SessionContext
Expand Down Expand Up @@ -126,7 +125,7 @@ async def translate_commits_to_transcripts(repo, count=10):
# Necessary for CodeContext to work
repo.git.checkout(commit.parents[0].hexsha)
shown = subprocess.check_output(["git", "show", sha, "-m", "--first-parent"]).decode("utf-8")
if count_tokens(shown, "gpt-4") > 6000:
if session_context.llm_api_handler.spice.count_tokens(shown, "gpt-4") > 6000:
print("Skipping because too long")
continue

Expand Down Expand Up @@ -219,7 +218,6 @@ async def translate_commits_to_transcripts(repo, count=10):
code_context = CodeContext(stream, os.getcwd())
session_context = SessionContext(
stream,
CostTracker(),
Path.cwd(),
config,
code_context,
Expand Down
8 changes: 4 additions & 4 deletions scripts/sampler/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from add_context import add_context
from finetune import generate_finetune
from remove_context import remove_context
from spice import Spice
from validate import validate_sample

from mentat.llm_api_handler import count_tokens, prompt_tokens
from mentat.sampler.sample import Sample
from mentat.utils import mentat_dir_path

Expand Down Expand Up @@ -94,11 +94,11 @@ async def main():
elif args.finetune:
try:
example = await generate_finetune(sample)
# Toktoken only includes encoding for openAI models, so this isn't always correct
spice = Spice()
if "messages" in example:
tokens = prompt_tokens(example["messages"], "gpt-4")
tokens = spice.count_prompt_tokens(example["messages"], "gpt-4")
elif "text" in example:
tokens = count_tokens(example["text"], "gpt-4", full_message=False)
tokens = spice.count_tokens(example["text"], "gpt-4", is_message=False)
example["tokens"] = tokens
print("Generated finetune example" f" {sample.id[:8]} ({example['tokens']} tokens)")
logs.append(example)
Expand Down
4 changes: 2 additions & 2 deletions scripts/select_git_transcripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
from pathlib import Path

from mentat.llm_api import count_tokens, model_context_size
from spice.spice import get_model_from_name


def select_transcripts(
Expand Down Expand Up @@ -33,7 +33,7 @@ def select_transcripts(
continue
if skip_config and info["configuration"]:
continue
if count_tokens(json.dumps(info["mocked_conversation"]), model) > model_context_size(model):
if count_tokens(json.dumps(info["mocked_conversation"]), model) > get_model_from_name(model).context_length:
continue
transcripts.append(info["mocked_conversation"])

Expand Down
3 changes: 1 addition & 2 deletions tests/code_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from mentat.git_handler import get_non_gitignored_files
from mentat.include_files import is_file_text_encoded
from mentat.interval import Interval
from mentat.llm_api_handler import count_tokens
from tests.conftest import run_git_command


Expand Down Expand Up @@ -211,7 +210,7 @@ def func_4(string):

async def _count_max_tokens_where(tokens_used: int) -> int:
code_message = await code_context.get_code_message(tokens_used, prompt="prompt")
return count_tokens(code_message, "gpt-4", full_message=True)
return mock_session_context.llm_api_handler.spice.count_tokens(code_message, "gpt-4", is_message=True)

assert await _count_max_tokens_where(0) == 89 # Code

Expand Down
4 changes: 2 additions & 2 deletions tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,15 +435,15 @@ async def test_screenshot_command(mocker):
stream = session_context.stream
conversation = session_context.conversation

assert config.model != "gpt-4-vision-preview"
assert config.model != "gpt-4-turbo"

mock_vision_manager.screenshot.return_value = "fake_image_data"

screenshot_command = Command.create_command("screenshot")
await screenshot_command.apply("fake_path")

mock_vision_manager.screenshot.assert_called_once_with("fake_path")
assert config.model == "gpt-4-vision-preview"
assert config.model == "gpt-4-turbo"
assert stream.messages[-1].data == "Screenshot taken for: fake_path."
assert conversation._messages[-1] == {
"role": "user",
Expand Down
3 changes: 0 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ def mock_session_context(temp_testbed):
"""
stream = SessionStream()

cost_tracker = CostTracker()

config = Config()

llm_api_handler = LlmApiHandler()
Expand All @@ -203,7 +201,6 @@ def mock_session_context(temp_testbed):
Path.cwd(),
stream,
llm_api_handler,
cost_tracker,
config,
code_context,
code_file_manager,
Expand Down
33 changes: 0 additions & 33 deletions tests/llm_api_handler_test.py

This file was deleted.

0 comments on commit 08f172d

Please sign in to comment.