Skip to content

Commit

Permalink
Merge pull request #1068 from guardrails-ai/async-validator-service
Browse files Browse the repository at this point in the history
AsyncValidatorService Update
  • Loading branch information
dtam authored Sep 17, 2024
2 parents ab12701 + cfff6e1 commit 2ec1ee5
Show file tree
Hide file tree
Showing 32 changed files with 3,581 additions and 1,683 deletions.
137 changes: 130 additions & 7 deletions docs/concepts/parallelization.md → docs/concepts/concurrency.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Parallelization
# Concurrency
## And the Orchestration of Guard Executions

This document is a description of the current implementation of the Guardrails' validation loop. It attempts to explain the current patterns used with some notes on why those patterns were accepted at the time of implementation and potential future optimizations. It is _not_ meant to be prescriptive as there can, and will, be improvements made in future versions.

In general you will find that our approach to performance is two fold:
1. Complete computationally cheaper, static checks first and exit early to avoid spending time and resources on more expensive checks that are unlikely to pass when the former fail.
2. Parallelize processing where possible.
2. Run processes concurrently where possible.

## Background: The Validation Loop
When a Guard is executed, that is called via `guard()`, `guard.parse()`, `guard.validate()`, etc., it goes through an internal process that has the following steps:
Expand Down Expand Up @@ -60,7 +60,7 @@ Besides handling asynchronous calls to the LLM, using an `AsyncGuard` also ensur
* An asyncio event loop is available.
* The asyncio event loop is not taken/already running.
## Validation Orchestration and Parallelization
## Validation Orchestration and Concurrency
### Structured Data Validation
We perform validation with a "deep-first" approach. This has no meaning for unstructured text output since there is only one value, but for structured output it means that the objects are validated from the inside out.
Expand All @@ -79,7 +79,7 @@ Take the below structure as an example:
}
```
As of versions v0.4.x and v0.5.x of Guardrails, the above object would validated as follows:
As of versions v0.4.x and v0.5.x of Guardrails, the above object would be validated as follows:
1. foo.baz
2. foo.bez
Expand All @@ -88,11 +88,134 @@ As of versions v0.4.x and v0.5.x of Guardrails, the above object would validated
5. bar.buz
6. bar
> NOTE: The approach currently used, and outlined above, was predicated on the assumption that if child properties fail validation, it is unlikely that the parent property would pass. With the current atomic state of validation, it can be argued that this assumption is false. That is, the types of validations applied to parent properties typically take the form of checking the appropriate format of the container like a length check on a list. These types of checks are generally independent of any requirements the child properties have. This opens up the possibility of running all six paths listed above in parallel at once instead of performing them in steps based on key path.
> NOTE: The approach currently used, and outlined above, was predicated on the assumption that if child properties fail validation, it is unlikely that the parent property would pass. With the current atomic state of validation, it can be argued that this assumption is false. That is, the types of validations applied to parent properties typically take the form of checking the appropriate format of the container like a length check on a list. These types of checks are generally independent of any requirements the child properties have. This opens up the possibility of running all six paths listed above concurrently instead of performing them in steps based on key path.
When synchronous validation occurs as defined in [Benefits of AsyncGuard](#benefits-of-async-guard), the validators for each property would be run in the order they are defined on the schema. That also means that any on fail actions are applied in that same order.
When asynchronous validation occurs, there are multiple levels of parallelization possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen in parallel via the asyncio event loop. Second, within the validation for each property, if the validators have `run_in_separate_process` set to `True`, they are run in parallel via multiprocessing. This multiprocessing is capped to the process count specified by the `GUARDRAILS_PROCESS_COUNT` environment variable which defaults to 10. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to set this environment variable to 1.
When asynchronous validation occurs, there are multiple levels of concurrency possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen concurrently via the asyncio event loop. Second, the validators on any given property are also run concurrently via the event loop. For validators that only define a synchronous `validate` method, calls to this method are run in the event loops default executor. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to either set the executor to a thread processor instead or limit validation to running synchronously by setting `GUARDRAILS_PROCESS_COUNT=1` or `GUARDRAILS_RUN_SYNC=true`.
### Unstructured Data Validation
When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run in parallel utilizing multiprocessing when `run_in_separate_process` is set to `True` on the validators.
When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run concurrently utilizing the event loop.
### Handling Failures During Async Concurrency
The Guardrails validation loop is opinionated about how it handles failures when running validators concurrently so that it spends the least amount of time processing an output that would result in a failure. It's behavior comes down to when and what it returns based on the [corrective action](/how_to_guides/custom_validators#on-fail) specified on a validator. Corrective actions are processed concurrently since they are specific to a given validator on a given property. This means that interruptive corrective actions, namely `EXCEPTION`, will be the first corrective action enforced because the exception is raised as soon as the failure is evaluated. The remaining actions are handled in the following order after all futures are collected from the validation of a specific property:
1. `FILTER` and `REFRAIN`
2. `REASK`
3. `FIX`
\*_NOTE:_ `NOOP` Does not require any special handling because it does not alter the value.
\*_NOTE:_ `FIX_REASK` Will fall into either the `REASK` or `FIX` bucket based on if the fixed value passes the second round of validation.
This means that if any validator with `on_fail=OnFailAction.EXCEPTION` returns a `FailResult`, then Guardrails will raise a `ValidationError` interrupting the process.
If any validator on a specific property which has `on_fail=OnFailAction.FILTER` or `on_fail=OnFailAction.REFRAIN` returns a `FailResult`, whichever of these is the first to finish will the returned early as the value for that property,
If any validator on a specific property which has `on_fail=OnFailAction.REASK` returns a `FailResult`, all reasks for that property will be merged and a `FieldReAsk` will be returned early as the value for that property.
If any validator on a specific property which has `on_fail=OnFailAction.FIX` returns a `FailResult`, all fix values for that property will be merged and the result of that merge will be returned as the value for that property.
Custom on_fail handlers will fall into one of the above actions based on what it returns; i.e. if it returns an updated value it's considered a `FIX`, if it returns an instance of `Filter` then `FILTER`, etc..
Let's look at an example. We'll keep the validation logic simple and write out some assertions to demonstrate the evaluation order discussed above.
```py
import asyncio
from random import randint
from typing import Optional
from guardrails import AsyncGuard, ValidationOutcome
from guardrails.errors import ValidationError
from guardrails.validators import (
Validator,
register_validator,
ValidationResult,
PassResult,
FailResult
)
@register_validator(name='custom/contains', data_type='string')
class Contains(Validator):
def __init__(self, match_value: str, **kwargs):
super().__init__(
match_value=match_value,
**kwargs
)
self.match_value = match_value
def validate(self, value, metadata = {}) -> ValidationResult:
if self.match_value in value:
return PassResult()
fix_value = None
if self.on_fail_descriptor == 'fix':
# Insert the match_value into the value at a random index
insertion = randint(0, len(value))
fix_value = f"{value[:insertion]}{self.match_value}{value[insertion:]}"
return FailResult(
error_message=f'Value must contain {self.match_value}',
fix_value=fix_value
)
exception_validator = Contains("a", on_fail='exception')
filter_validator = Contains("b", on_fail='filter')
refrain_validator = Contains("c", on_fail='refrain')
reask_validator_1 = Contains("d", on_fail='reask')
reask_validator_2 = Contains("e", on_fail='reask')
fix_validator_1 = Contains("f", on_fail='fix')
fix_validator_2 = Contains("g", on_fail='fix')
guard = AsyncGuard().use_many(
exception_validator,
filter_validator,
refrain_validator,
reask_validator_1,
reask_validator_2,
fix_validator_1,
fix_validator_2
)
### Trigger the exception validator ###
error = None
result: Optional[ValidationOutcome] = None
try:
result = asyncio.run(guard.validate("z", metadata={}))
except ValidationError as e:
error = e
assert result is None
assert error is not None
assert str(error) == "Validation failed for field with errors: Value must contain a"
### Trigger the Filter and Refrain validators ###
result = asyncio.run(guard.validate("a", metadata={}))
assert result.validation_passed is False
# The output was filtered or refrained
assert result.validated_output is None
assert result.reask is None
### Trigger the Reask validator ###
result = asyncio.run(guard.validate("abc", metadata={}))
assert result.validation_passed is False
# If allowed, a ReAsk would have occured
assert result.reask is not None
error_messages = [f.error_message for f in result.reask.fail_results]
assert error_messages == ["Value must contain d", "Value must contain e"]
### Trigger the Fix validator ###
result = asyncio.run(guard.validate("abcde", metadata={}))
assert result.validation_passed is True
# The fix values have been merged
assert "f" in result.validated_output
assert "g" in result.validated_output
print(result.validated_output)
```
2 changes: 1 addition & 1 deletion docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You can override the `fix` behavior by passing it as a function to the Guard obj
```python
from guardrails import Guard

def fix_is_cake(value, metadata):
def fix_is_cake(value, fail_result: FailResult):
return "IT IS cake"

guard = Guard().use(is_cake, on_fail=fix_is_cake)
Expand Down
1 change: 1 addition & 0 deletions docs/how_to_guides/custom_validators.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Validators ship with several out of the box `on_fail` policies. The `OnFailActio
| `OnFailAction.NOOP` | Do nothing. The failure will still be recorded in the logs, but no corrective action will be taken. |
| `OnFailAction.EXCEPTION` | Raise an exception when validation fails. |
| `OnFailAction.FIX_REASK` | First, fix the generated output deterministically, and then rerun validation with the deterministically fixed output. If validation fails, then perform reasking. |
| `OnFailAction.CUSTOM` | This action is set internally when the validator is passed a custom function to handle failures. The function is called with the value that failed validation and the FailResult returned from the Validator. i.e. the custom on fail handler must implement the method signature `def on_fail(value: Any, fail_result: FailResult) -> Any` |

In the code below, a `fix_value` will be supplied in the `FailResult`. This value will represent a programmatic fix that can be applied to the output if `on_fail='fix'` is passed during validator initialization.
```py
Expand Down
2 changes: 1 addition & 1 deletion docusaurus/sidebars.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ const sidebars = {
"concepts/streaming_fixes",
],
},
"concepts/parallelization",
"concepts/concurrency",
"concepts/logs",
"concepts/telemetry",
"concepts/error_remediation",
Expand Down
2 changes: 2 additions & 0 deletions guardrails/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from guardrails.validator_base import Validator, register_validator
from guardrails.settings import settings
from guardrails.hub.install import install
from guardrails.classes.validation_outcome import ValidationOutcome

__all__ = [
"Guard",
Expand All @@ -25,4 +26,5 @@
"Instructions",
"settings",
"install",
"ValidationOutcome",
]
21 changes: 21 additions & 0 deletions guardrails/classes/generic/default_json_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datetime import datetime
from dataclasses import asdict, is_dataclass
from pydantic import BaseModel
from json import JSONEncoder


class DefaultJSONEncoder(JSONEncoder):
def default(self, o):
if hasattr(o, "to_dict"):
return o.to_dict()
elif isinstance(o, BaseModel):
return o.model_dump()
elif is_dataclass(o):
return asdict(o)
elif isinstance(o, set):
return list(o)
elif isinstance(o, datetime):
return o.isoformat()
elif hasattr(o, "__dict__"):
return o.__dict__
return super().default(o)
12 changes: 9 additions & 3 deletions guardrails/merge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SOURCE: https://github.com/spyder-ide/three-merge/blob/master/three_merge/merge.py
from typing import Optional
from diff_match_patch import diff_match_patch

# Constants
Expand All @@ -10,7 +11,12 @@
ADDITION = 1


def merge(source: str, target: str, base: str) -> str:
def merge(
source: Optional[str], target: Optional[str], base: Optional[str]
) -> Optional[str]:
if source is None or target is None or base is None:
return None

diff1_l = DIFFER.diff_main(base, source)
diff2_l = DIFFER.diff_main(base, target)

Expand Down Expand Up @@ -75,7 +81,7 @@ def merge(source: str, target: str, base: str) -> str:
invariant = ""
target = (target_status, target_text) # type: ignore
if advance:
prev_source_text = source[1]
prev_source_text = source[1] # type: ignore
source = next(diff1, None) # type: ignore
elif len(source_text) < len(target_text):
# Addition performed by source
Expand Down Expand Up @@ -119,7 +125,7 @@ def merge(source: str, target: str, base: str) -> str:
invariant = ""
source = (source_status, source_text) # type: ignore
if advance:
prev_target_text = target[1]
prev_target_text = target[1] # type: ignore
target = next(diff2, None) # type: ignore
else:
# Source and target are equal
Expand Down
3 changes: 2 additions & 1 deletion guardrails/telemetry/open_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional

from guardrails.telemetry.common import get_span, serialize, to_dict
from guardrails.telemetry.common import get_span, to_dict
from guardrails.utils.serialization_utils import serialize


def trace_operation(
Expand Down
3 changes: 2 additions & 1 deletion guardrails/telemetry/runner_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from guardrails.classes.output_type import OT
from guardrails.classes.validation_outcome import ValidationOutcome
from guardrails.stores.context import get_guard_name
from guardrails.telemetry.common import get_tracer, serialize
from guardrails.telemetry.common import get_tracer
from guardrails.utils.safe_get import safe_get
from guardrails.utils.serialization_utils import serialize
from guardrails.version import GUARDRAILS_VERSION


Expand Down
66 changes: 65 additions & 1 deletion guardrails/telemetry/validator_tracing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import wraps
from typing import (
Any,
Awaitable,
Callable,
Dict,
Optional,
Expand All @@ -12,10 +13,11 @@

from guardrails.settings import settings
from guardrails.classes.validation.validation_result import ValidationResult
from guardrails.telemetry.common import get_tracer, serialize
from guardrails.telemetry.common import get_tracer
from guardrails.telemetry.open_inference import trace_operation
from guardrails.utils.casting_utils import to_string
from guardrails.utils.safe_get import safe_get
from guardrails.utils.serialization_utils import serialize
from guardrails.version import GUARDRAILS_VERSION


Expand Down Expand Up @@ -138,3 +140,65 @@ def trace_validator_wrapper(*args, **kwargs):
return trace_validator_wrapper

return trace_validator_decorator


def trace_async_validator(
validator_name: str,
obj_id: int,
on_fail_descriptor: Optional[str] = None,
tracer: Optional[Tracer] = None,
*,
validation_session_id: str,
**init_kwargs,
):
def trace_validator_decorator(
fn: Callable[..., Awaitable[Optional[ValidationResult]]],
):
@wraps(fn)
async def trace_validator_wrapper(*args, **kwargs):
if not settings.disable_tracing:
current_otel_context = context.get_current()
_tracer = get_tracer(tracer) or trace.get_tracer(
"guardrails-ai", GUARDRAILS_VERSION
)
validator_span_name = f"{validator_name}.validate"
with _tracer.start_as_current_span(
name=validator_span_name, # type: ignore
context=current_otel_context, # type: ignore
) as validator_span:
try:
resp = await fn(*args, **kwargs)
add_validator_attributes(
*args,
validator_span=validator_span,
validator_name=validator_name,
obj_id=obj_id,
on_fail_descriptor=on_fail_descriptor,
result=resp,
init_kwargs=init_kwargs,
validation_session_id=validation_session_id,
**kwargs,
)
return resp
except Exception as e:
validator_span.set_status(
status=StatusCode.ERROR, description=str(e)
)
add_validator_attributes(
*args,
validator_span=validator_span,
validator_name=validator_name,
obj_id=obj_id,
on_fail_descriptor=on_fail_descriptor,
result=None,
init_kwargs=init_kwargs,
validation_session_id=validation_session_id,
**kwargs,
)
raise e
else:
return await fn(*args, **kwargs)

return trace_validator_wrapper

return trace_validator_decorator
Loading

0 comments on commit 2ec1ee5

Please sign in to comment.