diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index efd3fd1417f82..d750851f468f4 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -18,7 +18,7 @@ check_valid_template, get_template_variables, ) -from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator +from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator class _FewShotPromptTemplateMixin(BaseModel): @@ -135,6 +135,12 @@ def is_lc_serializable(cls) -> bool: template_format: Literal["f-string", "jinja2"] = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + def __init__(self, **kwargs: Any) -> None: + """Initialize the few shot prompt template.""" + if "input_variables" not in kwargs and "example_prompt" in kwargs: + kwargs["input_variables"] = kwargs["example_prompt"].input_variables + super().__init__(**kwargs) + @root_validator(pre=False, skip_on_failure=True) def template_is_valid(cls, values: Dict) -> Dict: """Check that prefix, suffix, and input variables are consistent.""" @@ -351,14 +357,18 @@ class FewShotChatMessagePromptTemplate( chain.invoke({"input": "What's 3+3?"}) """ + input_variables: List[str] = Field(default_factory=list) + """A list of the names of the variables the prompt template will use + to pass to the example_selector, if provided.""" + + example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate] + """The class to format each example.""" + @classmethod def is_lc_serializable(cls) -> bool: """Return whether or not the class is serializable.""" return False - example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate] - """The class to format each example.""" - class Config: """Configuration for this pydantic object.""" diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 722a5ee5be34f..8c3cc523c23cf 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -58,6 +58,17 @@ def test_suffix_only() -> None: assert output == expected_output +def test_auto_infer_input_variables() -> None: + """Test prompt works with just a suffix.""" + suffix = "This is a {foo} test." + prompt = FewShotPromptTemplate( + suffix=suffix, + examples=[], + example_prompt=EXAMPLE_PROMPT, + ) + assert prompt.input_variables == ["foo"] + + def test_prompt_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" # Test when missing in suffix @@ -422,6 +433,30 @@ def test_few_shot_chat_message_prompt_template_with_selector() -> None: assert messages == expected +def test_few_shot_chat_message_prompt_template_infer_input_variables() -> None: + """Check that it can infer input variables if not provided.""" + examples = [ + {"input": "2+2", "output": "4"}, + {"input": "2+3", "output": "5"}, + ] + example_selector = AsIsSelector(examples) + example_prompt = ChatPromptTemplate.from_messages( + [ + HumanMessagePromptTemplate.from_template("{input}"), + AIMessagePromptTemplate.from_template("{output}"), + ] + ) + + few_shot_prompt = FewShotChatMessagePromptTemplate( + example_prompt=example_prompt, + example_selector=example_selector, + ) + + # The prompt template does not have any inputs! They + # have already been filled in. + assert few_shot_prompt.input_variables == [] + + class AsyncAsIsSelector(BaseExampleSelector): """An example selector for testing purposes.