Skip to content

Commit

Permalink
Merge pull request #12 from andrewyng/add-tests-for-mistral-provider
Browse files Browse the repository at this point in the history
  • Loading branch information
ksolo authored Jul 16, 2024
2 parents 5527a07 + 4b1e40a commit 5521ac6
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tests/providers/test_mistral_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from unittest.mock import patch, MagicMock

from mistralai.models.chat_completion import ChatMessage

from aimodels.providers.mistral_interface import MistralInterface


@pytest.fixture(autouse=True)
def set_api_key_env_var(monkeypatch):
"""Fixture to set environment variables for tests."""
monkeypatch.setenv("MISTRAL_API_KEY", "test-api-key")


def test_mistral_interface():
"""High-level test that the interface is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
chosen_temperature = 0.75
response_text_content = "mocked-text-response-from-model"

interface = MistralInterface()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message = MagicMock()
mock_response.choices[0].message.content = response_text_content

with patch.object(
interface.mistral_client, "chat", return_value=mock_response
) as mock_create:
response = interface.chat_completion_create(
messages=message_history,
model=selected_model,
temperature=chosen_temperature,
)

transformed_message_history = [
ChatMessage(role=message["role"], content=message["content"])
for message in message_history
]

mock_create.assert_called_with(
messages=transformed_message_history,
model=selected_model,
temperature=chosen_temperature,
)

assert response.choices[0].message.content == response_text_content

0 comments on commit 5521ac6

Please sign in to comment.