From 9715037de6a4a8fb638cfb21b40945375ca26af5 Mon Sep 17 00:00:00 2001 From: Karan Goel Date: Sun, 19 Mar 2023 21:20:36 -0700 Subject: [PATCH] Add Pydantic Support (#35) * Add pydantic support [wip] * Temporary comment for field * Complete Pydantic integration with reasking * Add pydantic dep * add integration test for pydantic --------- Co-authored-by: Shreya Rajpal --- docs/integrations/pydantic_validation.ipynb | 569 ++++++++++++++++++ guardrails/constants.xml | 14 + guardrails/datatypes.py | 181 +++++- guardrails/schema.py | 35 +- guardrails/utils/logs_utils.py | 13 +- guardrails/utils/pydantic_utils.py | 107 ++++ guardrails/utils/reask_utils.py | 29 +- guardrails/validators.py | 92 ++- setup.py | 1 + tests/integration_tests/mock_llm_outputs.py | 5 +- .../test_cases/pydantic/__init__.py | 34 ++ .../test_cases/pydantic/compiled_prompt.txt | 28 + .../pydantic/compiled_prompt_reask_1.txt | 33 + .../pydantic/compiled_prompt_reask_2.txt | 33 + .../test_cases/pydantic/llm_output.txt | 20 + .../pydantic/llm_output_reask_1.txt | 2 + .../pydantic/llm_output_reask_2.txt | 10 + .../test_cases/pydantic/reask.rail | 60 ++ .../pydantic/validated_response_reask_1.py | 59 ++ .../pydantic/validated_response_reask_2.py | 59 ++ .../pydantic/validated_response_reask_3.py | 50 ++ tests/integration_tests/test_pydantic.py | 44 ++ 22 files changed, 1441 insertions(+), 37 deletions(-) create mode 100644 docs/integrations/pydantic_validation.ipynb create mode 100644 guardrails/utils/pydantic_utils.py create mode 100644 tests/integration_tests/test_cases/pydantic/__init__.py create mode 100644 tests/integration_tests/test_cases/pydantic/compiled_prompt.txt create mode 100644 tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_1.txt create mode 100644 tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_2.txt create mode 100644 tests/integration_tests/test_cases/pydantic/llm_output.txt create mode 100644 tests/integration_tests/test_cases/pydantic/llm_output_reask_1.txt create mode 100644 tests/integration_tests/test_cases/pydantic/llm_output_reask_2.txt create mode 100644 tests/integration_tests/test_cases/pydantic/reask.rail create mode 100644 tests/integration_tests/test_cases/pydantic/validated_response_reask_1.py create mode 100644 tests/integration_tests/test_cases/pydantic/validated_response_reask_2.py create mode 100644 tests/integration_tests/test_cases/pydantic/validated_response_reask_3.py create mode 100644 tests/integration_tests/test_pydantic.py diff --git a/docs/integrations/pydantic_validation.ipynb b/docs/integrations/pydantic_validation.ipynb new file mode 100644 index 000000000..f93535ef3 --- /dev/null +++ b/docs/integrations/pydantic_validation.ipynb @@ -0,0 +1,569 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Validating LLM Outputs with Pydantic\n", + "\n", + "!!! note\n", + " To download this example as a Jupyter notebook, click [here](https://github.com/ShreyaR/guardrails/blob/main/docs/examples/pydantic_validation.ipynb).\n", + "\n", + "In this example, we will use Guardrails with Pydantic.\n", + "\n", + "## Objective\n", + "\n", + "We want to generate synthetic data that is consistent with a `Person` Pydantic BaseModel." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import guardrails as gd\n", + "\n", + "from rich import print" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Create the RAIL Spec\n", + "\n", + "Ordinarily, we would create an RAIL spec in a separate file. For the purposes of this example, we will create the spec in this notebook as a string following the RAIL syntax. For more information on RAIL, see the [RAIL documentation](../rail/output.md).\n", + "\n", + "Here, we define a Pydantic model for a `Person` with the following fields:\n", + "- `name`: a string\n", + "- `age`: an integer\n", + "- `zip_code`: a string zip code\n", + "\n", + "and write very simple validators for the fields as an example. As a way to show how LLM reasking can be used to generate data that is consistent with the Pydantic model, we can define a validator that asks for a zip code in California (including being perversely opposed to the \"90210\" zip code). If this validator fails, the LLM will be sent the error message and will reask the question.\n", + "\n", + "This Pydantic model could also be any model that you already have in your codebase, and just needs to be decorated with `@register_pydantic`.\n", + "\n", + "\n", + "To use this model in the `` specification, we used the special\n", + "`pydantic` tag. This tag takes the name of the Pydantic model, as well as the\n", + "`on-fail-pydantic` attribute, which specifies what to do when the output\n", + "does not validate against the Pydantic model." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "rail_str = \"\"\"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + " \n", + " \n", + " \n", + "\n", + "\n", + "\n", + "\n", + "Generate data for possible users in accordance with the specification below.\n", + "\n", + "@xml_prefix_prompt\n", + "\n", + "{output_schema}\n", + "\n", + "@complete_json_suffix_v2\n", + "\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Create a `Guard` object with the RAIL Spec\n", + "\n", + "We create a `gd.Guard` object that will check, validate and correct the output of the LLM. This object:\n", + "\n", + "1. Enforces the quality criteria specified in the RAIL spec.\n", + "2. Takes corrective action when the quality criteria are not met.\n", + "3. Compiles the schema and type info from the RAIL spec and adds it to the prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "guard = gd.Guard.from_rail_string(rail_str)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We see the prompt that will be sent to the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "Generate data for possible users in accordance with the specification below.\n",
+       "\n",
+       "\n",
+       "Given below is XML that describes the information to extract from this document and the tags to extract it into.\n",
+       "\n",
+       "\n",
+       "<output>\n",
+       "    <list name=\"people\" description=\"A list of 3 people.\">\n",
+       "        <object description=\"Information about a person.\" pydantic=\"Person\"><string name=\"name\" description=\"The \n",
+       "name of the person.\"/><integer name=\"age\" description=\"The age of the person.\" \n",
+       "format=\"age-must-be-between-0-and-150\"/><string name=\"zip_code\" description=\"The zip code of the person.\" \n",
+       "format=\"zip-code-must-be-numeric; zip-code-in-california\"/></object></list>\n",
+       "</output>\n",
+       "\n",
+       "\n",
+       "Given below is XML that describes the information to extract from this document and the tags to extract it into.\n",
+       "\n",
+       "<output>\n",
+       "    <list name=\"people\" description=\"A list of 3 people.\">\n",
+       "        <object description=\"Information about a person.\" pydantic=\"Person\"><string name=\"name\" description=\"The \n",
+       "name of the person.\"/><integer name=\"age\" description=\"The age of the person.\" \n",
+       "format=\"age-must-be-between-0-and-150\"/><string name=\"zip_code\" description=\"The zip code of the person.\" \n",
+       "format=\"zip-code-must-be-numeric; zip-code-in-california\"/></object></list>\n",
+       "</output>\n",
+       "\n",
+       "ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` \n",
+       "attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON\n",
+       "MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and \n",
+       "specific types. Be correct and concise.\n",
+       "\n",
+       "Here are examples of simple (XML, JSON) pairs that show the expected behavior:\n",
+       "- `<string name='foo' format='two-words lower-case' />` => `{{'foo': 'example one'}}`\n",
+       "- `<list name='bar'><string format='upper-case' /></list>` => `{{\"bar\": ['STRING ONE', 'STRING TWO', etc.]}}`\n",
+       "- `<object name='baz'><string name=\"foo\" format=\"capitalize two-words\" /><integer name=\"index\" format=\"1-indexed\" \n",
+       "/></object>` => `{{'baz': {{'foo': 'Some String', 'index': 1}}}}`\n",
+       "\n",
+       "JSON Object:\n",
+       "
\n" + ], + "text/plain": [ + "\n", + "Generate data for possible users in accordance with the specification below.\n", + "\n", + "\n", + "Given below is XML that describes the information to extract from this document and the tags to extract it into.\n", + "\n", + "\n", + "\u001b[1m<\u001b[0m\u001b[1;95moutput\u001b[0m\u001b[39m>\u001b[0m\n", + "\u001b[39m \u001b[0m\n", + "\u001b[39m <\u001b[0m\u001b[35m/\u001b[0m\u001b[95mobject\u001b[0m\u001b[39m><\u001b[0m\u001b[35m/\u001b[0m\u001b[95mlist\u001b[0m\u001b[39m>\u001b[0m\n", + "\u001b[39m<\u001b[0m\u001b[35m/\u001b[0m\u001b[95moutput\u001b[0m\u001b[39m>\u001b[0m\n", + "\n", + "\n", + "\u001b[39mGiven below is XML that describes the information to extract from this document and the tags to extract it into.\u001b[0m\n", + "\n", + "\u001b[39m\u001b[0m\n", + "\u001b[39m \u001b[0m\n", + "\u001b[39m <\u001b[0m\u001b[35m/\u001b[0m\u001b[95mobject\u001b[0m\u001b[39m><\u001b[0m\u001b[35m/\u001b[0m\u001b[95mlist\u001b[0m\u001b[39m>\u001b[0m\n", + "\u001b[39m<\u001b[0m\u001b[35m/\u001b[0m\u001b[95moutput\u001b[0m\u001b[39m>\u001b[0m\n", + "\n", + "\u001b[39mONLY return a valid JSON object \u001b[0m\u001b[1;39m(\u001b[0m\u001b[39mno other text is necessary\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m, where the key of the field in JSON is the `name` \u001b[0m\n", + "\u001b[39mattribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON\u001b[0m\n", + "\u001b[39mMUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and \u001b[0m\n", + "\u001b[39mspecific types. Be correct and concise.\u001b[0m\n", + "\n", + "\u001b[39mHere are examples of simple \u001b[0m\u001b[1;39m(\u001b[0m\u001b[39mXML, JSON\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m pairs that show the expected behavior:\u001b[0m\n", + "\u001b[39m- `` => `\u001b[0m\u001b[1;39m{\u001b[0m\u001b[1;39m{\u001b[0m\u001b[32m'foo'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'example one'\u001b[0m\u001b[1;39m}\u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m`\u001b[0m\n", + "\u001b[39m- `<\u001b[0m\u001b[35m/\u001b[0m\u001b[95mlist\u001b[0m\u001b[39m>` => `\u001b[0m\u001b[1;39m{\u001b[0m\u001b[1;39m{\u001b[0m\u001b[32m\"bar\"\u001b[0m\u001b[39m: \u001b[0m\u001b[1;39m[\u001b[0m\u001b[32m'STRING ONE'\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'STRING TWO'\u001b[0m\u001b[39m, etc.\u001b[0m\u001b[1;39m]\u001b[0m\u001b[1;39m}\u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m`\u001b[0m\n", + "\u001b[39m- `<\u001b[0m\u001b[35m/\u001b[0m\u001b[95mobject\u001b[0m\u001b[39m>` =\u001b[0m\u001b[1m>\u001b[0m `\u001b[1m{\u001b[0m\u001b[1m{\u001b[0m\u001b[32m'baz'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1m{\u001b[0m\u001b[32m'foo'\u001b[0m: \u001b[32m'Some String'\u001b[0m, \u001b[32m'index'\u001b[0m: \u001b[1;36m1\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m`\n", + "\n", + "JSON Object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(guard.base_prompt)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "!!! note\n", + " Notice that the prompt replaces the `pydantic` tag with the schema, validator and type information from the Pydantic model. This e.g. tells the LLM that we want that `zip-code-must-be-numeric` and `zip-code-in-california`. Guardrails will even automatically read the docstrings from the Pydantic model and add them to the prompt!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Wrap the LLM API call with `Guard`" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/krandiash/opt/anaconda3/envs/guardrails/lib/python3.9/site-packages/eliot/json.py:22: FutureWarning: In the future `np.bool` will be defined as the corresponding NumPy scalar.\n", + " if isinstance(o, (numpy.bool, numpy.bool_)):\n" + ] + } + ], + "source": [ + "import openai\n", + "\n", + "raw_llm_response, validated_response = guard(\n", + " openai.Completion.create,\n", + " engine=\"text-davinci-003\",\n", + " max_tokens=512,\n", + " temperature=0.5,\n", + " num_reasks=2,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'people': [\n",
+       "        Person(name='John Doe', age=25, zip_code='90000'),\n",
+       "        Person(name='Jane Doe', age=30, zip_code='94105'),\n",
+       "        Person(name='John Smith', age=40, zip_code='90001')\n",
+       "    ]\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'people'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[1;35mPerson\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'John Doe'\u001b[0m, \u001b[33mage\u001b[0m=\u001b[1;36m25\u001b[0m, \u001b[33mzip_code\u001b[0m=\u001b[32m'90000'\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[1;35mPerson\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Jane Doe'\u001b[0m, \u001b[33mage\u001b[0m=\u001b[1;36m30\u001b[0m, \u001b[33mzip_code\u001b[0m=\u001b[32m'94105'\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[1;35mPerson\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'John Smith'\u001b[0m, \u001b[33mage\u001b[0m=\u001b[1;36m40\u001b[0m, \u001b[33mzip_code\u001b[0m=\u001b[32m'90001'\u001b[0m\u001b[1m)\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(validated_response)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `guard` wrapper returns the raw_llm_respose (which is a simple string), and the validated and corrected output (which is a dictionary).\n", + "\n", + "We can see that the output is a dictionary with the correct schema and contains a few `Person` objects!\n", + "\n", + "We can even print out the logs of the most recent call. Notice that the first time the LLM actually returns a Beverly Hills zip code, the LLM is sent the error message and is reasked. The second time, the LLM returns a valid zip code and the output is returned." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
GuardHistory(\n",
+       "    history=[\n",
+       "        GuardLogs(\n",
+       "            prompt='\\nGenerate data for possible users in accordance with the specification below.\\n\\n\\nGiven below\n",
+       "is XML that describes the information to extract from this document and the tags to extract it \n",
+       "into.\\n\\n\\n<output>\\n    <list name=\"people\" description=\"A list of 3 people.\">\\n        <object \n",
+       "description=\"Information about a person.\" pydantic=\"Person\"><string name=\"name\" description=\"The name of the \n",
+       "person.\"/><integer name=\"age\" description=\"The age of the person.\" format=\"age-must-be-between-0-and-150\"/><string \n",
+       "name=\"zip_code\" description=\"The zip code of the person.\" format=\"zip-code-must-be-numeric; \n",
+       "zip-code-in-california\"/></object></list>\\n</output>\\n\\n\\nGiven below is XML that describes the information to \n",
+       "extract from this document and the tags to extract it into.\\n\\n<output>\\n    <list name=\"people\" description=\"A \n",
+       "list of 3 people.\">\\n        <object description=\"Information about a person.\" pydantic=\"Person\"><string \n",
+       "name=\"name\" description=\"The name of the person.\"/><integer name=\"age\" description=\"The age of the person.\" \n",
+       "format=\"age-must-be-between-0-and-150\"/><string name=\"zip_code\" description=\"The zip code of the person.\" \n",
+       "format=\"zip-code-must-be-numeric; zip-code-in-california\"/></object></list>\\n</output>\\n\\nONLY return a valid JSON \n",
+       "object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the \n",
+       "corresponding XML, and the value is of the type specified by the corresponding XML\\'s tag. The JSON MUST conform to\n",
+       "the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be \n",
+       "correct and concise.\\n\\nHere are examples of simple (XML, JSON) pairs that show the expected behavior:\\n- `<string \n",
+       "name=\\'foo\\' format=\\'two-words lower-case\\' />` => `{\\'foo\\': \\'example one\\'}`\\n- `<list name=\\'bar\\'><string \n",
+       "format=\\'upper-case\\' /></list>` => `{\"bar\": [\\'STRING ONE\\', \\'STRING TWO\\', etc.]}`\\n- `<object \n",
+       "name=\\'baz\\'><string name=\"foo\" format=\"capitalize two-words\" /><integer name=\"index\" format=\"1-indexed\" \n",
+       "/></object>` => `{\\'baz\\': {\\'foo\\': \\'Some String\\', \\'index\\': 1}}`\\n\\nJSON Object:',\n",
+       "            output=' \\n{\\n    \"people\": [\\n        {\\n            \"name\": \"John Doe\",\\n            \"age\": 25,\\n    \n",
+       "\"zip_code\": \"90210\"\\n        },\\n        {\\n            \"name\": \"Jane Doe\",\\n            \"age\": 30,\\n            \n",
+       "\"zip_code\": \"94105\"\\n        },\\n        {\\n            \"name\": \"John Smith\",\\n            \"age\": 40,\\n            \n",
+       "\"zip_code\": \"90001\"\\n        }\\n    ]\\n}',\n",
+       "            output_as_dict={\n",
+       "                'people': [\n",
+       "                    {'name': 'John Doe', 'age': 25, 'zip_code': '90210'},\n",
+       "                    {'name': 'Jane Doe', 'age': 30, 'zip_code': '94105'},\n",
+       "                    {'name': 'John Smith', 'age': 40, 'zip_code': '90001'}\n",
+       "                ]\n",
+       "            },\n",
+       "            validated_output={\n",
+       "                'people': [\n",
+       "                    {\n",
+       "                        'name': 'John Doe',\n",
+       "                        'age': 25,\n",
+       "                        'zip_code': ReAsk(\n",
+       "                            incorrect_value='90210',\n",
+       "                            error_message='Zip code must not be Beverly Hills.',\n",
+       "                            fix_value=None,\n",
+       "                            path=['people', 0]\n",
+       "                        )\n",
+       "                    },\n",
+       "                    Person(name='Jane Doe', age=30, zip_code='94105'),\n",
+       "                    Person(name='John Smith', age=40, zip_code='90001')\n",
+       "                ]\n",
+       "            },\n",
+       "            reasks=[\n",
+       "                ReAsk(\n",
+       "                    incorrect_value='90210',\n",
+       "                    error_message='Zip code must not be Beverly Hills.',\n",
+       "                    fix_value=None,\n",
+       "                    path=['people', 0]\n",
+       "                )\n",
+       "            ]\n",
+       "        ),\n",
+       "        GuardLogs(\n",
+       "            prompt='\\nI was given the following JSON response, which had problems due to incorrect values.\\n\\n{\\n  \n",
+       "\"people\": [\\n    {\\n      \"name\": \"John Doe\",\\n      \"age\": 25,\\n      \"zip_code\": {\\n        \"incorrect_value\": \n",
+       "\"90210\",\\n        \"error_message\": \"Zip code must not be Beverly Hills.\"\\n      }\\n    }\\n  ]\\n}\\n\\nHelp me correct\n",
+       "the incorrect values based on the given error messages.\\n\\nGiven below is XML that describes the information to \n",
+       "extract from this document and the tags to extract it into.\\n\\n<output>\\n    <list name=\"people\" description=\"A \n",
+       "list of 3 people.\">\\n        <object description=\"Information about a person.\" pydantic=\"Person\"><string \n",
+       "name=\"name\" description=\"The name of the person.\"/><integer name=\"age\" description=\"The age of the person.\" \n",
+       "format=\"age-must-be-between-0-and-150\"/><string name=\"zip_code\" description=\"The zip code of the person.\" \n",
+       "format=\"zip-code-must-be-numeric; zip-code-in-california\"/></object></list>\\n</output>\\n\\nONLY return a valid JSON \n",
+       "object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the \n",
+       "corresponding XML, and the value is of the type specified by the corresponding XML\\'s tag. The JSON MUST conform to\n",
+       "the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be \n",
+       "correct and concise. If you are unsure anywhere, enter `None`.\\n\\nHere are examples of simple (XML, JSON) pairs \n",
+       "that show the expected behavior:\\n- `<string name=\\'foo\\' format=\\'two-words lower-case\\' />` => `{{\\'foo\\': \n",
+       "\\'example one\\'}}`\\n- `<list name=\\'bar\\'><string format=\\'upper-case\\' /></list>` => `{{\"bar\": [\\'STRING ONE\\', \n",
+       "\\'STRING TWO\\', etc.]}}`\\n- `<object name=\\'baz\\'><string name=\"foo\" format=\"capitalize two-words\" /><integer \n",
+       "name=\"index\" format=\"1-indexed\" /></object>` => `{{\\'baz\\': {{\\'foo\\': \\'Some String\\', \\'index\\': 1}}}}`\\n\\nJSON \n",
+       "Object:',\n",
+       "            output='\\n{\\n    \"people\": [\\n        {\\n            \"name\": \"John Doe\",\\n            \"age\": 25,\\n     \n",
+       "\"zip_code\": \"90000\"\\n        }\\n    ]\\n}',\n",
+       "            output_as_dict={'people': [{'name': 'John Doe', 'age': 25, 'zip_code': '90000'}]},\n",
+       "            validated_output={\n",
+       "                'people': [\n",
+       "                    Person(name='John Doe', age=25, zip_code='90000'),\n",
+       "                    Person(name='Jane Doe', age=30, zip_code='94105'),\n",
+       "                    Person(name='John Smith', age=40, zip_code='90001')\n",
+       "                ]\n",
+       "            },\n",
+       "            reasks=[]\n",
+       "        )\n",
+       "    ]\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mGuardHistory\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mhistory\u001b[0m=\u001b[1m[\u001b[0m\n", + " \u001b[1;35mGuardLogs\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mprompt\u001b[0m=\u001b[32m'\\nGenerate data for possible users in accordance with the specification below.\\n\\n\\nGiven below\u001b[0m\n", + "\u001b[32mis XML that describes the information to extract from this document and the tags to extract it \u001b[0m\n", + "\u001b[32minto.\\n\\n\\n\u001b[0m\u001b[32m<\u001b[0m\u001b[32moutput\u001b[0m\u001b[32m>\\n \\n \\n\\n\\n\\nGiven below is XML that describes the information to \u001b[0m\n", + "\u001b[32mextract from this document and the tags to extract it into.\\n\\n\\n \\n \\n\\n\\nONLY return a valid JSON \u001b[0m\n", + "\u001b[32mobject \u001b[0m\u001b[32m(\u001b[0m\u001b[32mno other text is necessary\u001b[0m\u001b[32m)\u001b[0m\u001b[32m, where the key of the field in JSON is the `name` attribute of the \u001b[0m\n", + "\u001b[32mcorresponding XML, and the value is of the type specified by the corresponding XML\\'s tag. The JSON MUST conform to\u001b[0m\n", + "\u001b[32mthe XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be \u001b[0m\n", + "\u001b[32mcorrect and concise.\\n\\nHere are examples of simple \u001b[0m\u001b[32m(\u001b[0m\u001b[32mXML, JSON\u001b[0m\u001b[32m)\u001b[0m\u001b[32m pairs that show the expected behavior:\\n- `` => `\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\'foo\\': \\'example one\\'\u001b[0m\u001b[32m}\u001b[0m\u001b[32m`\\n- `` => `\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\"bar\": \u001b[0m\u001b[32m[\u001b[0m\u001b[32m\\'STRING ONE\\', \\'STRING TWO\\', etc.\u001b[0m\u001b[32m]\u001b[0m\u001b[32m}\u001b[0m\u001b[32m`\\n- `` => `\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\'baz\\': \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\'foo\\': \\'Some String\\', \\'index\\': 1\u001b[0m\u001b[32m}\u001b[0m\u001b[32m}\u001b[0m\u001b[32m`\\n\\nJSON Object:'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33moutput\u001b[0m\u001b[39m=\u001b[0m\u001b[32m' \\n\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"people\": \u001b[0m\u001b[32m[\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"name\": \"John Doe\",\\n \"age\": 25,\\n \u001b[0m\n", + "\u001b[32m\"zip_code\": \"90210\"\\n \u001b[0m\u001b[32m}\u001b[0m\u001b[32m,\\n \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"name\": \"Jane Doe\",\\n \"age\": 30,\\n \u001b[0m\n", + "\u001b[32m\"zip_code\": \"94105\"\\n \u001b[0m\u001b[32m}\u001b[0m\u001b[32m,\\n \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"name\": \"John Smith\",\\n \"age\": 40,\\n \u001b[0m\n", + "\u001b[32m\"zip_code\": \"90001\"\\n \u001b[0m\u001b[32m}\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m]\u001b[0m\u001b[32m\\n\u001b[0m\u001b[32m}\u001b[0m\u001b[32m'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33moutput_as_dict\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m{\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[32m'people'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;39m[\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[32m'name'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'John Doe'\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'age'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;36m25\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'zip_code'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'90210'\u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[32m'name'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'Jane Doe'\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'age'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;36m30\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'zip_code'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'94105'\u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\u001b[32m'name'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'John Smith'\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'age'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;36m40\u001b[0m\u001b[39m, \u001b[0m\u001b[32m'zip_code'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'90001'\u001b[0m\u001b[1;39m}\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m]\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mvalidated_output\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m{\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[32m'people'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;39m[\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m{\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[32m'name'\u001b[0m\u001b[39m: \u001b[0m\u001b[32m'John Doe'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[32m'age'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;36m25\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[32m'zip_code'\u001b[0m\u001b[39m: \u001b[0m\u001b[1;35mReAsk\u001b[0m\u001b[1;39m(\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mincorrect_value\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'90210'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33merror_message\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'Zip code must not be Beverly Hills.'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mfix_value\u001b[0m\u001b[39m=\u001b[0m\u001b[3;35mNone\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mpath\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m[\u001b[0m\u001b[32m'people'\u001b[0m\u001b[39m, \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m]\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m)\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;35mPerson\u001b[0m\u001b[1;39m(\u001b[0m\u001b[33mname\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'Jane Doe'\u001b[0m\u001b[39m, \u001b[0m\u001b[33mage\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m30\u001b[0m\u001b[39m, \u001b[0m\u001b[33mzip_code\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'94105'\u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;35mPerson\u001b[0m\u001b[1;39m(\u001b[0m\u001b[33mname\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'John Smith'\u001b[0m\u001b[39m, \u001b[0m\u001b[33mage\u001b[0m\u001b[39m=\u001b[0m\u001b[1;36m40\u001b[0m\u001b[39m, \u001b[0m\u001b[33mzip_code\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'90001'\u001b[0m\u001b[1;39m)\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m]\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m}\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mreasks\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m[\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;35mReAsk\u001b[0m\u001b[1;39m(\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mincorrect_value\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'90210'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33merror_message\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'Zip code must not be Beverly Hills.'\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mfix_value\u001b[0m\u001b[39m=\u001b[0m\u001b[3;35mNone\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mpath\u001b[0m\u001b[39m=\u001b[0m\u001b[1;39m[\u001b[0m\u001b[32m'people'\u001b[0m\u001b[39m, \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;39m]\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m)\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m]\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;39m)\u001b[0m\u001b[39m,\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[1;35mGuardLogs\u001b[0m\u001b[1;39m(\u001b[0m\n", + "\u001b[39m \u001b[0m\u001b[33mprompt\u001b[0m\u001b[39m=\u001b[0m\u001b[32m'\\nI was given the following JSON response, which had problems due to incorrect values.\\n\\n\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \u001b[0m\n", + "\u001b[32m\"people\": \u001b[0m\u001b[32m[\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"name\": \"John Doe\",\\n \"age\": 25,\\n \"zip_code\": \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"incorrect_value\": \u001b[0m\n", + "\u001b[32m\"90210\",\\n \"error_message\": \"Zip code must not be Beverly Hills.\"\\n \u001b[0m\u001b[32m}\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m}\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m]\u001b[0m\u001b[32m\\n\u001b[0m\u001b[32m}\u001b[0m\u001b[32m\\n\\nHelp me correct\u001b[0m\n", + "\u001b[32mthe incorrect values based on the given error messages.\\n\\nGiven below is XML that describes the information to \u001b[0m\n", + "\u001b[32mextract from this document and the tags to extract it into.\\n\\n\\n \\n \\n\\n\\nONLY return a valid JSON \u001b[0m\n", + "\u001b[32mobject \u001b[0m\u001b[32m(\u001b[0m\u001b[32mno other text is necessary\u001b[0m\u001b[32m)\u001b[0m\u001b[32m, where the key of the field in JSON is the `name` attribute of the \u001b[0m\n", + "\u001b[32mcorresponding XML, and the value is of the type specified by the corresponding XML\\'s tag. The JSON MUST conform to\u001b[0m\n", + "\u001b[32mthe XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be \u001b[0m\n", + "\u001b[32mcorrect and concise. If you are unsure anywhere, enter `None`.\\n\\nHere are examples of simple \u001b[0m\u001b[32m(\u001b[0m\u001b[32mXML, JSON\u001b[0m\u001b[32m)\u001b[0m\u001b[32m pairs \u001b[0m\n", + "\u001b[32mthat show the expected behavior:\\n- `` => `\u001b[0m\u001b[32m{\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\'foo\\': \u001b[0m\n", + "\u001b[32m\\'example one\\'\u001b[0m\u001b[32m}\u001b[0m\u001b[32m}\u001b[0m\u001b[32m`\\n- `` => `\u001b[0m\u001b[32m{\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\"bar\": \u001b[0m\u001b[32m[\u001b[0m\u001b[32m\\'STRING ONE\\', \u001b[0m\n", + "\u001b[32m\\'STRING TWO\\', etc.\u001b[0m\u001b[32m]\u001b[0m\u001b[32m}\u001b[0m\u001b[32m}\u001b[0m\u001b[32m`\\n- `` =\u001b[0m\u001b[32m>\u001b[0m\u001b[32m `\u001b[0m\u001b[32m{\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\'baz\\': \u001b[0m\u001b[32m{\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\'foo\\': \\'Some String\\', \\'index\\': 1\u001b[0m\u001b[32m}\u001b[0m\u001b[32m}\u001b[0m\u001b[32m}\u001b[0m\u001b[32m}\u001b[0m\u001b[32m`\\n\\nJSON \u001b[0m\n", + "\u001b[32mObject:'\u001b[0m,\n", + " \u001b[33moutput\u001b[0m=\u001b[32m'\\n\u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"people\": \u001b[0m\u001b[32m[\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m{\u001b[0m\u001b[32m\\n \"name\": \"John Doe\",\\n \"age\": 25,\\n \u001b[0m\n", + "\u001b[32m\"zip_code\": \"90000\"\\n \u001b[0m\u001b[32m}\u001b[0m\u001b[32m\\n \u001b[0m\u001b[32m]\u001b[0m\u001b[32m\\n\u001b[0m\u001b[32m}\u001b[0m\u001b[32m'\u001b[0m,\n", + " \u001b[33moutput_as_dict\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'people'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m'name'\u001b[0m: \u001b[32m'John Doe'\u001b[0m, \u001b[32m'age'\u001b[0m: \u001b[1;36m25\u001b[0m, \u001b[32m'zip_code'\u001b[0m: \u001b[32m'90000'\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[33mvalidated_output\u001b[0m=\u001b[1m{\u001b[0m\n", + " \u001b[32m'people'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[1;35mPerson\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'John Doe'\u001b[0m, \u001b[33mage\u001b[0m=\u001b[1;36m25\u001b[0m, \u001b[33mzip_code\u001b[0m=\u001b[32m'90000'\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[1;35mPerson\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'Jane Doe'\u001b[0m, \u001b[33mage\u001b[0m=\u001b[1;36m30\u001b[0m, \u001b[33mzip_code\u001b[0m=\u001b[32m'94105'\u001b[0m\u001b[1m)\u001b[0m,\n", + " \u001b[1;35mPerson\u001b[0m\u001b[1m(\u001b[0m\u001b[33mname\u001b[0m=\u001b[32m'John Smith'\u001b[0m, \u001b[33mage\u001b[0m=\u001b[1;36m40\u001b[0m, \u001b[33mzip_code\u001b[0m=\u001b[32m'90001'\u001b[0m\u001b[1m)\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[33mreasks\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n", + " \u001b[1m)\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "print(guard.state.most_recent_call)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "guardrails", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/guardrails/constants.xml b/guardrails/constants.xml index 8522334e6..253bc0f3e 100644 --- a/guardrails/constants.xml +++ b/guardrails/constants.xml @@ -49,5 +49,19 @@ Here are examples of simple (XML, JSON) pairs that show the expected behavior: JSON Object: + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + +{output_schema} + +ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. + +Here are examples of simple (XML, JSON) pairs that show the expected behavior: +- ``]]> => `{{{{'foo': 'example one'}}}}` +- `]]>` => `{{{{"bar": ['STRING ONE', 'STRING TWO', etc.]}}}}` +- `
]]>` => `{{{{'baz': {{{{'foo': 'Some String', 'index': 1}}}}}}}}` + +JSON Object: + \ No newline at end of file diff --git a/guardrails/datatypes.py b/guardrails/datatypes.py index 875d2a2d0..92ece416b 100644 --- a/guardrails/datatypes.py +++ b/guardrails/datatypes.py @@ -1,8 +1,9 @@ import datetime from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple, Type, Union from lxml import etree as ET +from pydantic import BaseModel if TYPE_CHECKING: from guardrails.schema import FormatAttr @@ -38,14 +39,38 @@ def __iter__(self) -> Generator[Tuple[str, "DataType", ET._Element], None, None] assert len(self._children) == 1, "Must have exactly one child." yield None, list(self._children.values())[0], el_child + def iter( + self, element: ET._Element + ) -> Generator[Tuple[str, "DataType", ET._Element], None, None]: + """Return a tuple of (name, child_data_type, child_element) for each + child.""" + for el_child in element: + if "name" in el_child.attrib: + name: str = el_child.attrib["name"] + child_data_type: DataType = self._children[name] + yield name, child_data_type, el_child + else: + assert len(self._children) == 1, "Must have exactly one child." + yield None, list(self._children.values())[0], el_child + @classmethod def from_str(cls, s: str) -> "DataType": - """Create a DataType from a string.""" - raise NotImplementedError("Abstract method.") + """Create a DataType from a string. + + Note: ScalarTypes like int, float, bool, etc. will override this method. + Other ScalarTypes like string, email, url, etc. will not override this + """ + return s def validate(self, key: str, value: Any, schema: Dict) -> Dict: """Validate a value.""" - raise NotImplementedError("Abstract method.") + + value = self.from_str(value) + + for validator in self.validators: + schema = validator.validate_with_correction(key, value, schema) + + return schema def set_children(self, element: ET._Element): raise NotImplementedError("Abstract method.") @@ -83,29 +108,10 @@ def decorator(cls: type): class ScalarType(DataType): - def validate(self, key: str, value: Any, schema: Dict) -> Dict: - """Validate a value.""" - - value = self.from_str(value) - - for validator in self.validators: - schema = validator.validate_with_correction(key, value, schema) - - return schema - def set_children(self, element: ET._Element): for _ in element: raise ValueError("ScalarType data type must not have any children.") - @classmethod - def from_str(cls, s: str) -> "ScalarType": - """Create a ScalarType from a string. - - Note: ScalarTypes like int, float, bool, etc. will override this method. - Other ScalarTypes like string, email, url, etc. will not override this - """ - return s - class NonScalarType(DataType): pass @@ -276,6 +282,135 @@ def set_children(self, element: ET._Element): self._children[child.attrib["name"]] = child_data_type.from_xml(child) +@register_type("pydantic") +class Pydantic(NonScalarType): + """Element tag: ``""" + + def __init__( + self, + model: Type[BaseModel], + children: Dict[str, Any], + format_attr: "FormatAttr", + element: ET._Element, + ) -> None: + super().__init__(children, format_attr, element) + assert ( + format_attr.empty + ), "The data type does not support the `format` attribute." + assert isinstance(model, type) and issubclass( + model, BaseModel + ), "The `model` argument must be a Pydantic model." + + self.model = model + + @property + def validators(self) -> List: + from guardrails.validators import Pydantic as PydanticValidator + + # Check if the element has an `on-fail` attribute. + # If so, use that as the `on_fail` argument for the PydanticValidator. + on_fail = None + on_fail_attr_name = "on-fail-pydantic" + if on_fail_attr_name in self.element.attrib: + on_fail = self.element.attrib[on_fail_attr_name] + return [PydanticValidator(self.model, on_fail=on_fail)] + + def set_children(self, element: ET._Element): + for child in element: + child_data_type = registry[child.tag] + self._children[child.attrib["name"]] = child_data_type.from_xml(child) + + @classmethod + def from_xml(cls, element: ET._Element, strict: bool = False) -> "DataType": + from guardrails.schema import FormatAttr + from guardrails.utils.pydantic_utils import pydantic_models + + model_name = element.attrib["model"] + model = pydantic_models.get(model_name, None) + + if model is None: + raise ValueError(f"Invalid Pydantic model: {model_name}") + + data_type = cls(model, {}, FormatAttr(), element) + data_type.set_children(element) + return data_type + + def to_object_element(self) -> ET._Element: + """Convert the Pydantic data type to an element.""" + from guardrails.utils.pydantic_utils import ( + PYDANTIC_SCHEMA_TYPE_MAP, + get_field_descriptions, + pydantic_validators, + ) + + # Get the following attributes + # TODO: add on-fail + try: + name = self.element.attrib["name"] + except KeyError: + name = None + try: + description = self.element.attrib["description"] + except KeyError: + description = None + + # Get the Pydantic model schema. + schema = self.model.schema() + field_descriptions = get_field_descriptions(self.model) + + # Make the XML as follows using lxml + # # noqa: E501 + # # noqa: E501 + # + + # Add the object element, opening tag + xml = "" + root_validators = "; ".join( + list(pydantic_validators[self.model]["__root__"].keys()) + ) + xml += "`""" + + # @register_type("key") # class Key(DataType): # """ diff --git a/guardrails/schema.py b/guardrails/schema.py index b7f53b442..41fd5bed0 100644 --- a/guardrails/schema.py +++ b/guardrails/schema.py @@ -41,6 +41,11 @@ class FormatAttr: # The XML element that this format attribute is associated with. element: Optional[ET._Element] = None + @property + def empty(self) -> bool: + """Return True if the format attribute is empty, False otherwise.""" + return self.format is None + @classmethod def from_element(cls, element: ET._Element) -> "FormatAttr": """Create a FormatAttr object from an XML element. @@ -384,6 +389,23 @@ def _inner(dt: DataType, el: ET._Element): dt_child = schema[el_child.attrib["name"]] _inner(dt_child, el_child) + @staticmethod + def pydantic_to_object(schema: Schema) -> None: + """Recursively replace all pydantic elements with object elements.""" + from guardrails.datatypes import Pydantic + + def _inner(dt: DataType, el: ET._Element): + if isinstance(dt, Pydantic): + new_el = dt.to_object_element() + el.getparent().replace(el, new_el) + + for _, dt_child, el_child in dt.iter(el): + _inner(dt_child, el_child) + + for el_child in schema.root: + dt_child = schema[el_child.attrib["name"]] + _inner(dt_child, el_child) + @classmethod def default(cls, schema: Schema) -> str: """Default transpiler. @@ -407,6 +429,13 @@ def default(cls, schema: Schema) -> str: cls.remove_on_fail_attributes(schema.root) # Remove validators with arguments. cls.validator_to_prompt(schema) - - # Return the XML as a string. - return ET.tostring(schema.root, encoding="unicode", method="xml") + # Replace pydantic elements with object elements. + cls.pydantic_to_object(schema) + + # Return the XML as a string that is + return ET.tostring( + schema.root, + encoding="unicode", + method="xml", + # pretty_print=True, + ) diff --git a/guardrails/utils/logs_utils.py b/guardrails/utils/logs_utils.py index 1d06c411b..3a64e9ed7 100644 --- a/guardrails/utils/logs_utils.py +++ b/guardrails/utils/logs_utils.py @@ -89,6 +89,8 @@ def merge_reask_output(prev_logs: GuardLogs, current_logs: GuardLogs) -> Dict: Returns: The merged output. """ + from guardrails.validators import PydanticReAsk + previous_response = prev_logs.validated_output pruned_reask_json = prune_json_for_reasking(previous_response) reask_response = current_logs.validated_output @@ -99,7 +101,16 @@ def merge_reask_output(prev_logs: GuardLogs, current_logs: GuardLogs) -> Dict: merged_json = deepcopy(previous_response) def update_reasked_elements(pruned_reask_json, reask_response_dict): - if isinstance(pruned_reask_json, dict): + if isinstance(pruned_reask_json, PydanticReAsk): + corrected_value = reask_response_dict + # Get the path from any of the ReAsk objects in the PydanticReAsk object + # all of them have the same path. + path = [v.path for v in pruned_reask_json.values() if isinstance(v, ReAsk)][ + 0 + ] + update_response_by_path(merged_json, path, corrected_value) + + elif isinstance(pruned_reask_json, dict): for key, value in pruned_reask_json.items(): if isinstance(value, ReAsk): corrected_value = reask_response_dict[key] diff --git a/guardrails/utils/pydantic_utils.py b/guardrails/utils/pydantic_utils.py new file mode 100644 index 000000000..010025df9 --- /dev/null +++ b/guardrails/utils/pydantic_utils.py @@ -0,0 +1,107 @@ +"""Utilities for working with Pydantic models. + +Guardrails lets users specify + + +""" +import logging +from typing import TYPE_CHECKING, Dict + +from griffe.dataclasses import Docstring +from griffe.docstrings.parsers import Parser, parse + +griffe_docstrings_google_logger = logging.getLogger("griffe.docstrings.google") +griffe_agents_nodes_logger = logging.getLogger("griffe.agents.nodes") + +if TYPE_CHECKING: + from pydantic import BaseModel + + +def get_field_descriptions(model: "BaseModel") -> Dict[str, str]: + """Get the descriptions of the fields in a Pydantic model using the + docstring.""" + griffe_docstrings_google_logger.disabled = True + griffe_agents_nodes_logger.disabled = True + try: + docstring = Docstring(model.__doc__, lineno=1) + except AttributeError: + return {} + parsed = parse(docstring, Parser.google) + griffe_docstrings_google_logger.disabled = False + griffe_agents_nodes_logger.disabled = False + + # TODO: change parsed[1] to an isinstance check for the args section + return { + field.name: field.description.replace("\n", " ") + for field in parsed[1].as_dict()["value"] + } + + +PYDANTIC_SCHEMA_TYPE_MAP = { + "string": "string", + "number": "float", + "integer": "integer", + "boolean": "bool", + "object": "object", + "array": "list", +} + +pydantic_validators = {} +pydantic_models = {} + + +# Create a class decorator to register all the validators in a BaseModel +def register_pydantic(cls: type): + """ + Register a Pydantic BaseModel. This is a class decorator that can + be used in the following way: + + ``` + @register_pydantic + class MyModel(BaseModel): + ... + ``` + + This decorator does the following: + 1. Add the model to the pydantic_models dictionary. + 2. Register all pre and post validators. + 3. Register all pre and post root validators. + """ + # Register the model + pydantic_models[cls.__name__] = cls + + # Create a dictionary to store all the validators + pydantic_validators[cls] = {} + # All all pre and post validators, for each field in the model + for field in cls.__fields__.values(): + pydantic_validators[cls][field.name] = {} + if field.pre_validators: + for validator in field.pre_validators: + pydantic_validators[cls][field.name][ + validator.func_name.replace("_", "-") + ] = validator + if field.post_validators: + for validator in field.post_validators: + pydantic_validators[cls][field.name][ + validator.func_name.replace("_", "-") + ] = validator + + pydantic_validators[cls]["__root__"] = {} + # Add all pre and post root validators + if cls.__pre_root_validators__: + for _, validator in cls.__pre_root_validators__: + pydantic_validators[cls]["__root__"][ + validator.__name__.replace("_", "-") + ] = validator + + if cls.__post_root_validators__: + for _, validator in cls.__post_root_validators__: + pydantic_validators[cls]["__root__"][ + validator.__name__.replace("_", "-") + ] = validator + return cls diff --git a/guardrails/utils/reask_utils.py b/guardrails/utils/reask_utils.py index c82fae5a6..e86555ae2 100644 --- a/guardrails/utils/reask_utils.py +++ b/guardrails/utils/reask_utils.py @@ -27,12 +27,18 @@ def gather_reasks(validated_output: Dict) -> List[ReAsk]: Returns: A list of ReAsk objects found in the output. """ + from guardrails.validators import PydanticReAsk + reasks = [] def _gather_reasks_in_dict(output: Dict, path: List[str] = []) -> None: + is_pydantic = isinstance(output, PydanticReAsk) for field, value in output.items(): if isinstance(value, ReAsk): - value.path = path + [field] + if is_pydantic: + value.path = path + else: + value.path = path + [field] reasks.append(value) if isinstance(value, dict): @@ -62,6 +68,12 @@ def get_reasks_by_element( parsed_rail: ET._Element, ) -> Dict[ET._Element, List[tuple]]: """Cluster reasks by the XML element they are associated with.""" + # This should be guaranteed to work, since the path corresponding + # to a ReAsk should always be valid in the element tree. + + # This is because ReAsk objects are only created for elements + # with corresponding validators i.e. the element must have been + # in the tree in the first place for the ReAsk to be created. reasks_by_element = defaultdict(list) @@ -139,8 +151,11 @@ def prune_json_for_reasking(json_object: Any) -> Union[None, Dict, List]: Returns: The pruned validated JSON. """ + from guardrails.validators import PydanticReAsk - if isinstance(json_object, list): + if isinstance(json_object, ReAsk) or isinstance(json_object, PydanticReAsk): + return json_object + elif isinstance(json_object, list): pruned_list = [] for item in json_object: pruned_output = prune_json_for_reasking(item) @@ -152,7 +167,7 @@ def prune_json_for_reasking(json_object: Any) -> Union[None, Dict, List]: elif isinstance(json_object, dict): pruned_json = {} for key, value in json_object.items(): - if isinstance(value, ReAsk): + if isinstance(value, ReAsk) or isinstance(value, PydanticReAsk): pruned_json[key] = value elif isinstance(value, dict): pruned_output = prune_json_for_reasking(value) @@ -171,14 +186,12 @@ def prune_json_for_reasking(json_object: Any) -> Union[None, Dict, List]: return pruned_json return None - else: - if isinstance(json_object, ReAsk): - return json_object - return None def get_reask_prompt( - parsed_rail, reasks: List[ReAsk], reask_json: Dict + parsed_rail: ET._Element, + reasks: List[ReAsk], + reask_json: Dict, ) -> Tuple[str, ET._Element]: """Construct a prompt for reasking. diff --git a/guardrails/validators.py b/guardrails/validators.py index c61b7318d..b9a7880bf 100644 --- a/guardrails/validators.py +++ b/guardrails/validators.py @@ -7,10 +7,12 @@ import logging import os from collections import defaultdict +from copy import deepcopy from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import openai +from pydantic import BaseModel, ValidationError from guardrails.datatypes import registry as types_registry from guardrails.utils.reask_utils import ReAsk @@ -98,6 +100,8 @@ def filter_in_list(schema: List) -> List: for item in schema: if isinstance(item, Filter): pass + elif isinstance(item, PydanticReAsk): + filtered_list.append(item) elif isinstance(item, list): filtered_item = filter_in_list(item) if len(filtered_item): @@ -127,6 +131,8 @@ def filter_in_dict(schema: Dict) -> Dict: for key, value in schema.items(): if isinstance(value, Filter): pass + elif isinstance(value, PydanticReAsk): + filtered_dict[key] = value elif isinstance(value, list): filtered_item = filter_in_list(value) if len(filtered_item): @@ -277,6 +283,90 @@ def to_prompt(self, with_keywords: bool = True) -> str: # return value is not None +class PydanticReAsk(dict): + pass + + +@register_validator(name="pydantic", data_type="pydantic") +class Pydantic(Validator): + """Validate an object using Pydantic.""" + + def __init__( + self, + model: Type[BaseModel], + on_fail: Optional[Callable] = None, + ): + super().__init__(on_fail=on_fail) + + self.model = model + + def validate_with_correction( + self, key: str, value: Dict, schema: Union[Dict, List] + ) -> Dict: + """Validate an object using Pydantic. + + For example, consider the following data for a `Person` model + with fields `name`, `age`, and `zipcode`: + { + "user" : { + "name": "John", + "age": 30, + "zipcode": "12345", + } + } + then `key` is "user", `value` is the value of the "user" key, and + `schema` is the entire schema. + + If this validator succeeds, then the `schema` is returned and + looks like: + { + "user": Person(name="John", age=30, zipcode="12345") + } + + If it fails, then the `schema` is returned and looks like e.g. + { + "user": { + "name": "John", + "age": 30, + "zipcode": ReAsk( + incorrect_value="12345", + error_message="...", + fix_value=None, + path=None, + ) + } + } + """ + try: + # Run the Pydantic model on the value. + schema[key] = self.model(**value) + except ValidationError as e: + # Create a copy of the value so that we can modify it + # to insert e.g. ReAsk objects. + new_value = deepcopy(value) + for error in e.errors(): + assert ( + len(error["loc"]) == 1 + ), "Pydantic validation errors should only have one location." + + field_name = error["loc"][0] + event_detail = EventDetail( + key=field_name, + value=new_value[field_name], + schema=new_value, + error_message=error["msg"], + fix_value=None, + ) + # Call the on_fail method and reassign the value. + new_value = self.on_fail(event_detail) + + # Insert the new `value` dictionary into the schema. + # This now contains e.g. ReAsk objects. + schema[key] = PydanticReAsk(new_value) + + return schema + + @register_validator(name="valid-range", data_type=["integer", "float", "percentage"]) class ValidRange(Validator): """Validate that a value is within a range. diff --git a/setup.py b/setup.py index 925b9c337..d889a7fe4 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ "rich", "eliot", "eliot-tree", + "pydantic", ] # Read in docs/requirements.txt diff --git a/tests/integration_tests/mock_llm_outputs.py b/tests/integration_tests/mock_llm_outputs.py index 2205ba7fc..0b2db735a 100644 --- a/tests/integration_tests/mock_llm_outputs.py +++ b/tests/integration_tests/mock_llm_outputs.py @@ -1,4 +1,4 @@ -from .test_cases import entity_extraction +from .test_cases import entity_extraction, pydantic def openai_completion_create(prompt, *args, **kwargs): @@ -7,6 +7,9 @@ def openai_completion_create(prompt, *args, **kwargs): mock_llm_responses = { entity_extraction.COMPILED_PROMPT: entity_extraction.LLM_OUTPUT, entity_extraction.COMPILED_PROMPT_REASK: entity_extraction.LLM_OUTPUT_REASK, + pydantic.COMPILED_PROMPT: pydantic.LLM_OUTPUT, + pydantic.COMPILED_PROMPT_REASK_1: pydantic.LLM_OUTPUT_REASK_1, + pydantic.COMPILED_PROMPT_REASK_2: pydantic.LLM_OUTPUT_REASK_2, } try: diff --git a/tests/integration_tests/test_cases/pydantic/__init__.py b/tests/integration_tests/test_cases/pydantic/__init__.py new file mode 100644 index 000000000..e1c88936f --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/__init__.py @@ -0,0 +1,34 @@ +# flake8: noqa: E501 +import os + +from .validated_response_reask_1 import VALIDATED_OUTPUT as VALIDATED_OUTPUT_REASK_1 +from .validated_response_reask_2 import VALIDATED_OUTPUT as VALIDATED_OUTPUT_REASK_2 +from .validated_response_reask_3 import VALIDATED_OUTPUT as VALIDATED_OUTPUT_REASK_3 + +DATA_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__))) +reader = ( + lambda filename: open(os.path.join(DATA_DIR, filename)).read().replace("\r", "") +) + +COMPILED_PROMPT = reader("compiled_prompt.txt") +COMPILED_PROMPT_REASK_1 = reader("compiled_prompt_reask_1.txt") +COMPILED_PROMPT_REASK_2 = reader("compiled_prompt_reask_2.txt") + +LLM_OUTPUT = reader("llm_output.txt") +LLM_OUTPUT_REASK_1 = reader("llm_output_reask_1.txt") +LLM_OUTPUT_REASK_2 = reader("llm_output_reask_2.txt") + +RAIL_SPEC_WITH_REASK = reader("reask.rail") + +__all__ = [ + "COMPILED_PROMPT", + "COMPILED_PROMPT_REASK_1", + "COMPILED_PROMPT_REASK_2", + "LLM_OUTPUT", + "LLM_OUTPUT_REASK_1", + "LLM_OUTPUT_REASK_2", + "RAIL_SPEC_WITH_REASK", + "VALIDATED_OUTPUT_REASK_1", + "VALIDATED_OUTPUT_REASK_2", + "VALIDATED_OUTPUT_REASK_3", +] diff --git a/tests/integration_tests/test_cases/pydantic/compiled_prompt.txt b/tests/integration_tests/test_cases/pydantic/compiled_prompt.txt new file mode 100644 index 000000000..ccd8ff700 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/compiled_prompt.txt @@ -0,0 +1,28 @@ + +Generate data for possible users in accordance with the specification below. + + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + + + + + + + + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + + + + + + +ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. + +Here are examples of simple (XML, JSON) pairs that show the expected behavior: +- `` => `{'foo': 'example one'}` +- `` => `{"bar": ['STRING ONE', 'STRING TWO', etc.]}` +- `` => `{'baz': {'foo': 'Some String', 'index': 1}}` + +JSON Object: \ No newline at end of file diff --git a/tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_1.txt b/tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_1.txt new file mode 100644 index 000000000..81943e37e --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_1.txt @@ -0,0 +1,33 @@ + +I was given the following JSON response, which had problems due to incorrect values. + +{ + "people": [ + { + "name": "John Doe", + "age": 28, + "zip_code": { + "incorrect_value": "90210", + "error_message": "Zip code must not be Beverly Hills." + } + } + ] +} + +Help me correct the incorrect values based on the given error messages. + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + + + + + + +ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `None`. + +Here are examples of simple (XML, JSON) pairs that show the expected behavior: +- `` => `{{'foo': 'example one'}}` +- `` => `{{"bar": ['STRING ONE', 'STRING TWO', etc.]}}` +- `` => `{{'baz': {{'foo': 'Some String', 'index': 1}}}}` + +JSON Object: \ No newline at end of file diff --git a/tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_2.txt b/tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_2.txt new file mode 100644 index 000000000..fe14da498 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/compiled_prompt_reask_2.txt @@ -0,0 +1,33 @@ + +I was given the following JSON response, which had problems due to incorrect values. + +{ + "people": [ + { + "name": "John Doe", + "age": 28, + "zip_code": { + "incorrect_value": "None", + "error_message": "Zip code must be numeric." + } + } + ] +} + +Help me correct the incorrect values based on the given error messages. + +Given below is XML that describes the information to extract from this document and the tags to extract it into. + + + + + + +ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `None`. + +Here are examples of simple (XML, JSON) pairs that show the expected behavior: +- `` => `{{'foo': 'example one'}}` +- `` => `{{"bar": ['STRING ONE', 'STRING TWO', etc.]}}` +- `` => `{{'baz': {{'foo': 'Some String', 'index': 1}}}}` + +JSON Object: \ No newline at end of file diff --git a/tests/integration_tests/test_cases/pydantic/llm_output.txt b/tests/integration_tests/test_cases/pydantic/llm_output.txt new file mode 100644 index 000000000..0206872a3 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/llm_output.txt @@ -0,0 +1,20 @@ + +{ + "people": [ + { + "name": "John Doe", + "age": 28, + "zip_code": "90210" + }, + { + "name": "Jane Doe", + "age": 32, + "zip_code": "94103" + }, + { + "name": "James Smith", + "age": 40, + "zip_code": "92101" + } + ] +} \ No newline at end of file diff --git a/tests/integration_tests/test_cases/pydantic/llm_output_reask_1.txt b/tests/integration_tests/test_cases/pydantic/llm_output_reask_1.txt new file mode 100644 index 000000000..1152da784 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/llm_output_reask_1.txt @@ -0,0 +1,2 @@ + +{"people": [{"name": "John Doe", "age": 28, "zip_code": "None"}]} \ No newline at end of file diff --git a/tests/integration_tests/test_cases/pydantic/llm_output_reask_2.txt b/tests/integration_tests/test_cases/pydantic/llm_output_reask_2.txt new file mode 100644 index 000000000..5e2eb48c2 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/llm_output_reask_2.txt @@ -0,0 +1,10 @@ + +{ + "people": [ + { + "name": "John Doe", + "age": 28, + "zip_code": "None" + } + ] +} \ No newline at end of file diff --git a/tests/integration_tests/test_cases/pydantic/reask.rail b/tests/integration_tests/test_cases/pydantic/reask.rail new file mode 100644 index 000000000..56fe88137 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/reask.rail @@ -0,0 +1,60 @@ + + + + + + + + + + + + + +Generate data for possible users in accordance with the specification below. + +@xml_prefix_prompt + +{output_schema} + +@complete_json_suffix_v2 + + diff --git a/tests/integration_tests/test_cases/pydantic/validated_response_reask_1.py b/tests/integration_tests/test_cases/pydantic/validated_response_reask_1.py new file mode 100644 index 000000000..b3359c674 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/validated_response_reask_1.py @@ -0,0 +1,59 @@ +# flake8: noqa: E501 +from pydantic import BaseModel, validator + +from guardrails.utils.pydantic_utils import register_pydantic +from guardrails.utils.reask_utils import ReAsk + + +@register_pydantic +class Person(BaseModel): + """ + Information about a person. + + Args: + name (str): The name of the person. + age (int): The age of the person. + zip_code (str): The zip code of the person. + """ + + name: str + age: int + zip_code: str + + @validator("zip_code") + def zip_code_must_be_numeric(cls, v): + if not v.isnumeric(): + raise ValueError("Zip code must be numeric.") + return v + + @validator("age") + def age_must_be_between_0_and_150(cls, v): + if not 0 <= v <= 150: + raise ValueError("Age must be between 0 and 150.") + return v + + @validator("zip_code") + def zip_code_in_california(cls, v): + if not v.startswith("9"): + raise ValueError("Zip code must be in California, and start with 9.") + if v == "90210": + raise ValueError("Zip code must not be Beverly Hills.") + return v + + +VALIDATED_OUTPUT = { + "people": [ + { + "name": "John Doe", + "age": 28, + "zip_code": ReAsk( + incorrect_value="90210", + error_message="Zip code must not be Beverly Hills.", + fix_value=None, + path=["people", 0], + ), + }, + Person(name="Jane Doe", age=32, zip_code="94103"), + Person(name="James Smith", age=40, zip_code="92101"), + ] +} diff --git a/tests/integration_tests/test_cases/pydantic/validated_response_reask_2.py b/tests/integration_tests/test_cases/pydantic/validated_response_reask_2.py new file mode 100644 index 000000000..5e2fc8509 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/validated_response_reask_2.py @@ -0,0 +1,59 @@ +# flake8: noqa: E501 +from pydantic import BaseModel, validator + +from guardrails.utils.pydantic_utils import register_pydantic +from guardrails.utils.reask_utils import ReAsk + + +@register_pydantic +class Person(BaseModel): + """ + Information about a person. + + Args: + name (str): The name of the person. + age (int): The age of the person. + zip_code (str): The zip code of the person. + """ + + name: str + age: int + zip_code: str + + @validator("zip_code") + def zip_code_must_be_numeric(cls, v): + if not v.isnumeric(): + raise ValueError("Zip code must be numeric.") + return v + + @validator("age") + def age_must_be_between_0_and_150(cls, v): + if not 0 <= v <= 150: + raise ValueError("Age must be between 0 and 150.") + return v + + @validator("zip_code") + def zip_code_in_california(cls, v): + if not v.startswith("9"): + raise ValueError("Zip code must be in California, and start with 9.") + if v == "90210": + raise ValueError("Zip code must not be Beverly Hills.") + return v + + +VALIDATED_OUTPUT = { + "people": [ + { + "name": "John Doe", + "age": 28, + "zip_code": ReAsk( + incorrect_value="None", + error_message="Zip code must be numeric.", + fix_value=None, + path=["people", 0], + ), + }, + Person(name="Jane Doe", age=32, zip_code="94103"), + Person(name="James Smith", age=40, zip_code="92101"), + ] +} diff --git a/tests/integration_tests/test_cases/pydantic/validated_response_reask_3.py b/tests/integration_tests/test_cases/pydantic/validated_response_reask_3.py new file mode 100644 index 000000000..8f833e285 --- /dev/null +++ b/tests/integration_tests/test_cases/pydantic/validated_response_reask_3.py @@ -0,0 +1,50 @@ +# flake8: noqa: E501 + +from pydantic import BaseModel, validator + +from guardrails.utils.pydantic_utils import register_pydantic + + +@register_pydantic +class Person(BaseModel): + """ + Information about a person. + + Args: + name (str): The name of the person. + age (int): The age of the person. + zip_code (str): The zip code of the person. + """ + + name: str + age: int + zip_code: str + + @validator("zip_code") + def zip_code_must_be_numeric(cls, v): + if not v.isnumeric(): + raise ValueError("Zip code must be numeric.") + return v + + @validator("age") + def age_must_be_between_0_and_150(cls, v): + if not 0 <= v <= 150: + raise ValueError("Age must be between 0 and 150.") + return v + + @validator("zip_code") + def zip_code_in_california(cls, v): + if not v.startswith("9"): + raise ValueError("Zip code must be in California, and start with 9.") + if v == "90210": + raise ValueError("Zip code must not be Beverly Hills.") + return v + + +VALIDATED_OUTPUT = { + "people": [ + {"name": "John Doe", "age": 28, "zip_code": None}, + Person(name="Jane Doe", age=32, zip_code="94103"), + Person(name="James Smith", age=40, zip_code="92101"), + ] +} diff --git a/tests/integration_tests/test_pydantic.py b/tests/integration_tests/test_pydantic.py new file mode 100644 index 000000000..bde14aa1b --- /dev/null +++ b/tests/integration_tests/test_pydantic.py @@ -0,0 +1,44 @@ +import openai + +import guardrails as gd + +from .mock_llm_outputs import openai_completion_create, pydantic + + +def test_pydantic_with_reask(mocker): + """Test that the entity extraction works with re-asking.""" + mocker.patch( + "guardrails.llm_providers.openai_wrapper", new=openai_completion_create + ) + + guard = gd.Guard.from_rail_string(pydantic.RAIL_SPEC_WITH_REASK) + _, final_output = guard( + openai.Completion.create, + engine="text-davinci-003", + max_tokens=512, + temperature=0.5, + num_reasks=2, + ) + + # Assertions are made on the guard state object. + assert final_output == pydantic.VALIDATED_OUTPUT_REASK_3 + + guard_history = guard.guard_state.most_recent_call.history + + # Check that the guard state object has the correct number of re-asks. + assert len(guard_history) == 3 + + # For orginal prompt and output + assert guard_history[0].prompt == pydantic.COMPILED_PROMPT + assert guard_history[0].output == pydantic.LLM_OUTPUT + assert guard_history[0].validated_output == pydantic.VALIDATED_OUTPUT_REASK_1 + + # For re-asked prompt and output + assert guard_history[1].prompt == pydantic.COMPILED_PROMPT_REASK_1 + assert guard_history[1].output == pydantic.LLM_OUTPUT_REASK_1 + assert guard_history[1].validated_output == pydantic.VALIDATED_OUTPUT_REASK_2 + + # For re-asked prompt #2 and output #2 + assert guard_history[2].prompt == pydantic.COMPILED_PROMPT_REASK_2 + assert guard_history[2].output == pydantic.LLM_OUTPUT_REASK_2 + assert guard_history[2].validated_output == pydantic.VALIDATED_OUTPUT_REASK_3