From 70255a078dff1a7e93aefcddca8e1aae090c4c09 Mon Sep 17 00:00:00 2001 From: Reza Rahemtola Date: Fri, 6 Dec 2024 16:20:50 +0900 Subject: [PATCH] style: fix ruff issues and enforce in CI --- libertai_agents/libertai_agents/__init__.py | 2 +- libertai_agents/libertai_agents/agents.py | 4 ++-- libertai_agents/libertai_agents/interfaces/tools.py | 2 +- libertai_agents/libertai_agents/models/base.py | 6 +++--- libertai_agents/libertai_agents/utils.py | 2 +- libertai_agents/pyproject.toml | 4 ++++ 6 files changed, 12 insertions(+), 8 deletions(-) diff --git a/libertai_agents/libertai_agents/__init__.py b/libertai_agents/libertai_agents/__init__.py index 25bcfc6..5d8b595 100644 --- a/libertai_agents/libertai_agents/__init__.py +++ b/libertai_agents/libertai_agents/__init__.py @@ -1,4 +1,4 @@ import logging # Disables the error about frameworks not installed -logging.getLogger("transformers").disabled = True \ No newline at end of file +logging.getLogger("transformers").disabled = True diff --git a/libertai_agents/libertai_agents/agents.py b/libertai_agents/libertai_agents/agents.py index 27876d7..2ac94c7 100644 --- a/libertai_agents/libertai_agents/agents.py +++ b/libertai_agents/libertai_agents/agents.py @@ -2,7 +2,7 @@ import inspect import json from http import HTTPStatus -from typing import Awaitable, Any, AsyncIterable +from typing import Any, AsyncIterable, Awaitable import aiohttp from aiohttp import ClientSession @@ -56,7 +56,7 @@ def __init__( if tools is None: tools = [] - if len(set(map(lambda x: x.name, tools))) != len(tools): + if len({x.name for x in tools}) != len(tools): raise ValueError("Tool functions must have different names") self.model = model self.system_prompt = system_prompt diff --git a/libertai_agents/libertai_agents/interfaces/tools.py b/libertai_agents/libertai_agents/interfaces/tools.py index 7791732..a5105e6 100644 --- a/libertai_agents/libertai_agents/interfaces/tools.py +++ b/libertai_agents/libertai_agents/interfaces/tools.py @@ -1,4 +1,4 @@ -from typing import Callable, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable from pydantic.json_schema import GenerateJsonSchema, JsonSchemaMode, JsonSchemaValue from pydantic.v1 import BaseModel diff --git a/libertai_agents/libertai_agents/models/base.py b/libertai_agents/libertai_agents/models/base.py index bdc747c..4b721d1 100644 --- a/libertai_agents/libertai_agents/models/base.py +++ b/libertai_agents/libertai_agents/models/base.py @@ -3,8 +3,8 @@ from libertai_agents.interfaces.messages import ( Message, - ToolCallFunction, MessageRoleEnum, + ToolCallFunction, ) from libertai_agents.interfaces.tools import Tool @@ -74,13 +74,13 @@ def generate_prompt( if self.include_system_message and system_prompt is not None else [] ) - raw_messages = list(map(lambda x: x.dict(), messages)) + raw_messages = [x.dict() for x in messages] for i in range(len(raw_messages)): included_messages: list = system_messages + raw_messages[i:] prompt = self.tokenizer.apply_chat_template( conversation=included_messages, - tools=list(map(lambda x: x.args_schema, tools)), + tools=[x.args_schema for x in tools], tokenize=False, add_generation_prompt=True, ) diff --git a/libertai_agents/libertai_agents/utils.py b/libertai_agents/libertai_agents/utils.py index 4661912..19a89c9 100644 --- a/libertai_agents/libertai_agents/utils.py +++ b/libertai_agents/libertai_agents/utils.py @@ -1,4 +1,4 @@ -from typing import TypeVar, Callable +from typing import Callable, TypeVar T = TypeVar("T") diff --git a/libertai_agents/pyproject.toml b/libertai_agents/pyproject.toml index 21cb027..4e44d10 100644 --- a/libertai_agents/pyproject.toml +++ b/libertai_agents/pyproject.toml @@ -33,6 +33,10 @@ pytest-cov = "^6.0.0" [tool.poetry.extras] langchain = ["langchain-community"] +[tool.ruff] +lint.select = ["C", "E", "F", "I", "W"] +lint.ignore = ["E501"] + [tool.pytest.ini_options] addopts = "--cov=libertai_agents" testpaths = ["tests"]