Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created a TypedPredictorSignature class that builds a signature from Pydantic models - this signature is optimized for use with TypedPredictor and TypedChainOfThought #1655

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

drawal1
Copy link
Contributor

@drawal1 drawal1 commented Oct 20, 2024

📝 Changes Description

This MR/PR contains the following changes:

Created a TypedPredictorSignature class with a single function called create that takes the pydantic classes that define input and output fields and builds a signature optimized to use with the dspy.TypedPredictor to extract structured information from the input.

The advantages of this implementation are:

  1. It significantly reduces the chances of the dreaded "retries..." exception
  2. If default value is specified, predictor will now return default value if the information cannot be extracted
  3. For fields with constraints, an invalid value can be specified along with a validator function (mode="wrap"), and the predictor will return the invalid value if the extracted information does not satisfy the constraints
  4. For fields defined as Optional, the predictor will return null if the information cannot be extracted
  5. Field description and example values will be used to construct the signature for better prediction, if they are specified in the pydantic model
  6. Prefix text can be specified optionally, and it will be used to construct the prompt
  7. Using enum fields in pydantic models will no longer generate parse errors

Here is example usage:

    class CommandExtractionInput(BaseModel):
        command: str

    class OutputParamsSchema(BaseModel):
        @field_validator('name', mode='wrap')
        @staticmethod
        def validate_name(name, handler):
            try:
                return handler(name)
            except ValidationError:
                return 'INVALID_NAME'

        @field_validator('age1', 'age2', 'age3', 'age4', 'age5', 'age6', mode='wrap')
        @staticmethod
        def validate_age(age, handler):
            try:
                return handler(age)
            except ValidationError:
                return -8888

        @field_validator('email1', 'email2', mode='wrap')
        @staticmethod
        def validate_email(email, handler):
            try:
                return handler(email)
            except ValidationError:
                return 'INVALID_EMAIL'

        name: Annotated[str,
                        Field(default='NOT_FOUND', max_length=15,
                            title='Name', description='The name of the person',
                            examples=['John Doe', 'Jane Doe'], json_schema_extra={'invalid_value': 'INVALID_NAME'})
                    ]
        age1: Annotated[int, 
                       Field(gt=0, lt=150, default=-999, json_schema_extra={'invalid_value': '-8888'})]
        age2: Annotated[int, 
                       Field(gt=0, lt=150, json_schema_extra={'invalid_value': '-8888'})] = -999
        age3: Optional[Annotated[int, 
                       Field(gt=0, lt=150, json_schema_extra={'invalid_value': '-8888'})]]

        age4: Annotated[int, 
                       Field(gt=0, lt=150, default=-999, json_schema_extra={'invalid_value': '-8888'})]
        age5: Annotated[int, 
                       Field(gt=0, lt=150, json_schema_extra={'invalid_value': '-8888'})] = -999
        age6: Optional[Annotated[int, 
                       Field(gt=0, lt=150, json_schema_extra={'invalid_value': '-8888'})]]

        email1: Annotated[str, 
                         Field(default='NOT_FOUND', 
                            pattern=r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
                            json_schema_extra={'invalid_value': 'INVALID_EMAIL'})
                    ]
        email2: Annotated[str, 
                         Field(default='NOT_FOUND', 
                            pattern=r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$'),
                            json_schema_extra={'invalid_value': 'INVALID_EMAIL'}
                    ]

    dspy_signature_class = TypedPredictorSignature.create(
        CommandExtractionInput, OutputParamsSchema)

    lm = dspy.LM('openai/gpt-3.5-turbo')
    with dspy.context(lm=lm):
        extract_cmd_params = dspy.TypedChainOfThought(
            dspy_signature_class)

        input_for_parameter_extraction = CommandExtractionInput(
            # command = "A random command."
            # command = "My name is kjhd and I am 200 years old. My email is 9236"
            command = "Hello, my name is John Doe and I am 25 years old. My email is [email protected]."
        )
        prediction = extract_cmd_params(**input_for_parameter_extraction.model_dump())

        dspy.inspect_history(n=1)
        print(prediction)

...

✅ Contributor Checklist

  • Pre-Commit checks are passing (locally and remotely)
  • Title of your PR / MR corresponds to the required format
  • [] Commit message follows required format {label}(dspy): {message}

⚠️ Warnings

Anything we should be aware of ?

…utput classes. It takes examples, constraints, defaults and invalid value specifications into account when constructing the signature
…g invalid value. Specifying a default of null for optional age field
…ing field.default in such cases if its not specified
@drawal1 drawal1 changed the title Merge branch 'main' of https://github.com/stanfordnlp/dspy Created a TypedPredictorSignature class that builds a signature from Pydantic models - this signature is optimized for use with TypedPredictor and TypedChainOfThought Oct 20, 2024
@okhat
Copy link
Collaborator

okhat commented Oct 20, 2024

Amazing. Having some discussions on Discord at https://discord.com/channels/1161519468141355160/1294140517764042794/1297707179054207028

@okhat okhat self-requested a review October 21, 2024 17:28
if field.default and 'typing.Annotated' in str(field.default):
raise ValueError(f"Field '{field_name}' is annotated incorrectly. See 'Constraints on compound types' in https://docs.pydantic.dev/latest/concepts/fields/")

is_default_value_specified, is_marked_as_optional, inner_field = cls._process_field(field)
Copy link
Collaborator

@okhat okhat Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting to see special support for optional fields hmm, do we need optional fields? Or are Optional[.] types enough? I see that you have this for input fields here and for output fields below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pydantic allows multiple ways to specify a field schema, including Optional property. It could be on the field spec inside the annotation or outside. IMO, we should not let users guess what pydantic field metadata we do and don't support. If we can support it, we should

if field.default is None or field.default is PydanticUndefined:
field.default = 'null'
field.description = inner_field.description
field.examples = inner_field.examples
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting to see examples and metadata per field

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are definitely valuable signals. We should leverage Pydantic field metadata instead of duplicating in DSPy signature fields

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: I am no longer using examples. It turns out LLM's will incorrectly return an example as the parameter value instead of "NOT_FOUND"


if field.default is PydanticUndefined:
raise ValueError(
f"Field '{field_name}' has no default value. Required fields must have a default value. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting concept here, requiring a default value for every required output field, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users may provide input without passing the value of a required parameter. So how to indicate that the required value was missing in the input, without throwing an exception? And without forcing the LM to hallucinate a required missing value?

Luckily, you can specify that Pydantic default value should not be subject to validation. This means that regardless of the type (str, int, float, ...), you can specify a default value of "NOT_FOUND" and this signature will correctly detect and return it without hallucinating.

No wrap validators necessary, which in any case don't distinguish between invalid and missing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unique nature of agentic applications is that we cannot assume user input is being validated in the UI for required input. Combine this with LM ability to hallucinate when forced to do so...

dspy_fields[field_name] = (field.annotation, output_field)

instructions += f"When extracting '{field_name}':\n"
instructions += f"If it is not mentioned in the input fields, return: '{field.default}'. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting choice of words. "When extracting" and "return". Something to keep in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work. Any pitfalls here? Are you thinking if the field name was "extracted_name" or field.default was "return"? Not quite sure how to get around that


examples = field.examples
if examples:
quoted_examples = [f"'{example}'" for example in examples]
Copy link
Collaborator

@okhat okhat Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naive quotes will fail on complex values? same with the naive join below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only issue I have seen is the LM sometimes returning string values enclosed in quotes. For example, returning "'John Doe'", instead of "John Doe". But this is easy to strip. It correctly extracts numbers without putting them in quotes but I guess the prompt could be enhanced to detect if the example values should be quoted. Any other suggestions on how this could be improved?

Also, re. complex types - isn't that what I am testing with complex Input and Output BaseModels? So may be the code for pulling apart complex types and building a custom signature is unavoidable. Thoughts?

if field.metadata:
constraints = [meta for meta in field.metadata if 'Validator' not in str(meta)]
if field.json_schema_extra and 'invalid_value' in field.json_schema_extra:
instructions += f"If the extracted value does not conform to: {constraints}, return: '{field.json_schema_extra['invalid_value']}'."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting, asking the model to signal bad/hard fields

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detecting missing and invalid values is a "must-have" requirement for production-quality apps. Here, I am basically trying to avoid the multiple retries that the framework makes when the extracted value does not match the constraints specified in the Pydantic BaseModel.

The LM may be incorrectly extracting the parameter or the user has specified an invalid value. In both cases, the system should not hallucinate the closest value that matches the actual input. Rather, it should flag the invalid field and let the application handle the error with a proper error-correction workflow ("Provided value was invalid. It must be... Here are some examples... Did you mean...?")

return 'INVALID_NAME'

name: Annotated[str,
Field(default='NOT_FOUND', max_length=15,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use Annotated instead of assignment = Field(...)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Pydantic docs - "In case you use field constraints with compound types, an error can happen in some cases. To avoid potential issues, you can use Annotated:"

See https://docs.pydantic.dev/latest/concepts/fields/#validation-alias

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants