Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic tests #15

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions codecov.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
coverage:
status:
patch:
default:
threshold: 5
target: 0%
project:
default:
threshold: 5
target: 60%
8 changes: 5 additions & 3 deletions libertai_agents/libertai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@
async for message in self.generate_answer(
messages, only_final_answer=only_final_answer
):
yield json.dumps(message.dict(), indent=4)
yield json.dumps(message.model_dump(), indent=4)

Check warning on line 182 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L182

Added line #L182 was not covered by tests

async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
"""
Expand All @@ -189,9 +189,11 @@
:param prompt: Prompt to give to the model
:return: String response (if no error)
"""
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.dict())
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.model_dump())

Check warning on line 192 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L192

Added line #L192 was not covered by tests

async with session.post(self.model.vm_url, json=params.dict()) as response:
async with session.post(

Check warning on line 194 in libertai_agents/libertai_agents/agents.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/agents.py#L194

Added line #L194 was not covered by tests
self.model.vm_url, json=params.model_dump()
) as response:
# TODO: handle errors and retries
if response.status == HTTPStatus.OK:
response_data = await response.json()
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
if self.include_system_message and system_prompt is not None
else []
)
raw_messages = [x.dict() for x in messages]
raw_messages = [x.model_dump() for x in messages]

Check warning on line 77 in libertai_agents/libertai_agents/models/base.py

View check run for this annotation

Codecov / codecov/patch

libertai_agents/libertai_agents/models/base.py#L77

Added line #L77 was not covered by tests

for i in range(len(raw_messages)):
included_messages: list = system_messages + raw_messages[i:]
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ def get_model(
)

return full_config.constructor(
model_id=model_id, **configuration.dict(exclude={"constructor"})
model_id=model_id, **configuration.model_dump(exclude={"constructor"})
)
1 change: 1 addition & 0 deletions libertai_agents/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.fixtures.fixtures_tools import * # noqa: F401, F403
1 change: 1 addition & 0 deletions libertai_agents/tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

20 changes: 20 additions & 0 deletions libertai_agents/tests/fixtures/fixtures_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Callable

import pytest


@pytest.fixture()
def basic_function_for_tool() -> Callable:
def get_current_temperature(location: str, unit: str) -> float:
"""
Get the current temperature at a location.

Args:
location: The location to get the temperature for, in the format "City, Country"
unit: The unit to return the temperature in. (choices: ["celsius", "fahrenheit"])
Returns:
The current temperature at the specified location in the specified units, as a float.
"""
return 22.0 # A real function should probably actually get the temperature!

return get_current_temperature
35 changes: 35 additions & 0 deletions libertai_agents/tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from fastapi import FastAPI

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces.tools import Tool
from libertai_agents.models import get_model
from libertai_agents.models.base import ModelId
from libertai_agents.models.models import ModelConfiguration

MODEL_ID: ModelId = "NousResearch/Hermes-3-Llama-3.1-8B"


def test_create_chat_agent_minimal():
agent = ChatAgent(model=get_model(MODEL_ID))

assert len(agent.tools) == 0
assert agent.model.model_id == MODEL_ID
assert isinstance(agent.app, FastAPI)


def test_create_chat_agent_with_config(basic_function_for_tool):
context_length = 42

agent = ChatAgent(
model=get_model(
MODEL_ID,
custom_configuration=ModelConfiguration(
vm_url="https://example.org", context_length=context_length
),
),
tools=[Tool.from_function(basic_function_for_tool)],
expose_api=False,
)
assert agent.model.context_length == context_length
assert not hasattr(agent, "app")
assert len(agent.tools) == 1
14 changes: 14 additions & 0 deletions libertai_agents/tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest

from libertai_agents.models import Model, get_model


def test_get_model_basic():
model = get_model("NousResearch/Hermes-3-Llama-3.1-8B")

assert isinstance(model, Model)


def test_get_model_invalid_id():
with pytest.raises(ValueError):
_model = get_model(model_id="random-string") # type: ignore
9 changes: 9 additions & 0 deletions libertai_agents/tests/tools/test_function_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from libertai_agents.interfaces.tools import Tool


def test_function_example_tool(basic_function_for_tool):
libertai_tool = Tool.from_function(basic_function_for_tool)
assert libertai_tool.name == basic_function_for_tool.__name__


# TODO: add test with Python 3.10+ union style when https://github.com/huggingface/transformers/pull/35103 merged + new release
6 changes: 6 additions & 0 deletions libertai_agents/tests/tools/test_langchain_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from libertai_agents.interfaces.tools import Tool
from libertai_agents.utils import find

# TODO: uncomment when https://huggingface.co/spaces/lysandre/hf-model-downloads/discussions/1 is merged
# def test_langchain_huggingface_hub_tool():
# # https://python.langchain.com/docs/integrations/tools/huggingface_tools/
# tool = load_huggingface_tool("lysandre/hf-model-downloads")
Expand Down Expand Up @@ -73,3 +74,8 @@ def test_langchain_requests_tools():
libertai_tools = [Tool.from_langchain(t) for t in tools]
get_tool = find(lambda t: t.name == "requests_get", libertai_tools)
assert get_tool is not None


# TODO: add tests for the following tools:
# https://python.langchain.com/docs/integrations/tools/nasa/
# https://python.langchain.com/docs/integrations/tools/openweathermap/
Loading