-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapters: Support JSON serialization of all pydantic types (e.g. datetimes, enums, etc.) #1853
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: dbczumar <[email protected]>
@@ -115,86 +115,6 @@ def format_fields(self, signature, values, role): | |||
return format_fields(fields_with_values) | |||
|
|||
|
|||
def format_blob(blob): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved these utilities into a shared adapters/utils.py
module, since they're identical but repeated across JSONAdapter and ChatAdapter.
(Except that JSONAdapter doesn't supported images yet and throws an "unsupported" exception. This behavior is preserved by wrapping utils.format_field_value()
with aif field_info.annotation is Image: throw
block)
@@ -231,7 +151,7 @@ def parse_value(value, annotation): | |||
parsed_value = value | |||
|
|||
if isinstance(annotation, enum.EnumMeta): | |||
parsed_value = annotation[value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enum handling logic was incorrect.
Given an enum like
from enum import Enum
class Status(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
Serializing this enum to JSON with Pydantic produces in_progress
(an enum field value), not IN_PROGRESS
(an enum field *name). Status('in_progress') is the correct way to restore the enum field from its value, while Status['in_progress'] throws:
python3.10/enum.py:440, in EnumMeta.__getitem__(cls, name)
439 def __getitem__(cls, name):
--> 440 return cls._member_map_[name]
KeyError: 'in_progress'
I've added test coverage to confirm that the new behavior is correct (see test_predict
and reliability/complex_types/generated/test_many_types_1
)
import json | ||
import textwrap | ||
import json_repair | ||
|
||
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin | ||
|
||
import json_repair | ||
import litellm | ||
import pydantic | ||
from pydantic import TypeAdapter | ||
from pydantic.fields import FieldInfo | ||
from typing import Any, Dict, KeysView, List, Literal, NamedTuple, get_args, get_origin | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is all just the linter reordering imports
|
||
|
||
|
||
try: | ||
provider = lm.model.split('/', 1)[0] or "openai" | ||
if 'response_format' in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): | ||
outputs = lm(**inputs, **lm_kwargs, response_format={ "type": "json_object" }) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just the linter - no material change here
assert set(value.keys()) == set( | ||
signature.output_fields.keys() | ||
), f"Expected {signature.output_fields.keys()} but got {value.keys()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just the linter - no material change here
parsed_value = value | ||
|
||
if isinstance(annotation, enum.EnumMeta): | ||
parsed_value = annotation[value] | ||
parsed_value = annotation(value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def format_blob(blob): | ||
if "\n" not in blob and "«" not in blob and "»" not in blob: | ||
return f"«{blob}»" | ||
|
||
modified_blob = blob.replace("\n", "\n ") | ||
return f"«««\n {modified_blob}\n»»»" | ||
|
||
|
||
def format_input_list_field_value(value: List[Any]) -> str: | ||
""" | ||
Formats the value of an input field of type List[Any]. | ||
|
||
Args: | ||
value: The value of the list-type input field. | ||
Returns: | ||
A string representation of the input field's list value. | ||
""" | ||
if len(value) == 0: | ||
return "N/A" | ||
if len(value) == 1: | ||
return format_blob(value[0]) | ||
|
||
return "\n".join([f"[{idx+1}] {format_blob(txt)}" for idx, txt in enumerate(value)]) | ||
return TypeAdapter(annotation).validate_python(parsed_value) | ||
|
||
|
||
def _serialize_for_json(value): | ||
if isinstance(value, pydantic.BaseModel): | ||
return value.model_dump() | ||
elif isinstance(value, list): | ||
return [_serialize_for_json(item) for item in value] | ||
elif isinstance(value, dict): | ||
return {key: _serialize_for_json(val) for key, val in value.items()} | ||
else: | ||
return value | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deduplicate code with chat_adapter by moving these into a utils
file - see https://github.com/stanfordnlp/dspy/pull/1853/files#r1856167223
return ( | ||
f" (must be formatted as a valid Python {get_annotation_name(v.annotation)})" | ||
if v.annotation is not str | ||
else "" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (and everything else below it in this file) is just a linter change
"The 'processedTupleField' should be a tuple containing a string and a number.", | ||
"The 'processedEnumField' should be one of the allowed enum values: 'option1', 'option2', or 'option3'.", | ||
"The 'processedDatetimeField' should be a date-time", | ||
"The 'processedLiteralField' should be exactly 'literalValue'.", | ||
"The 'processedObjectField' should contain 'subField1' (string), 'subField2' (number), and an additional boolean field 'additionalField'.", | ||
"The 'processedNestedObjectField' should contain 'tupleField' (which is actually a list with a string and a number - the name is misleading), 'enumField' (one of the allowed enum values), 'datetimeField' (string formatted as date-time), 'literalField' (exactly 'literalValue'), and an additional boolean field 'additionalField'." | ||
"The 'processedNestedObjectField' should contain 'tupleField' as a tuple with a string and float, 'enumField' (one of the allowed enum values), 'datetimeField' (string formatted as date-time), 'literalField' (exactly 'literalValue'), and an additional boolean field 'additionalField'." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirmed that these generated reliability tests for enum inputs / outputs pass with ChatAdapter / JSONAdapter (see PR description). I did notice some flakiness in the LLM judge with the version of the generated program on main
: the program contains a field whose name contains tuple
but whose type is actually list
. The judge gets confused and thinks there's a problem. This PR fixes that by changing the type.
Generating test cases with field names that don't match their types is probably a good idea, but that should be added separately (it wasn't the goal of this case)
def _serialize_for_json(value): | ||
if isinstance(value, pydantic.BaseModel): | ||
return value.model_dump() | ||
elif isinstance(value, list): | ||
return [_serialize_for_json(item) for item in value] | ||
elif isinstance(value, dict): | ||
return {key: _serialize_for_json(val) for key, val in value.items()} | ||
else: | ||
return value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core change in this PR is the replacement of this method with the new version in utils.py
:
def serialize_for_json(value: Any) -> Any:
"""
Formats the specified value so that it can be serialized as a JSON string.
Args:
value: The value to format as a JSON string.
Returns:
The formatted value, which is serializable as a JSON string.
"""
# Attempt to format the value as a JSON-compatible object using pydantic, falling back to
# a string representation of the value if that fails (e.g. if the value contains an object
# that pydantic doesn't recognize or can't serialize)
try:
return TypeAdapter(type(value)).dump_python(value, mode="json")
except Exception:
return str(value)
This reuses pydantic's JSON serialization, ensuring we don't miss certain types.
Signed-off-by: dbczumar <[email protected]>
Looks great, thank you @dbczumar ! I will do some testing then merge |
Seems OK in my initial testing except for: from enum import Enum
Color = Enum('Color', ['RED', 'GREEN', 'BLUE'])
class Colorful(dspy.Signature):
text: str = dspy.InputField()
color: Color = dspy.OutputField()
dspy.ChainOfThought(Colorful)(text="The sky is blue.") That throws; it's unable to parse the "BLUE" that the model generates, with 4o-mini. though a reminder to self that I haven't ran end-to-end notebooks yet. |
Fixes #1826
In addition to the unit tests introduced in this PR, I verified that generated integration tests for a program with enum inputs and enum outputs pass as well with
ChatAdapter
andJSONAdapter