Skip to content

Commit

Permalink
Add Pydantic Support (#35)
Browse files Browse the repository at this point in the history
* Add pydantic support [wip]

* Temporary comment for field

* Complete Pydantic integration with reasking

* Add pydantic dep

* add integration test for pydantic

---------

Co-authored-by: Shreya Rajpal <[email protected]>
  • Loading branch information
krandiash and ShreyaR authored Mar 20, 2023
1 parent a31e2ea commit 9715037
Show file tree
Hide file tree
Showing 22 changed files with 1,441 additions and 37 deletions.
569 changes: 569 additions & 0 deletions docs/integrations/pydantic_validation.ipynb

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions guardrails/constants.xml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,19 @@ Here are examples of simple (XML, JSON) pairs that show the expected behavior:

JSON Object:</complete_json_suffix>

<complete_json_suffix_v2>
Given below is XML that describes the information to extract from this document and the tags to extract it into.

{output_schema}

ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise.

Here are examples of simple (XML, JSON) pairs that show the expected behavior:
- `<![CDATA[<string name='foo' format='two-words lower-case' />`]]> => `{{{{'foo': 'example one'}}}}`
- `<![CDATA[<list name='bar'><string format='upper-case' /></list>]]>` => `{{{{"bar": ['STRING ONE', 'STRING TWO', etc.]}}}}`
- `<![CDATA[<object name='baz'><string name="foo" format="capitalize two-words" /><integer name="index" format="1-indexed" /></object>]]>` => `{{{{'baz': {{{{'foo': 'Some String', 'index': 1}}}}}}}}`

JSON Object:</complete_json_suffix_v2>


</constants>
181 changes: 158 additions & 23 deletions guardrails/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import datetime
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple, Type, Union

from lxml import etree as ET
from pydantic import BaseModel

if TYPE_CHECKING:
from guardrails.schema import FormatAttr
Expand Down Expand Up @@ -38,14 +39,38 @@ def __iter__(self) -> Generator[Tuple[str, "DataType", ET._Element], None, None]
assert len(self._children) == 1, "Must have exactly one child."
yield None, list(self._children.values())[0], el_child

def iter(
self, element: ET._Element
) -> Generator[Tuple[str, "DataType", ET._Element], None, None]:
"""Return a tuple of (name, child_data_type, child_element) for each
child."""
for el_child in element:
if "name" in el_child.attrib:
name: str = el_child.attrib["name"]
child_data_type: DataType = self._children[name]
yield name, child_data_type, el_child
else:
assert len(self._children) == 1, "Must have exactly one child."
yield None, list(self._children.values())[0], el_child

@classmethod
def from_str(cls, s: str) -> "DataType":
"""Create a DataType from a string."""
raise NotImplementedError("Abstract method.")
"""Create a DataType from a string.
Note: ScalarTypes like int, float, bool, etc. will override this method.
Other ScalarTypes like string, email, url, etc. will not override this
"""
return s

def validate(self, key: str, value: Any, schema: Dict) -> Dict:
"""Validate a value."""
raise NotImplementedError("Abstract method.")

value = self.from_str(value)

for validator in self.validators:
schema = validator.validate_with_correction(key, value, schema)

return schema

def set_children(self, element: ET._Element):
raise NotImplementedError("Abstract method.")
Expand Down Expand Up @@ -83,29 +108,10 @@ def decorator(cls: type):


class ScalarType(DataType):
def validate(self, key: str, value: Any, schema: Dict) -> Dict:
"""Validate a value."""

value = self.from_str(value)

for validator in self.validators:
schema = validator.validate_with_correction(key, value, schema)

return schema

def set_children(self, element: ET._Element):
for _ in element:
raise ValueError("ScalarType data type must not have any children.")

@classmethod
def from_str(cls, s: str) -> "ScalarType":
"""Create a ScalarType from a string.
Note: ScalarTypes like int, float, bool, etc. will override this method.
Other ScalarTypes like string, email, url, etc. will not override this
"""
return s


class NonScalarType(DataType):
pass
Expand Down Expand Up @@ -276,6 +282,135 @@ def set_children(self, element: ET._Element):
self._children[child.attrib["name"]] = child_data_type.from_xml(child)


@register_type("pydantic")
class Pydantic(NonScalarType):
"""Element tag: `<pydantic>`"""

def __init__(
self,
model: Type[BaseModel],
children: Dict[str, Any],
format_attr: "FormatAttr",
element: ET._Element,
) -> None:
super().__init__(children, format_attr, element)
assert (
format_attr.empty
), "The <pydantic /> data type does not support the `format` attribute."
assert isinstance(model, type) and issubclass(
model, BaseModel
), "The `model` argument must be a Pydantic model."

self.model = model

@property
def validators(self) -> List:
from guardrails.validators import Pydantic as PydanticValidator

# Check if the <pydantic /> element has an `on-fail` attribute.
# If so, use that as the `on_fail` argument for the PydanticValidator.
on_fail = None
on_fail_attr_name = "on-fail-pydantic"
if on_fail_attr_name in self.element.attrib:
on_fail = self.element.attrib[on_fail_attr_name]
return [PydanticValidator(self.model, on_fail=on_fail)]

def set_children(self, element: ET._Element):
for child in element:
child_data_type = registry[child.tag]
self._children[child.attrib["name"]] = child_data_type.from_xml(child)

@classmethod
def from_xml(cls, element: ET._Element, strict: bool = False) -> "DataType":
from guardrails.schema import FormatAttr
from guardrails.utils.pydantic_utils import pydantic_models

model_name = element.attrib["model"]
model = pydantic_models.get(model_name, None)

if model is None:
raise ValueError(f"Invalid Pydantic model: {model_name}")

data_type = cls(model, {}, FormatAttr(), element)
data_type.set_children(element)
return data_type

def to_object_element(self) -> ET._Element:
"""Convert the Pydantic data type to an <object /> element."""
from guardrails.utils.pydantic_utils import (
PYDANTIC_SCHEMA_TYPE_MAP,
get_field_descriptions,
pydantic_validators,
)

# Get the following attributes
# TODO: add on-fail
try:
name = self.element.attrib["name"]
except KeyError:
name = None
try:
description = self.element.attrib["description"]
except KeyError:
description = None

# Get the Pydantic model schema.
schema = self.model.schema()
field_descriptions = get_field_descriptions(self.model)

# Make the XML as follows using lxml
# <object name="..." description="..." format="semicolon separated root validators" pydantic="ModelName"> # noqa: E501
# <type name="..." description="..." format="semicolon separated validators" /> # noqa: E501
# </object>

# Add the object element, opening tag
xml = ""
root_validators = "; ".join(
list(pydantic_validators[self.model]["__root__"].keys())
)
xml += "<object "
if name:
xml += f' name="{name}"'
if description:
xml += f' description="{description}"'
if root_validators:
xml += f' format="{root_validators}"'
xml += f' pydantic="{self.model.__name__}"'
xml += ">"

# Add all the nested fields
for field in schema["properties"]:
properties = schema["properties"][field]
field_type = PYDANTIC_SCHEMA_TYPE_MAP[properties["type"]]
field_validators = "; ".join(
list(pydantic_validators[self.model][field].keys())
)
try:
field_description = field_descriptions[field]
except KeyError:
field_description = ""
xml += f"<{field_type}"
xml += f' name="{field}"'
if field_description:
xml += f' description="{field_descriptions[field]}"'
if field_validators:
xml += f' format="{field_validators}"'
xml += " />"

# Close the object element
xml += "</object>"

# Convert the string to an XML element, making sure to format it.
return ET.fromstring(
xml, parser=ET.XMLParser(encoding="utf-8", remove_blank_text=True)
)


@register_type("field")
class Field(ScalarType):
"""Element tag: `<field>`"""


# @register_type("key")
# class Key(DataType):
# """
Expand Down
35 changes: 32 additions & 3 deletions guardrails/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class FormatAttr:
# The XML element that this format attribute is associated with.
element: Optional[ET._Element] = None

@property
def empty(self) -> bool:
"""Return True if the format attribute is empty, False otherwise."""
return self.format is None

@classmethod
def from_element(cls, element: ET._Element) -> "FormatAttr":
"""Create a FormatAttr object from an XML element.
Expand Down Expand Up @@ -384,6 +389,23 @@ def _inner(dt: DataType, el: ET._Element):
dt_child = schema[el_child.attrib["name"]]
_inner(dt_child, el_child)

@staticmethod
def pydantic_to_object(schema: Schema) -> None:
"""Recursively replace all pydantic elements with object elements."""
from guardrails.datatypes import Pydantic

def _inner(dt: DataType, el: ET._Element):
if isinstance(dt, Pydantic):
new_el = dt.to_object_element()
el.getparent().replace(el, new_el)

for _, dt_child, el_child in dt.iter(el):
_inner(dt_child, el_child)

for el_child in schema.root:
dt_child = schema[el_child.attrib["name"]]
_inner(dt_child, el_child)

@classmethod
def default(cls, schema: Schema) -> str:
"""Default transpiler.
Expand All @@ -407,6 +429,13 @@ def default(cls, schema: Schema) -> str:
cls.remove_on_fail_attributes(schema.root)
# Remove validators with arguments.
cls.validator_to_prompt(schema)

# Return the XML as a string.
return ET.tostring(schema.root, encoding="unicode", method="xml")
# Replace pydantic elements with object elements.
cls.pydantic_to_object(schema)

# Return the XML as a string that is
return ET.tostring(
schema.root,
encoding="unicode",
method="xml",
# pretty_print=True,
)
13 changes: 12 additions & 1 deletion guardrails/utils/logs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ def merge_reask_output(prev_logs: GuardLogs, current_logs: GuardLogs) -> Dict:
Returns:
The merged output.
"""
from guardrails.validators import PydanticReAsk

previous_response = prev_logs.validated_output
pruned_reask_json = prune_json_for_reasking(previous_response)
reask_response = current_logs.validated_output
Expand All @@ -99,7 +101,16 @@ def merge_reask_output(prev_logs: GuardLogs, current_logs: GuardLogs) -> Dict:
merged_json = deepcopy(previous_response)

def update_reasked_elements(pruned_reask_json, reask_response_dict):
if isinstance(pruned_reask_json, dict):
if isinstance(pruned_reask_json, PydanticReAsk):
corrected_value = reask_response_dict
# Get the path from any of the ReAsk objects in the PydanticReAsk object
# all of them have the same path.
path = [v.path for v in pruned_reask_json.values() if isinstance(v, ReAsk)][
0
]
update_response_by_path(merged_json, path, corrected_value)

elif isinstance(pruned_reask_json, dict):
for key, value in pruned_reask_json.items():
if isinstance(value, ReAsk):
corrected_value = reask_response_dict[key]
Expand Down
Loading

0 comments on commit 9715037

Please sign in to comment.