Skip to content

Commit

Permalink
core[patch]: Fix regression requiring input_variables in few chat pro…
Browse files Browse the repository at this point in the history
…mpt templates (langchain-ai#24360)

* Fix regression that requires users passing input_variables=[].

* Regression introduced by my own changes to this PR:
langchain-ai#22851
  • Loading branch information
eyurtsev authored Jul 17, 2024
1 parent 034a8c7 commit 96bac8e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
18 changes: 14 additions & 4 deletions libs/core/langchain_core/prompts/few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
check_valid_template,
get_template_variables,
)
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator


class _FewShotPromptTemplateMixin(BaseModel):
Expand Down Expand Up @@ -135,6 +135,12 @@ def is_lc_serializable(cls) -> bool:
template_format: Literal["f-string", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""

def __init__(self, **kwargs: Any) -> None:
"""Initialize the few shot prompt template."""
if "input_variables" not in kwargs and "example_prompt" in kwargs:
kwargs["input_variables"] = kwargs["example_prompt"].input_variables
super().__init__(**kwargs)

@root_validator(pre=False, skip_on_failure=True)
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
Expand Down Expand Up @@ -351,14 +357,18 @@ class FewShotChatMessagePromptTemplate(
chain.invoke({"input": "What's 3+3?"})
"""

input_variables: List[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template will use
to pass to the example_selector, if provided."""

example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
"""The class to format each example."""

@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether or not the class is serializable."""
return False

example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
"""The class to format each example."""

class Config:
"""Configuration for this pydantic object."""

Expand Down
35 changes: 35 additions & 0 deletions libs/core/tests/unit_tests/prompts/test_few_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def test_suffix_only() -> None:
assert output == expected_output


def test_auto_infer_input_variables() -> None:
"""Test prompt works with just a suffix."""
suffix = "This is a {foo} test."
prompt = FewShotPromptTemplate(
suffix=suffix,
examples=[],
example_prompt=EXAMPLE_PROMPT,
)
assert prompt.input_variables == ["foo"]


def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
# Test when missing in suffix
Expand Down Expand Up @@ -422,6 +433,30 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None:
assert messages == expected


def test_few_shot_chat_message_prompt_template_infer_input_variables() -> None:
"""Check that it can infer input variables if not provided."""
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
example_selector = AsIsSelector(examples)
example_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template("{input}"),
AIMessagePromptTemplate.from_template("{output}"),
]
)

few_shot_prompt = FewShotChatMessagePromptTemplate(
example_prompt=example_prompt,
example_selector=example_selector,
)

# The prompt template does not have any inputs! They
# have already been filled in.
assert few_shot_prompt.input_variables == []


class AsyncAsIsSelector(BaseExampleSelector):
"""An example selector for testing purposes.
Expand Down

0 comments on commit 96bac8e

Please sign in to comment.