Skip to content

Commit

Permalink
add a test mocking ollama server
Browse files Browse the repository at this point in the history
  • Loading branch information
ricopinazo committed May 6, 2024
1 parent 37cacc7 commit e0a0d26
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/comai/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def attatch_history(

def create_chain_stream(settings: Settings, context: Context):
prompt = create_prompt(context)
model = ChatOllama(model=settings.model)
model = ChatOllama(model=settings.model, temperature=0)
base_chain = prompt | model

if context.session_id is not None:
Expand Down
40 changes: 11 additions & 29 deletions tests/test_comai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
from typer.testing import CliRunner

from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_community.chat_models import ChatOllama

# from dotenv import load_dotenv
from comai import cli, context, __version__

Expand All @@ -12,26 +15,16 @@
runner = CliRunner()


# def test_invalid_api_key():
# config.delete_api_key()

# result = runner.invoke(cli.app, ["show", "files"], input="bad-api-key\n")
# assert result.exit_code != 0
# assert "API key not valid" in result.stdout


# def test_installation_flow():
# config.delete_api_key()
def test_normal_flow(monkeypatch):
def mock_stream(*args, **kwargs):
for token in ["COMMAND", " ls", " END"]:
yield AIMessageChunk(content=token)

# result = runner.invoke(cli.app, ["show", "files"], input=f"{api_key}\n\n")
# assert result.exit_code == 0
# assert "Please enter your OpenAI API key:" in result.stdout
# assert "ls" in result.stdout
monkeypatch.setattr(ChatOllama, "stream", mock_stream)

# result = runner.invoke(cli.app, ["show", "files"], input="\n")
# assert result.exit_code == 0
# assert "Please enter your OpenAI API key:" not in result.stdout
# assert "ls" in result.stdout
result = runner.invoke(cli.app, ["show", "files"])
assert result.exit_code == 0
assert "ls" in result.stdout


def test_version():
Expand All @@ -43,14 +36,3 @@ def test_version():
def test_missing_instruction():
result = runner.invoke(cli.app, [])
assert result.exit_code != 0


# TODO: create a mock ollama server that always returns "COMMAND ls END"
# def test_translation():
# ctx = context.get_context()
# history = History.create_local()
# history.append_user_message("show files")

# answer = translation.request_command(history, ctx, api_key)
# command = translation.filter_assistant_message(answer)
# assert "".join(command) == "ls"

0 comments on commit e0a0d26

Please sign in to comment.