Skip to content

Commit

Permalink
fix: Pydantic downgraded to v1
Browse files Browse the repository at this point in the history
  • Loading branch information
RezaRahemtola committed Oct 18, 2024
1 parent 1e84b5e commit ab33714
Show file tree
Hide file tree
Showing 8 changed files with 1,112 additions and 1,591 deletions.
7 changes: 2 additions & 5 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
run: pipx install poetry
- uses: actions/setup-python@v5
with:
python-version: '3.12'
python-version: '3.11'
cache: 'poetry'
- name: Install dependencies
run: poetry install
Expand All @@ -33,11 +33,8 @@ jobs:
run: pipx install poetry
- uses: actions/setup-python@v5
with:
python-version: '3.12'
python-version: '3.11'
- name: Install dependencies
run: pip install ruff
- name: Run Ruff
run: ruff check --output-format=github



9 changes: 5 additions & 4 deletions libertai_agents/agents.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
from http import HTTPStatus
from typing import Callable, Awaitable, Any, AsyncIterable

Expand Down Expand Up @@ -112,7 +113,7 @@ async def __api_generate_answer(self, messages: list[Message], stream: bool = Fa
only_final_answer: bool = True):
"""
Generate an answer based on an existing conversation.
The response messages can be streamed or sent in a single block
The response messages can be streamed or sent in a single block.
"""
if stream:
return StreamingResponse(
Expand All @@ -134,7 +135,7 @@ async def __dump_api_generate_streamed_answer(self, messages: list[Message], onl
:return: Iterable of each messages from generate_answer dumped to JSON
"""
async for message in self.generate_answer(messages, only_final_answer=only_final_answer):
yield message.model_dump_json()
yield json.dumps(message.dict(), indent=4)

async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
"""
Expand All @@ -144,9 +145,9 @@ async def __call_model(self, session: ClientSession, prompt: str) -> str | None:
:param prompt: Prompt to give to the model
:return: String response (if no error)
"""
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.model_dump())
params = LlamaCppParams(prompt=prompt, **self.llamacpp_params.dict())

async with session.post(self.model.vm_url, json=params.model_dump()) as response:
async with session.post(self.model.vm_url, json=params.dict()) as response:
# TODO: handle errors and retries
if response.status == HTTPStatus.OK:
response_data = await response.json()
Expand Down
4 changes: 1 addition & 3 deletions libertai_agents/interfaces/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from pydantic import BaseModel

from libertai_agents.models.base import ModelId


class ModelInformation(BaseModel):
id: ModelId
id: str
context_length: int
2 changes: 1 addition & 1 deletion libertai_agents/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def generate_prompt(self, messages: list[Message], tools: list, system_prompt: s
"""
system_messages = [Message(role=MessageRoleEnum.system,
content=system_prompt)] if self.include_system_message and system_prompt is not None else []
raw_messages = list(map(lambda x: x.model_dump(), messages))
raw_messages = list(map(lambda x: x.dict(), messages))

for i in range(len(raw_messages)):
included_messages: list = system_messages + raw_messages[i:]
Expand Down
2 changes: 1 addition & 1 deletion libertai_agents/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ def get_model(model_id: ModelId, hf_token: str | None = None) -> Model:
if hf_token is not None:
login(hf_token)

return model_configuration.constructor(model_id=model_id, **model_configuration.model_dump(exclude={'constructor'}))
return model_configuration.constructor(model_id=model_id, **model_configuration.dict(exclude={'constructor'}))
20 changes: 11 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import asyncio

from libertai_agents.agents import ChatAgent
from libertai_agents.interfaces.messages import Message, MessageRoleEnum
from libertai_agents.models import get_model


Expand All @@ -17,14 +20,13 @@ async def get_current_temperature(location: str, unit: str) -> float:

agent = ChatAgent(model=get_model("NousResearch/Hermes-2-Pro-Llama-3-8B"),
system_prompt="You are a helpful assistant",
tools=[get_current_temperature])
tools=[get_current_temperature], expose_api=False)


async def main():
async for message in agent.generate_answer(
[Message(role=MessageRoleEnum.user, content="What is the temperature in Paris and in Lyon?")]):
print(message)

app = agent.app

# async def main():
# async for message in agent.generate_answer(
# [Message(role=MessageRoleEnum.user, content="What is the temperature in Paris and in Lyon?")]):
# print(message)
#
#
# asyncio.run(main())
asyncio.run(main())
2,643 changes: 1,082 additions & 1,561 deletions poetry.lock

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tool.poetry]
name = "libertai-agents"
version = "0.0.6"
version = "0.0.17"
description = ""
authors = ["Reza Rahemtola <[email protected]>"]
authors = ["LibertAI.io team <[email protected]>"]
readme = "README.md"
homepage = "https://libertai.io"
repository = "https://github.com/LibertAI/libertai-agents"
Expand All @@ -11,15 +11,17 @@ classifiers = [
"Operating System :: OS Independent",
"Intended Audience :: Developers",
"Development Status :: 2 - Pre-Alpha",
"Topic :: Scientific/Engineering :: Artificial Intelligence"
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Programming Language :: Python :: 3.11"
]

[tool.poetry.dependencies]
python = "^3.12"
python = "~3.11"
transformers = "^4.44.0"
pydantic = "^2.8.2"
aiohttp = "^3.10.3"
fastapi = { extras = ["standard"], version = "^0.112.2" }
pydantic = "^1.10"
aiohttp = "^3.10"
fastapi = "^0.112"
jinja2 = "^3.1.4"

[tool.poetry.group.dev.dependencies]
mypy = "^1.11.1"
Expand Down

0 comments on commit ab33714

Please sign in to comment.