Skip to content

Commit

Permalink
Fixed appending empty strings + wrong prompt positions in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBioWaste905 committed Jan 31, 2025
1 parent 32eae7d commit 6440eaf
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
4 changes: 4 additions & 0 deletions chatsky/conditions/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class LLMCondition(BaseCondition):
"""
Filter function to filter messages that will go the models context.
"""
prompt_misc_filter: str = Field(default=r"prompt")
"""
idk
"""
max_size: int = 1000
"""
Maximum size of any message in chat in symbols. If exceed the limit will raise ValueError.
Expand Down
8 changes: 5 additions & 3 deletions chatsky/llm/langchain_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,10 @@ async def get_langchain_context(
history = await context_to_history(ctx, **history_args)
logger.debug(f"Position config: {position_config}")
prompts: list[tuple[list[Union[HumanMessage, AIMessage, SystemMessage]], float]] = [
([await message_to_langchain(system_prompt, ctx, source="system")], position_config.system_prompt),
(history, position_config.history),
]
if system_prompt.text != '':
prompts.append(([await message_to_langchain(system_prompt, ctx, source="system")], position_config.system_prompt))

logger.debug(f"System prompt: {prompts[0]}")

Expand All @@ -130,8 +131,9 @@ async def get_langchain_context(
prompts.append(([prompt_langchain_message], prompt.position))

call_prompt_text = await call_prompt.message(ctx)
call_prompt_message = await message_to_langchain(call_prompt_text, ctx, source="human")
prompts.append(([call_prompt_message], call_prompt.position or position_config.call_prompt))
if call_prompt_text.text != '':
call_prompt_message = await message_to_langchain(call_prompt_text, ctx, source="human")
prompts.append(([call_prompt_message], call_prompt.position or position_config.call_prompt))

prompts.append(([await message_to_langchain(ctx.last_request, ctx, source="human")], position_config.last_request))

Expand Down
1 change: 1 addition & 0 deletions chatsky/slots/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async def get_value(self, ctx: Context) -> ExtractedGroupSlot:
result_json = result.model_dump()
logger.debug(f"Result JSON: {result_json}")

# TODO: un-flatten the dict with child.names.like.this
res = {
name: ExtractedValueSlot.model_construct(is_slot_extracted=True, extracted_value=result_json[name])
for name in result_json
Expand Down
8 changes: 4 additions & 4 deletions tests/llm/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,17 +207,17 @@ async def test_message_to_langchain(context):
[
(
2,
"Mock response with history: ['prompt', 'Request 1', 'Response 1', "
"'Request 2', 'Response 2', 'Last request']",
"Mock response with history: ['Request 1', 'Response 1', "
"'Request 2', 'Response 2', 'prompt', 'Last request']",
),
(
0,
"Mock response with history: ['prompt', 'Last request']",
),
(
4,
"Mock response with history: ['prompt', 'Request 0', 'Response 0', "
"'Request 1', 'Response 1', 'Request 2', 'Response 2', 'Last request']",
"Mock response with history: ['Request 0', 'Response 0', "
"'Request 1', 'Response 1', 'Request 2', 'Response 2', 'prompt', 'Last request']",
),
],
)
Expand Down

0 comments on commit 6440eaf

Please sign in to comment.