From 44935ff9618a2ef542bf630de6fdbe96bd11d971 Mon Sep 17 00:00:00 2001 From: NotBioWaste905 Date: Wed, 29 Jan 2025 12:34:58 +0300 Subject: [PATCH] Added docstring for get_langchain_context, lint --- chatsky/llm/filters.py | 4 +++- chatsky/llm/langchain_context.py | 33 ++++++++++++++------------------ chatsky/llm/llm_api.py | 4 ++-- chatsky/responses/llm.py | 6 +++--- chatsky/slots/llm.py | 1 + 5 files changed, 23 insertions(+), 25 deletions(-) diff --git a/chatsky/llm/filters.py b/chatsky/llm/filters.py index 435d42398..d584dd244 100644 --- a/chatsky/llm/filters.py +++ b/chatsky/llm/filters.py @@ -29,7 +29,9 @@ class BaseHistoryFilter(BaseModel, abc.ABC): """ @abc.abstractmethod - def call(self, ctx: Context, request: Optional[Message], response: Optional[Message], llm_model_name: str) -> Union[Return, int]: + def call( + self, ctx: Context, request: Optional[Message], response: Optional[Message], llm_model_name: str + ) -> Union[Return, int]: """ :param ctx: Context object. :param request: Request message. diff --git a/chatsky/llm/langchain_context.py b/chatsky/llm/langchain_context.py index deda73f43..de458b270 100644 --- a/chatsky/llm/langchain_context.py +++ b/chatsky/llm/langchain_context.py @@ -7,10 +7,8 @@ import re import logging from typing import Literal, Union -from pydantic import validate_call -from chatsky.core import AnyResponse, Context, Message -from chatsky.core.script_function import ConstResponse +from chatsky.core import Context, Message from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available from chatsky.llm.filters import BaseHistoryFilter from chatsky.llm.prompt import Prompt, PositionConfig @@ -18,6 +16,7 @@ logger = logging.getLogger(__name__) logger.debug("Loaded LLM Utils logger.") + async def message_to_langchain( message: Message, ctx: Context, source: Literal["human", "ai", "system"] = "human", max_size: int = 1000 ) -> Union[HumanMessage, AIMessage, SystemMessage]: @@ -64,22 +63,7 @@ async def context_to_history( :return: List of Langchain message objects. """ history = [] - - # pairs = zip( - # [ctx.requests[x] for x in range(1, len(ctx.requests) + 1)], - # [ctx.responses[x] for x in range(1, len(ctx.responses) + 1)], - # ) - # pairs_list = list(pairs) - # filtered_pairs = filter( - # lambda x: filter_func(ctx, x[0], x[1], llm_model_name), pairs_list[-length:] if length != -1 else pairs_list - # ) - - # for req, resp in filtered_pairs: - # logger.debug(f"This pair is valid: {req, resp}") - # history.append(await message_to_langchain(req, ctx=ctx, max_size=max_size)) - # history.append(await message_to_langchain(resp, ctx=ctx, source="ai", max_size=max_size)) - - indices = range(1, min(max([*ctx.requests.keys(), 0]), max([*ctx.responses.keys(), 0]))+1) + indices = range(1, min(max([*ctx.requests.keys(), 0]), max([*ctx.responses.keys(), 0])) + 1) if length == 0: return [] @@ -112,6 +96,17 @@ async def get_langchain_context( ) -> list[HumanMessage | AIMessage | SystemMessage]: """ Get a list of Langchain messages using the context and prompts. + + :param system_prompt: System message to be included in the context. + :param ctx: Current dialog context. + :param call_prompt: Prompt to be used for the current call. + :param prompt_misc_filter: Regex pattern to filter miscellaneous prompts from context. + Defaults to r"prompt". + :param position_config: Configuration for positioning different parts of the context. + Defaults to default PositionConfig(). + :param history_args: Additional arguments to be passed to context_to_history function. + + :return: List of Langchain message objects ordered by their position values. """ logger.debug(f"History args: {history_args}") diff --git a/chatsky/llm/llm_api.py b/chatsky/llm/llm_api.py index 4261cc4fb..171392ee8 100644 --- a/chatsky/llm/llm_api.py +++ b/chatsky/llm/llm_api.py @@ -4,7 +4,7 @@ Wrapper around langchain. """ -from typing import Union, Type, Optional +from typing import Union, Type from pydantic import BaseModel, TypeAdapter import logging from chatsky.core.message import Message @@ -26,7 +26,7 @@ def __init__( self, model: BaseChatModel, system_prompt: Union[AnyResponse, MessageInitTypes] = "", - position_config: PositionConfig = None + position_config: PositionConfig = None, ) -> None: """ :param model: Model object diff --git a/chatsky/responses/llm.py b/chatsky/responses/llm.py index 5bd7357d5..8c5f6da01 100644 --- a/chatsky/responses/llm.py +++ b/chatsky/responses/llm.py @@ -11,11 +11,11 @@ from chatsky.core.message import Message from chatsky.core.context import Context -from chatsky.llm.langchain_context import message_to_langchain, context_to_history, get_langchain_context +from chatsky.llm.langchain_context import get_langchain_context from chatsky.llm._langchain_imports import check_langchain_available from chatsky.llm.filters import BaseHistoryFilter, DefaultFilter from chatsky.llm.prompt import Prompt, PositionConfig -from chatsky.core.script_function import BaseResponse, AnyResponse +from chatsky.core.script_function import BaseResponse class LLMResponse(BaseResponse): @@ -77,7 +77,7 @@ async def call(self, ctx: Context) -> Message: max_size=self.max_size, ) ) - + logging.debug(f"History: {history_messages}") result = await model.respond(history_messages, message_schema=self.message_schema) diff --git a/chatsky/slots/llm.py b/chatsky/slots/llm.py index 3cda99d47..0c9a1352a 100644 --- a/chatsky/slots/llm.py +++ b/chatsky/slots/llm.py @@ -26,6 +26,7 @@ class LLMSlot(ValueSlot, frozen=True): LLMSlot is a slot type that extract information described in `caption` parameter using LLM. """ + # TODO: # add history (and overall update the class)