Skip to content

Commit

Permalink
Added docstring for get_langchain_context, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Jan 29, 2025
1 parent 3c0fe22 commit 44935ff
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 25 deletions.
4 changes: 3 additions & 1 deletion chatsky/llm/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ class BaseHistoryFilter(BaseModel, abc.ABC):
"""

@abc.abstractmethod
def call(self, ctx: Context, request: Optional[Message], response: Optional[Message], llm_model_name: str) -> Union[Return, int]:
def call(
self, ctx: Context, request: Optional[Message], response: Optional[Message], llm_model_name: str
) -> Union[Return, int]:
"""
:param ctx: Context object.
:param request: Request message.
Expand Down
33 changes: 14 additions & 19 deletions chatsky/llm/langchain_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
import re
import logging
from typing import Literal, Union
from pydantic import validate_call

from chatsky.core import AnyResponse, Context, Message
from chatsky.core.script_function import ConstResponse
from chatsky.core import Context, Message
from chatsky.llm._langchain_imports import HumanMessage, SystemMessage, AIMessage, check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter
from chatsky.llm.prompt import Prompt, PositionConfig

logger = logging.getLogger(__name__)
logger.debug("Loaded LLM Utils logger.")


async def message_to_langchain(
message: Message, ctx: Context, source: Literal["human", "ai", "system"] = "human", max_size: int = 1000
) -> Union[HumanMessage, AIMessage, SystemMessage]:
Expand Down Expand Up @@ -64,22 +63,7 @@ async def context_to_history(
:return: List of Langchain message objects.
"""
history = []

# pairs = zip(
# [ctx.requests[x] for x in range(1, len(ctx.requests) + 1)],
# [ctx.responses[x] for x in range(1, len(ctx.responses) + 1)],
# )
# pairs_list = list(pairs)
# filtered_pairs = filter(
# lambda x: filter_func(ctx, x[0], x[1], llm_model_name), pairs_list[-length:] if length != -1 else pairs_list
# )

# for req, resp in filtered_pairs:
# logger.debug(f"This pair is valid: {req, resp}")
# history.append(await message_to_langchain(req, ctx=ctx, max_size=max_size))
# history.append(await message_to_langchain(resp, ctx=ctx, source="ai", max_size=max_size))

indices = range(1, min(max([*ctx.requests.keys(), 0]), max([*ctx.responses.keys(), 0]))+1)
indices = range(1, min(max([*ctx.requests.keys(), 0]), max([*ctx.responses.keys(), 0])) + 1)

if length == 0:
return []
Expand Down Expand Up @@ -112,6 +96,17 @@ async def get_langchain_context(
) -> list[HumanMessage | AIMessage | SystemMessage]:
"""
Get a list of Langchain messages using the context and prompts.
:param system_prompt: System message to be included in the context.
:param ctx: Current dialog context.
:param call_prompt: Prompt to be used for the current call.
:param prompt_misc_filter: Regex pattern to filter miscellaneous prompts from context.
Defaults to r"prompt".
:param position_config: Configuration for positioning different parts of the context.
Defaults to default PositionConfig().
:param history_args: Additional arguments to be passed to context_to_history function.
:return: List of Langchain message objects ordered by their position values.
"""
logger.debug(f"History args: {history_args}")

Expand Down
4 changes: 2 additions & 2 deletions chatsky/llm/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Wrapper around langchain.
"""

from typing import Union, Type, Optional
from typing import Union, Type
from pydantic import BaseModel, TypeAdapter
import logging
from chatsky.core.message import Message
Expand All @@ -26,7 +26,7 @@ def __init__(
self,
model: BaseChatModel,
system_prompt: Union[AnyResponse, MessageInitTypes] = "",
position_config: PositionConfig = None
position_config: PositionConfig = None,
) -> None:
"""
:param model: Model object
Expand Down
6 changes: 3 additions & 3 deletions chatsky/responses/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

from chatsky.core.message import Message
from chatsky.core.context import Context
from chatsky.llm.langchain_context import message_to_langchain, context_to_history, get_langchain_context
from chatsky.llm.langchain_context import get_langchain_context
from chatsky.llm._langchain_imports import check_langchain_available
from chatsky.llm.filters import BaseHistoryFilter, DefaultFilter
from chatsky.llm.prompt import Prompt, PositionConfig
from chatsky.core.script_function import BaseResponse, AnyResponse
from chatsky.core.script_function import BaseResponse


class LLMResponse(BaseResponse):
Expand Down Expand Up @@ -77,7 +77,7 @@ async def call(self, ctx: Context) -> Message:
max_size=self.max_size,
)
)

logging.debug(f"History: {history_messages}")
result = await model.respond(history_messages, message_schema=self.message_schema)

Expand Down
1 change: 1 addition & 0 deletions chatsky/slots/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class LLMSlot(ValueSlot, frozen=True):
LLMSlot is a slot type that extract information described in
`caption` parameter using LLM.
"""

# TODO:
# add history (and overall update the class)

Expand Down

0 comments on commit 44935ff

Please sign in to comment.