From 5df8ed2897d4a62c5fb2c2cc611efc490484e803 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Tue, 10 Sep 2024 15:17:03 -0500 Subject: [PATCH 01/13] refactor validator service classes, run validators in event loop, merge at end --- .../classes/generic/default_json_encoder.py | 21 + guardrails/telemetry/open_inference.py | 3 +- guardrails/telemetry/runner_tracing.py | 3 +- guardrails/telemetry/validator_tracing.py | 3 +- guardrails/utils/serialization_utils.py | 41 + guardrails/validator_service/__init__.py | 143 ++++ .../async_validator_service.py | 267 ++++++ .../sequential_validator_service.py} | 571 +------------ .../validator_service_base.py | 210 +++++ poetry.lock | 44 + pyproject.toml | 4 + tests/conftest.py | 4 +- .../test_async_validator_service.py | 345 -------- tests/unit_tests/test_validator_service.py | 97 --- .../utils/test_serialization_utils.py | 169 ++++ .../test_async_validator_service.py | 798 ++++++++++++++++++ .../test_validator_service.py | 152 ++++ 17 files changed, 1863 insertions(+), 1012 deletions(-) create mode 100644 guardrails/classes/generic/default_json_encoder.py create mode 100644 guardrails/utils/serialization_utils.py create mode 100644 guardrails/validator_service/__init__.py create mode 100644 guardrails/validator_service/async_validator_service.py rename guardrails/{validator_service.py => validator_service/sequential_validator_service.py} (50%) create mode 100644 guardrails/validator_service/validator_service_base.py delete mode 100644 tests/unit_tests/test_async_validator_service.py delete mode 100644 tests/unit_tests/test_validator_service.py create mode 100644 tests/unit_tests/utils/test_serialization_utils.py create mode 100644 tests/unit_tests/validator_service/test_async_validator_service.py create mode 100644 tests/unit_tests/validator_service/test_validator_service.py diff --git a/guardrails/classes/generic/default_json_encoder.py b/guardrails/classes/generic/default_json_encoder.py new file mode 100644 index 000000000..1319c0cd5 --- /dev/null +++ b/guardrails/classes/generic/default_json_encoder.py @@ -0,0 +1,21 @@ +from datetime import datetime +from dataclasses import asdict, is_dataclass +from pydantic import BaseModel +from json import JSONEncoder + + +class DefaultJSONEncoder(JSONEncoder): + def default(self, o): + if hasattr(o, "to_dict"): + return o.to_dict() + elif isinstance(o, BaseModel): + return o.model_dump() + elif is_dataclass(o): + return asdict(o) + elif isinstance(o, set): + return list(o) + elif isinstance(o, datetime): + return o.isoformat() + elif hasattr(o, "__dict__"): + return o.__dict__ + return super().default(o) diff --git a/guardrails/telemetry/open_inference.py b/guardrails/telemetry/open_inference.py index 2fc21649d..266c36c2c 100644 --- a/guardrails/telemetry/open_inference.py +++ b/guardrails/telemetry/open_inference.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List, Optional -from guardrails.telemetry.common import get_span, serialize, to_dict +from guardrails.telemetry.common import get_span, to_dict +from guardrails.utils.serialization_utils import serialize def trace_operation( diff --git a/guardrails/telemetry/runner_tracing.py b/guardrails/telemetry/runner_tracing.py index 9bcba353f..1b3bb346f 100644 --- a/guardrails/telemetry/runner_tracing.py +++ b/guardrails/telemetry/runner_tracing.py @@ -17,8 +17,9 @@ from guardrails.classes.output_type import OT from guardrails.classes.validation_outcome import ValidationOutcome from guardrails.stores.context import get_guard_name -from guardrails.telemetry.common import get_tracer, serialize +from guardrails.telemetry.common import get_tracer from guardrails.utils.safe_get import safe_get +from guardrails.utils.serialization_utils import serialize from guardrails.version import GUARDRAILS_VERSION diff --git a/guardrails/telemetry/validator_tracing.py b/guardrails/telemetry/validator_tracing.py index 502e670f3..3d2904327 100644 --- a/guardrails/telemetry/validator_tracing.py +++ b/guardrails/telemetry/validator_tracing.py @@ -12,10 +12,11 @@ from guardrails.settings import settings from guardrails.classes.validation.validation_result import ValidationResult -from guardrails.telemetry.common import get_tracer, serialize +from guardrails.telemetry.common import get_tracer from guardrails.telemetry.open_inference import trace_operation from guardrails.utils.casting_utils import to_string from guardrails.utils.safe_get import safe_get +from guardrails.utils.serialization_utils import serialize from guardrails.version import GUARDRAILS_VERSION diff --git a/guardrails/utils/serialization_utils.py b/guardrails/utils/serialization_utils.py new file mode 100644 index 000000000..18e435374 --- /dev/null +++ b/guardrails/utils/serialization_utils.py @@ -0,0 +1,41 @@ +from datetime import datetime +import json +from typing import Any, Optional +import warnings + +from guardrails.classes.generic.default_json_encoder import DefaultJSONEncoder + + +# TODO: What other common cases we should consider? +def serialize(val: Any) -> Optional[str]: + try: + return json.dumps(val, cls=DefaultJSONEncoder) + except Exception as e: + warnings.warn(e) + return None + + +# We want to do the oppisite of what we did in the DefaultJSONEncoder +# TODO: What's a good way to expose a configurable API for this? +# Do we wrap JSONDecoder with an extra layer to supply the original object? +def deserialize(original: Optional[Any], serialized: Optional[str]) -> Any: + try: + if original is None or serialized is None: + return None + + loaded_val = json.loads(serialized) + if isinstance(original, datetime): + return datetime.fromisoformat(loaded_val) + elif isinstance(original, set): + return set(original) + elif hasattr(original, "__class__"): + # TODO: Handle nested classes + # NOTE: nested pydantic classes already work + if isinstance(loaded_val, dict): + return original.__class__(**loaded_val) + elif isinstance(loaded_val, list): + return original.__class__(loaded_val) + return loaded_val + except Exception as e: + warnings.warn(e) + return None diff --git a/guardrails/validator_service/__init__.py b/guardrails/validator_service/__init__.py new file mode 100644 index 000000000..416975294 --- /dev/null +++ b/guardrails/validator_service/__init__.py @@ -0,0 +1,143 @@ +import asyncio +import os +from typing import Any, Iterable, Optional, Tuple +import warnings + +from guardrails.actions.filter import apply_filters +from guardrails.actions.refrain import apply_refrain +from guardrails.classes.history import Iteration +from guardrails.classes.output_type import OutputTypes +from guardrails.classes.validation.validation_result import ( + StreamValidationResult, +) +from guardrails.types import ValidatorMap +from guardrails.telemetry.legacy_validator_tracing import trace_validation_result +from guardrails.validator_service.async_validator_service import AsyncValidatorService +from guardrails.validator_service.sequential_validator_service import ( + SequentialValidatorService, +) + + +try: + import uvloop # type: ignore +except ImportError: + uvloop = None + + +def should_run_sync(): + process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT")) + if process_count is not None: + warnings.warn( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation, please use GUARDRAILS_RUN_SYNC instead.", + DeprecationWarning, + ) + run_sync = os.environ.get("GUARDRAILS_RUN_SYNC", "false") + bool_values = ["true", "false"] + if run_sync.lower() not in bool_values: + warnings.warn( + f"GUARDRAILS_RUN_SYNC must be one of {bool_values}!" + f" Defaulting to 'false'." + ) + return process_count == 1 or run_sync.lower() == "true" + + +def get_loop() -> asyncio.AbstractEventLoop: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + if loop is not None: + raise RuntimeError("An event loop is already running.") + + if uvloop is not None: + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + return asyncio.get_event_loop() + + +def validate( + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + **kwargs, +): + if path is None: + path = "$" + + loop = None + if should_run_sync(): + validator_service = SequentialValidatorService(disable_tracer) + else: + try: + loop = get_loop() + validator_service = AsyncValidatorService(disable_tracer) + except RuntimeError: + warnings.warn( + "Could not obtain an event loop." + " Falling back to synchronous validation." + ) + validator_service = SequentialValidatorService(disable_tracer) + + return validator_service.validate( + value, metadata, validator_map, iteration, path, path, loop=loop, **kwargs + ) + + +def validate_stream( + value_stream: Iterable[Tuple[Any, bool]], + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + **kwargs, +) -> Iterable[StreamValidationResult]: + if path is None: + path = "$" + sequential_validator_service = SequentialValidatorService(disable_tracer) + gen = sequential_validator_service.validate_stream( + value_stream, metadata, validator_map, iteration, path, path, **kwargs + ) + return gen + + +async def async_validate( + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + disable_tracer: Optional[bool] = True, + path: Optional[str] = None, + stream: Optional[bool] = False, + **kwargs, +) -> Tuple[Any, dict]: + if path is None: + path = "$" + validator_service = AsyncValidatorService(disable_tracer) + return await validator_service.async_validate( + value, metadata, validator_map, iteration, path, path, stream, **kwargs + ) + + +def post_process_validation( + validation_response: Any, + attempt_number: int, + iteration: Iteration, + output_type: OutputTypes, +) -> Any: + validated_response = apply_refrain(validation_response, output_type) + + # Remove all keys that have `Filter` values. + validated_response = apply_filters(validated_response) + + trace_validation_result( + validation_logs=iteration.validator_logs, attempt_number=attempt_number + ) + + return validated_response diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py new file mode 100644 index 000000000..0a13c5e99 --- /dev/null +++ b/guardrails/validator_service/async_validator_service.py @@ -0,0 +1,267 @@ +import asyncio +from typing import Any, Awaitable, Coroutine, Dict, List, Optional, Tuple, Union, cast + +from guardrails.actions.filter import Filter +from guardrails.actions.refrain import Refrain +from guardrails.classes.history import Iteration +from guardrails.classes.validation.validation_result import ( + FailResult, + PassResult, + ValidationResult, +) +from guardrails.types import ValidatorMap, OnFailAction +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.actions.reask import FieldReAsk +from guardrails.validator_base import Validator +from guardrails.validator_service.validator_service_base import ( + ValidatorRun, + ValidatorServiceBase, +) + +ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] + + +class AsyncValidatorService(ValidatorServiceBase): + async def run_validator_async( + self, + validator: Validator, + value: Any, + metadata: Dict, + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + ) -> ValidationResult: + result: ValidatorResult = self.execute_validator( + validator, + value, + metadata, + stream, + validation_session_id=validation_session_id, + **kwargs, + ) + if asyncio.iscoroutine(result): + result = await result + + if result is None: + result = PassResult() + else: + result = cast(ValidationResult, result) + return result + + async def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + absolute_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorRun: + validator_logs = self.before_run_validator( + iteration, validator, value, absolute_property_path + ) + + result = await self.run_validator_async( + validator, + value, + metadata, + stream, + validation_session_id=iteration.id, + **kwargs, + ) + + validator_logs = self.after_run_validator(validator, validator_logs, result) + + if isinstance(result, FailResult): + rechecked_value = None + if validator.on_fail_descriptor == OnFailAction.FIX_REASK: + fixed_value = result.fix_value + rechecked_value = await self.run_validator_async( + validator, + fixed_value, + result.metadata or {}, + stream, + validation_session_id=iteration.id, + **kwargs, + ) + value = self.perform_correction( + result, + value, + validator, + rechecked_value=rechecked_value, + ) + + # handle overrides + # QUESTION: Should this consider the rechecked_value as well? + elif ( + isinstance(result, PassResult) + and result.value_override is not PassResult.ValueOverrideSentinel + ): + value = result.value_override + + validator_logs.value_after_validation = value + + return ValidatorRun( + value=value, + metadata=metadata, + validator_logs=validator_logs, + ) + + async def run_validators( + self, + iteration: Iteration, + validator_map: ValidatorMap, + value: Any, + metadata: Dict, + absolute_property_path: str, + reference_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ): + validators = validator_map.get(reference_property_path, []) + coroutines: List[Coroutine[Any, Any, ValidatorRun]] = [] + validators_logs: List[ValidatorLogs] = [] + for validator in validators: + coroutine: Coroutine[Any, Any, ValidatorRun] = self.run_validator( + iteration, + validator, + value, + metadata, + absolute_property_path, + stream=stream, + **kwargs, + ) + coroutines.append(coroutine) + + results = await asyncio.gather(*coroutines) + for res in results: + validators_logs.extend(res.validator_logs) + # QUESTION: Do we still want to do this here or handle it during the merge? + # return early if we have a filter, refrain, or reask + if isinstance(res.value, (Filter, Refrain, FieldReAsk)): + return res.value, metadata + + # merge the results + if len(results) > 0: + values = [res.value for res in results] + value = self.merge_results(value, values) + + return value, metadata + + async def validate_children( + self, + value: Any, + metadata: Dict, + validator_map: ValidatorMap, + iteration: Iteration, + abs_parent_path: str, + ref_parent_path: str, + stream: Optional[bool] = False, + **kwargs, + ): + async def validate_child( + child_value: Any, *, key: Optional[str] = None, index: Optional[int] = None + ): + child_key = key or index + abs_child_path = f"{abs_parent_path}.{child_key}" + ref_child_path = ref_parent_path + if key is not None: + ref_child_path = f"{ref_child_path}.{key}" + elif index is not None: + ref_child_path = f"{ref_child_path}.*" + new_child_value, new_metadata = await self.async_validate( + child_value, + metadata, + validator_map, + iteration, + abs_child_path, + ref_child_path, + stream=stream, + **kwargs, + ) + return child_key, new_child_value, new_metadata + + coroutines = [] + if isinstance(value, List): + for index, child in enumerate(value): + coroutines.append(validate_child(child, index=index)) + elif isinstance(value, Dict): + for key in value: + child = value.get(key) + coroutines.append(validate_child(child, key=key)) + + results = await asyncio.gather(*coroutines) + + for key, child_value, child_metadata in results: + value[key] = child_value + # TODO address conflicting metadata entries + metadata = {**metadata, **child_metadata} + + return value, metadata + + async def async_validate( + self, + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> Tuple[Any, dict]: + child_ref_path = reference_path.replace(".*", "") + # Validate children first + if isinstance(value, List) or isinstance(value, Dict): + await self.validate_children( + value, + metadata, + validator_map, + iteration, + absolute_path, + child_ref_path, + stream=stream, + **kwargs, + ) + + # Then validate the parent value + value, metadata = await self.run_validators( + iteration, + validator_map, + value, + metadata, + absolute_path, + reference_path, + stream=stream, + **kwargs, + ) + + return value, metadata + + def validate( + self, + value: Any, + metadata: dict, + validator_map: ValidatorMap, + iteration: Iteration, + absolute_path: str, + reference_path: str, + loop: asyncio.AbstractEventLoop, + stream: Optional[bool] = False, + **kwargs, + ) -> Tuple[Any, dict]: + value, metadata = loop.run_until_complete( + self.async_validate( + value, + metadata, + validator_map, + iteration, + absolute_path, + reference_path, + stream=stream, + **kwargs, + ) + ) + return value, metadata diff --git a/guardrails/validator_service.py b/guardrails/validator_service/sequential_validator_service.py similarity index 50% rename from guardrails/validator_service.py rename to guardrails/validator_service/sequential_validator_service.py index cc9f53cbb..3c5893d1c 100644 --- a/guardrails/validator_service.py +++ b/guardrails/validator_service/sequential_validator_service.py @@ -1,190 +1,22 @@ import asyncio -import itertools -import os -from concurrent.futures import ProcessPoolExecutor -from datetime import datetime -from typing import Any, Awaitable, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterable, List, Optional, Tuple, cast -from guardrails.actions.filter import Filter, apply_filters -from guardrails.actions.refrain import Refrain, apply_refrain +from guardrails.actions.filter import Filter +from guardrails.actions.refrain import Refrain from guardrails.classes.history import Iteration -from guardrails.classes.output_type import OutputTypes from guardrails.classes.validation.validation_result import ( FailResult, PassResult, StreamValidationResult, ValidationResult, ) -from guardrails.errors import ValidationError from guardrails.merge import merge from guardrails.types import ValidatorMap, OnFailAction from guardrails.utils.exception_utils import UserFacingException -from guardrails.utils.hub_telemetry_utils import HubTelemetry from guardrails.classes.validation.validator_logs import ValidatorLogs -from guardrails.actions.reask import FieldReAsk, ReAsk -from guardrails.telemetry.legacy_validator_tracing import trace_validation_result -from guardrails.telemetry import trace_validator +from guardrails.actions.reask import ReAsk from guardrails.validator_base import Validator - -ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] - - -def key_not_empty(key: str) -> bool: - return key is not None and len(str(key)) > 0 - - -class ValidatorServiceBase: - """Base class for validator services.""" - - def __init__(self, disable_tracer: Optional[bool] = True): - self._disable_tracer = disable_tracer - - # NOTE: This is avoiding an issue with multiprocessing. - # If we wrap the validate methods at the class level or anytime before - # loop.run_in_executor is called, multiprocessing fails with a Pickling error. - # This is a well known issue without any real solutions. - # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, - # but is relatively unsupported on Windows. - def execute_validator( - self, - validator: Validator, - value: Any, - metadata: Optional[Dict], - stream: Optional[bool] = False, - *, - validation_session_id: str, - **kwargs, - ) -> ValidatorResult: - validate_func = validator.validate_stream if stream else validator.validate - traced_validator = trace_validator( - validator_name=validator.rail_alias, - obj_id=id(validator), - on_fail_descriptor=validator.on_fail_descriptor, - validation_session_id=validation_session_id, - **validator._kwargs, - )(validate_func) - if stream: - result = traced_validator(value, metadata, **kwargs) - else: - result = traced_validator(value, metadata) - return result - - def perform_correction( - self, - results: List[FailResult], - value: Any, - validator: Validator, - on_fail_descriptor: Union[OnFailAction, str], - rechecked_value: Optional[ValidationResult] = None, - ): - if on_fail_descriptor == OnFailAction.FIX: - # FIXME: Should we still return fix_value if it is None? - # I think we should warn and return the original value. - return results[0].fix_value - elif on_fail_descriptor == OnFailAction.FIX_REASK: - # FIXME: Same thing here - fixed_value = results[0].fix_value - - if isinstance(rechecked_value, FailResult): - return FieldReAsk( - incorrect_value=fixed_value, - fail_results=results, - ) - - return fixed_value - if on_fail_descriptor == "custom": - if validator.on_fail_method is None: - raise ValueError("on_fail is 'custom' but on_fail_method is None") - return validator.on_fail_method(value, results) - if on_fail_descriptor == OnFailAction.REASK: - return FieldReAsk( - incorrect_value=value, - fail_results=results, - ) - if on_fail_descriptor == OnFailAction.EXCEPTION: - raise ValidationError( - "Validation failed for field with errors: " - + ", ".join([result.error_message for result in results]) - ) - if on_fail_descriptor == OnFailAction.FILTER: - return Filter() - if on_fail_descriptor == OnFailAction.REFRAIN: - return Refrain() - if on_fail_descriptor == OnFailAction.NOOP: - return value - else: - raise ValueError( - f"Invalid on_fail_descriptor {on_fail_descriptor}, " - f"expected 'fix' or 'exception'." - ) - - def before_run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - absolute_property_path: str, - ) -> ValidatorLogs: - validator_class_name = validator.__class__.__name__ - validator_logs = ValidatorLogs( - validator_name=validator_class_name, - value_before_validation=value, - registered_name=validator.rail_alias, - property_path=absolute_property_path, - # If we ever re-use validator instances across multiple properties, - # this will have to change. - instance_id=id(validator), - ) - iteration.outputs.validator_logs.append(validator_logs) - - start_time = datetime.now() - validator_logs.start_time = start_time - - return validator_logs - - def after_run_validator( - self, - validator: Validator, - validator_logs: ValidatorLogs, - result: Optional[ValidationResult], - ): - end_time = datetime.now() - validator_logs.validation_result = result - validator_logs.end_time = end_time - - if not self._disable_tracer: - # Get HubTelemetry singleton and create a new span to - # log the validator usage - _hub_telemetry = HubTelemetry() - _hub_telemetry.create_new_span( - span_name="/validator_usage", - attributes=[ - ("validator_name", validator.rail_alias), - ("validator_on_fail", validator.on_fail_descriptor), - ( - "validator_result", - result.outcome - if isinstance(result, ValidationResult) - else None, - ), - ], - is_parent=False, # This span will have no children - has_parent=True, # This span has a parent - ) - - return validator_logs - - def run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - metadata: Dict, - absolute_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> ValidatorLogs: - raise NotImplementedError +from guardrails.validator_service.validator_service_base import ValidatorServiceBase class SequentialValidatorService(ValidatorServiceBase): @@ -675,396 +507,3 @@ def validate_stream( **kwargs, ) return gen - - -class MultiprocMixin: - multiprocessing_executor: Optional[ProcessPoolExecutor] = None - process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - - def __init__(self): - if MultiprocMixin.multiprocessing_executor is None: - MultiprocMixin.multiprocessing_executor = ProcessPoolExecutor( - max_workers=MultiprocMixin.process_count - ) - - -class AsyncValidatorService(ValidatorServiceBase, MultiprocMixin): - async def run_validator_async( - self, - validator: Validator, - value: Any, - metadata: Dict, - stream: Optional[bool] = False, - *, - validation_session_id: str, - **kwargs, - ) -> ValidationResult: - result: ValidatorResult = self.execute_validator( - validator, - value, - metadata, - stream, - validation_session_id=validation_session_id, - **kwargs, - ) - if asyncio.iscoroutine(result): - result = await result - - if result is None: - result = PassResult() - else: - result = cast(ValidationResult, result) - return result - - async def run_validator( - self, - iteration: Iteration, - validator: Validator, - value: Any, - metadata: Dict, - absolute_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> ValidatorLogs: - validator_logs = self.before_run_validator( - iteration, validator, value, absolute_property_path - ) - - result = await self.run_validator_async( - validator, - value, - metadata, - stream, - validation_session_id=iteration.id, - **kwargs, - ) - - return self.after_run_validator(validator, validator_logs, result) - - def group_validators(self, validators: List[Validator]): - groups = itertools.groupby( - validators, key=lambda v: (v.on_fail_descriptor, v.override_value_on_pass) - ) - # NOTE: This isn't ordering anything. - # If we want to yield fix-like valiators first, - # then we need to extract them outside of the loop. - for (on_fail_descriptor, override_on_pass), group in groups: - if override_on_pass or on_fail_descriptor in [ - OnFailAction.FIX, - OnFailAction.FIX_REASK, - "custom", - ]: - for validator in group: - yield on_fail_descriptor, [validator] - else: - yield on_fail_descriptor, list(group) - - async def run_validators( - self, - iteration: Iteration, - validator_map: ValidatorMap, - value: Any, - metadata: Dict, - absolute_property_path: str, - reference_property_path: str, - stream: Optional[bool] = False, - **kwargs, - ): - loop = asyncio.get_running_loop() - validators = validator_map.get(reference_property_path, []) - for on_fail, validator_group in self.group_validators(validators): - parallel_tasks = [] - validators_logs: List[ValidatorLogs] = [] - for validator in validator_group: - if validator.run_in_separate_process: - # queue the validators to run in a separate process - parallel_tasks.append( - loop.run_in_executor( - self.multiprocessing_executor, - self.run_validator, - iteration, - validator, - value, - metadata, - absolute_property_path, - stream, - ) - ) - else: - # run the validators in the current process - result = await self.run_validator( - iteration, - validator, - value, - metadata, - absolute_property_path, - stream=stream, - **kwargs, - ) - validators_logs.append(result) - - # wait for the parallel tasks to finish - if parallel_tasks: - parallel_results = await asyncio.gather(*parallel_tasks) - awaited_results = [] - for res in parallel_results: - if asyncio.iscoroutine(res): - res = await res - awaited_results.append(res) - validators_logs.extend(awaited_results) - - # process the results, handle failures - fails = [ - logs - for logs in validators_logs - if isinstance(logs.validation_result, FailResult) - ] - if fails: - # NOTE: Ignoring type bc we know it's a FailResult - fail_results: List[FailResult] = [ - logs.validation_result # type: ignore - for logs in fails - ] - rechecked_value = None - validator: Validator = validator_group[0] - if validator.on_fail_descriptor == OnFailAction.FIX_REASK: - fixed_value = fail_results[0].fix_value - rechecked_value = await self.run_validator_async( - validator, - fixed_value, - fail_results[0].metadata or {}, - stream, - validation_session_id=iteration.id, - **kwargs, - ) - value = self.perform_correction( - fail_results, - value, - validator_group[0], - on_fail, - rechecked_value=rechecked_value, - ) - - # handle overrides - if ( - len(validator_group) == 1 - and validator_group[0].override_value_on_pass - and isinstance(validators_logs[0].validation_result, PassResult) - and validators_logs[0].validation_result.value_override - is not PassResult.ValueOverrideSentinel - ): - value = validators_logs[0].validation_result.value_override - - for logs in validators_logs: - logs.value_after_validation = value - - # return early if we have a filter, refrain, or reask - if isinstance(value, (Filter, Refrain, FieldReAsk)): - return value, metadata - - return value, metadata - - async def validate_children( - self, - value: Any, - metadata: Dict, - validator_map: ValidatorMap, - iteration: Iteration, - abs_parent_path: str, - ref_parent_path: str, - stream: Optional[bool] = False, - **kwargs, - ): - async def validate_child( - child_value: Any, *, key: Optional[str] = None, index: Optional[int] = None - ): - child_key = key or index - abs_child_path = f"{abs_parent_path}.{child_key}" - ref_child_path = ref_parent_path - if key is not None: - ref_child_path = f"{ref_child_path}.{key}" - elif index is not None: - ref_child_path = f"{ref_child_path}.*" - new_child_value, new_metadata = await self.async_validate( - child_value, - metadata, - validator_map, - iteration, - abs_child_path, - ref_child_path, - stream=stream, - **kwargs, - ) - return child_key, new_child_value, new_metadata - - tasks = [] - if isinstance(value, List): - for index, child in enumerate(value): - tasks.append(validate_child(child, index=index)) - elif isinstance(value, Dict): - for key in value: - child = value.get(key) - tasks.append(validate_child(child, key=key)) - - results = await asyncio.gather(*tasks) - - for key, child_value, child_metadata in results: - value[key] = child_value - # TODO address conflicting metadata entries - metadata = {**metadata, **child_metadata} - - return value, metadata - - async def async_validate( - self, - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - absolute_path: str, - reference_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any, dict]: - child_ref_path = reference_path.replace(".*", "") - # Validate children first - if isinstance(value, List) or isinstance(value, Dict): - await self.validate_children( - value, - metadata, - validator_map, - iteration, - absolute_path, - child_ref_path, - stream=stream, - **kwargs, - ) - - # Then validate the parent value - value, metadata = await self.run_validators( - iteration, - validator_map, - value, - metadata, - absolute_path, - reference_path, - stream=stream, - **kwargs, - ) - - return value, metadata - - def validate( - self, - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - absolute_path: str, - reference_path: str, - stream: Optional[bool] = False, - **kwargs, - ) -> Tuple[Any, dict]: - # Run validate_async in an async loop - loop = asyncio.get_event_loop() - if loop.is_running(): - raise RuntimeError( - "Async event loop found, please call `validate_async` instead." - ) - value, metadata = loop.run_until_complete( - self.async_validate( - value, - metadata, - validator_map, - iteration, - absolute_path, - reference_path, - stream=stream, - **kwargs, - ) - ) - return value, metadata - - -def validate( - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - disable_tracer: Optional[bool] = True, - path: Optional[str] = None, - **kwargs, -): - if path is None: - path = "$" - - process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT", 10)) - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = None - - if process_count == 1: - validator_service = SequentialValidatorService(disable_tracer) - elif loop is not None and not loop.is_running(): - validator_service = AsyncValidatorService(disable_tracer) - else: - validator_service = SequentialValidatorService(disable_tracer) - - return validator_service.validate( - value, metadata, validator_map, iteration, path, path, **kwargs - ) - - -def validate_stream( - value_stream: Iterable[Tuple[Any, bool]], - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - disable_tracer: Optional[bool] = True, - path: Optional[str] = None, - **kwargs, -) -> Iterable[StreamValidationResult]: - if path is None: - path = "$" - sequential_validator_service = SequentialValidatorService(disable_tracer) - gen = sequential_validator_service.validate_stream( - value_stream, metadata, validator_map, iteration, path, path, **kwargs - ) - return gen - - -async def async_validate( - value: Any, - metadata: dict, - validator_map: ValidatorMap, - iteration: Iteration, - disable_tracer: Optional[bool] = True, - path: Optional[str] = None, - stream: Optional[bool] = False, - **kwargs, -) -> Tuple[Any, dict]: - if path is None: - path = "$" - validator_service = AsyncValidatorService(disable_tracer) - return await validator_service.async_validate( - value, metadata, validator_map, iteration, path, path, stream, **kwargs - ) - - -def post_process_validation( - validation_response: Any, - attempt_number: int, - iteration: Iteration, - output_type: OutputTypes, -) -> Any: - validated_response = apply_refrain(validation_response, output_type) - - # Remove all keys that have `Filter` values. - validated_response = apply_filters(validated_response) - - trace_validation_result( - validation_logs=iteration.validator_logs, attempt_number=attempt_number - ) - - return validated_response diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py new file mode 100644 index 000000000..bff57687c --- /dev/null +++ b/guardrails/validator_service/validator_service_base.py @@ -0,0 +1,210 @@ +from copy import deepcopy +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Awaitable, Dict, Optional, Union + +from guardrails.actions.filter import Filter +from guardrails.actions.refrain import Refrain +from guardrails.classes.history import Iteration +from guardrails.classes.validation.validation_result import ( + FailResult, + ValidationResult, +) +from guardrails.errors import ValidationError +from guardrails.merge import merge +from guardrails.types import OnFailAction +from guardrails.utils.hub_telemetry_utils import HubTelemetry +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.actions.reask import FieldReAsk +from guardrails.telemetry import trace_validator +from guardrails.utils.serialization_utils import deserialize, serialize +from guardrails.validator_base import Validator + +ValidatorResult = Optional[Union[ValidationResult, Awaitable[ValidationResult]]] + + +@dataclass +class ValidatorRun: + value: Any + metadata: Dict + validator_logs: ValidatorLogs + + +class ValidatorServiceBase: + """Base class for validator services.""" + + def __init__(self, disable_tracer: Optional[bool] = True): + self._disable_tracer = disable_tracer + + # NOTE: This is avoiding an issue with multiprocessing. + # If we wrap the validate methods at the class level or anytime before + # loop.run_in_executor is called, multiprocessing fails with a Pickling error. + # This is a well known issue without any real solutions. + # Using `fork` instead of `spawn` may alleviate the symptom for POSIX systems, + # but is relatively unsupported on Windows. + def execute_validator( + self, + validator: Validator, + value: Any, + metadata: Optional[Dict], + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + ) -> ValidatorResult: + validate_func = validator.validate_stream if stream else validator.validate + traced_validator = trace_validator( + validator_name=validator.rail_alias, + obj_id=id(validator), + on_fail_descriptor=validator.on_fail_descriptor, + validation_session_id=validation_session_id, + **validator._kwargs, + )(validate_func) + if stream: + result = traced_validator(value, metadata, **kwargs) + else: + result = traced_validator(value, metadata) + return result + + def perform_correction( + self, + result: FailResult, + value: Any, + validator: Validator, + rechecked_value: Optional[ValidationResult] = None, + ): + on_fail_descriptor = validator.on_fail_descriptor + if on_fail_descriptor == OnFailAction.FIX: + # FIXME: Should we still return fix_value if it is None? + # I think we should warn and return the original value. + return result.fix_value + elif on_fail_descriptor == OnFailAction.FIX_REASK: + # FIXME: Same thing here + fixed_value = result.fix_value + + if isinstance(rechecked_value, FailResult): + return FieldReAsk( + incorrect_value=fixed_value, + fail_results=[result], + ) + + return fixed_value + if on_fail_descriptor == "custom": + if validator.on_fail_method is None: + raise ValueError("on_fail is 'custom' but on_fail_method is None") + return validator.on_fail_method(value, [result]) + if on_fail_descriptor == OnFailAction.REASK: + return FieldReAsk( + incorrect_value=value, + fail_results=[result], + ) + if on_fail_descriptor == OnFailAction.EXCEPTION: + raise ValidationError( + "Validation failed for field with errors: " + + ", ".join([result.error_message]) + ) + if on_fail_descriptor == OnFailAction.FILTER: + return Filter() + if on_fail_descriptor == OnFailAction.REFRAIN: + return Refrain() + if on_fail_descriptor == OnFailAction.NOOP: + return value + else: + raise ValueError( + f"Invalid on_fail_descriptor {on_fail_descriptor}, " + f"expected 'fix' or 'exception'." + ) + + def before_run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + absolute_property_path: str, + ) -> ValidatorLogs: + validator_class_name = validator.__class__.__name__ + validator_logs = ValidatorLogs( + validator_name=validator_class_name, + value_before_validation=value, + registered_name=validator.rail_alias, + property_path=absolute_property_path, + # If we ever re-use validator instances across multiple properties, + # this will have to change. + instance_id=id(validator), + ) + iteration.outputs.validator_logs.append(validator_logs) + + start_time = datetime.now() + validator_logs.start_time = start_time + + return validator_logs + + def after_run_validator( + self, + validator: Validator, + validator_logs: ValidatorLogs, + result: Optional[ValidationResult], + ) -> ValidatorLogs: + end_time = datetime.now() + validator_logs.validation_result = result + validator_logs.end_time = end_time + + if not self._disable_tracer: + # Get HubTelemetry singleton and create a new span to + # log the validator usage + _hub_telemetry = HubTelemetry() + _hub_telemetry.create_new_span( + span_name="/validator_usage", + attributes=[ + ("validator_name", validator.rail_alias), + ("validator_on_fail", validator.on_fail_descriptor), + ( + "validator_result", + result.outcome + if isinstance(result, ValidationResult) + else None, + ), + ], + is_parent=False, # This span will have no children + has_parent=True, # This span has a parent + ) + + return validator_logs + + def run_validator( + self, + iteration: Iteration, + validator: Validator, + value: Any, + metadata: Dict, + absolute_property_path: str, + stream: Optional[bool] = False, + **kwargs, + ) -> ValidatorRun: + raise NotImplementedError + + def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: + new_vals = deepcopy(new_values) + current = new_values.pop() + while len(new_values) > 0: + nextval = new_values.pop() + # print("current:", current) + # print("serialize(current):", serialize(current)) + # print("nextval:", nextval) + # print("serialize(nextval):", serialize(nextval)) + # print("original_value:", original_value) + # print("serialize(original_value):", serialize(original_value)) + current = merge( + serialize(current), serialize(nextval), serialize(original_value) + ) + deserialized_value = deserialize(original_value, current) + if deserialized_value is None and current is not None: + # QUESTION: How do we escape hatch + # for when deserializing the merged value fails? + + # Should we return the original value? + # return original_value + + # Or just pick one of the new values? + return new_vals[0] + return deserialized_value diff --git a/poetry.lock b/poetry.lock index e176aef6f..8c8f56368 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7871,6 +7871,50 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uvloop" +version = "0.20.0" +description = "Fast implementation of asyncio event loop on top of libuv" +optional = true +python-versions = ">=3.8.0" +files = [ + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:9ebafa0b96c62881d5cafa02d9da2e44c23f9f0cd829f3a32a6aff771449c996"}, + {file = "uvloop-0.20.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:35968fc697b0527a06e134999eef859b4034b37aebca537daeb598b9d45a137b"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b16696f10e59d7580979b420eedf6650010a4a9c3bd8113f24a103dfdb770b10"}, + {file = "uvloop-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b04d96188d365151d1af41fa2d23257b674e7ead68cfd61c725a422764062ae"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:94707205efbe809dfa3a0d09c08bef1352f5d3d6612a506f10a319933757c006"}, + {file = "uvloop-0.20.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:89e8d33bb88d7263f74dc57d69f0063e06b5a5ce50bb9a6b32f5fcbe655f9e73"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e50289c101495e0d1bb0bfcb4a60adde56e32f4449a67216a1ab2750aa84f037"}, + {file = "uvloop-0.20.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e237f9c1e8a00e7d9ddaa288e535dc337a39bcbf679f290aee9d26df9e72bce9"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:746242cd703dc2b37f9d8b9f173749c15e9a918ddb021575a0205ec29a38d31e"}, + {file = "uvloop-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:82edbfd3df39fb3d108fc079ebc461330f7c2e33dbd002d146bf7c445ba6e756"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:80dc1b139516be2077b3e57ce1cb65bfed09149e1d175e0478e7a987863b68f0"}, + {file = "uvloop-0.20.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f44af67bf39af25db4c1ac27e82e9665717f9c26af2369c404be865c8818dcf"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:4b75f2950ddb6feed85336412b9a0c310a2edbcf4cf931aa5cfe29034829676d"}, + {file = "uvloop-0.20.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:77fbc69c287596880ecec2d4c7a62346bef08b6209749bf6ce8c22bbaca0239e"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6462c95f48e2d8d4c993a2950cd3d31ab061864d1c226bbf0ee2f1a8f36674b9"}, + {file = "uvloop-0.20.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649c33034979273fa71aa25d0fe120ad1777c551d8c4cd2c0c9851d88fcb13ab"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3a609780e942d43a275a617c0839d85f95c334bad29c4c0918252085113285b5"}, + {file = "uvloop-0.20.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aea15c78e0d9ad6555ed201344ae36db5c63d428818b4b2a42842b3870127c00"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f0e94b221295b5e69de57a1bd4aeb0b3a29f61be6e1b478bb8a69a73377db7ba"}, + {file = "uvloop-0.20.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fee6044b64c965c425b65a4e17719953b96e065c5b7e09b599ff332bb2744bdf"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:265a99a2ff41a0fd56c19c3838b29bf54d1d177964c300dad388b27e84fd7847"}, + {file = "uvloop-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b10c2956efcecb981bf9cfb8184d27d5d64b9033f917115a960b83f11bfa0d6b"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:e7d61fe8e8d9335fac1bf8d5d82820b4808dd7a43020c149b63a1ada953d48a6"}, + {file = "uvloop-0.20.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:2beee18efd33fa6fdb0976e18475a4042cd31c7433c866e8a09ab604c7c22ff2"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:d8c36fdf3e02cec92aed2d44f63565ad1522a499c654f07935c8f9d04db69e95"}, + {file = "uvloop-0.20.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a0fac7be202596c7126146660725157d4813aa29a4cc990fe51346f75ff8fde7"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d0fba61846f294bce41eb44d60d58136090ea2b5b99efd21cbdf4e21927c56a"}, + {file = "uvloop-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95720bae002ac357202e0d866128eb1ac82545bcf0b549b9abe91b5178d9b541"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:36c530d8fa03bfa7085af54a48f2ca16ab74df3ec7108a46ba82fd8b411a2315"}, + {file = "uvloop-0.20.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e97152983442b499d7a71e44f29baa75b3b02e65d9c44ba53b10338e98dedb66"}, + {file = "uvloop-0.20.0.tar.gz", hash = "sha256:4603ca714a754fc8d9b197e325db25b2ea045385e8a3ad05d3463de725fdf469"}, +] + +[package.extras] +docs = ["Sphinx (>=4.1.2,<4.2.0)", "sphinx-rtd-theme (>=0.5.2,<0.6.0)", "sphinxcontrib-asyncio (>=0.3.0,<0.4.0)"] +test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)", "flake8 (>=5.0,<6.0)", "mypy (>=0.800)", "psutil", "pyOpenSSL (>=23.0.0,<23.1.0)", "pycodestyle (>=2.9.0,<2.10.0)"] + [[package]] name = "virtualenv" version = "20.26.2" diff --git a/pyproject.toml b/pyproject.toml index 3020192d9..76a927a54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,10 @@ pillow = "^10.1.0" cairosvg = "^2.7.1" mkdocs-glightbox = "^0.3.4" + +[tool.poetry.group.uv.dependencies] +uvloop = {version = "^0.20.0", optional = true} + [[tool.poetry.source]] name = "PyPI" diff --git a/tests/conftest.py b/tests/conftest.py index fb98d40f2..c8d13ee79 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,7 +36,9 @@ def mock_validator_base_hub_telemetry(): @pytest.fixture(autouse=True) def mock_validator_service_hub_telemetry(): - with patch("guardrails.validator_service.HubTelemetry") as MockHubTelemetry: + with patch( + "guardrails.validator_service.validator_service_base.HubTelemetry" + ) as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() yield MockHubTelemetry diff --git a/tests/unit_tests/test_async_validator_service.py b/tests/unit_tests/test_async_validator_service.py deleted file mode 100644 index 2c566035a..000000000 --- a/tests/unit_tests/test_async_validator_service.py +++ /dev/null @@ -1,345 +0,0 @@ -import asyncio - -import pytest - -from guardrails.classes.history.iteration import Iteration -from guardrails.classes.validation.validator_logs import ValidatorLogs -from guardrails.validator_base import OnFailAction -from guardrails.validator_service import AsyncValidatorService -from guardrails.classes.validation.validation_result import PassResult - -from .mocks import MockLoop -from .mocks.mock_validator import create_mock_validator - -avs = AsyncValidatorService() - - -def test_validate_with_running_loop(mocker): - iteration = Iteration( - call_id="mock-call", - index=0, - ) - with pytest.raises(RuntimeError) as e_info: - mock_loop = MockLoop(True) - mocker.patch("asyncio.get_event_loop", return_value=mock_loop) - avs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert ( - str(e_info) - == "Async event loop found, please call `validate_async` instead." - ) - - -def test_validate_without_running_loop(mocker): - mock_loop = MockLoop(False) - mocker.patch("asyncio.get_event_loop", return_value=mock_loop) - async_validate_mock = mocker.MagicMock( - return_value=("async_validate_mock", {"async": True}) - ) - mocker.patch.object(avs, "async_validate", async_validate_mock) - loop_spy = mocker.spy(mock_loop, "run_until_complete") - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = avs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert loop_spy.call_count == 1 - async_validate_mock.assert_called_once_with( - True, {}, {}, iteration, "$", "$", stream=False - ) - assert validated_value == "async_validate_mock" - assert validated_metadata == {"async": True} - - -@pytest.mark.asyncio -async def test_async_validate_with_children(mocker): - validate_children_mock = mocker.patch.object(avs, "validate_children") - - run_validators_mock = mocker.patch.object(avs, "run_validators") - run_validators_mock.return_value = ("run_validators_mock", {"async": True}) - - value = {"a": 1} - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = await avs.async_validate( - value=value, - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert validate_children_mock.call_count == 1 - validate_children_mock.assert_called_once_with( - value, {}, {}, iteration, "$", "$", stream=False - ) - - assert run_validators_mock.call_count == 1 - run_validators_mock.assert_called_once_with( - iteration, {}, value, {}, "$", "$", stream=False - ) - - assert validated_value == "run_validators_mock" - assert validated_metadata == {"async": True} - - -@pytest.mark.asyncio -async def test_async_validate_without_children(mocker): - validate_children_mock = mocker.patch.object(avs, "validate_children") - - run_validators_mock = mocker.patch.object(avs, "run_validators") - run_validators_mock.return_value = ("run_validators_mock", {"async": True}) - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = await avs.async_validate( - value="Hello world!", - metadata={}, - validator_map={}, - iteration=iteration, - absolute_path="$", - reference_path="$", - ) - - assert validate_children_mock.call_count == 0 - - assert run_validators_mock.call_count == 1 - run_validators_mock.assert_called_once_with( - iteration, {}, "Hello world!", {}, "$", "$", stream=False - ) - - assert validated_value == "run_validators_mock" - assert validated_metadata == {"async": True} - - -@pytest.mark.asyncio -async def test_validate_children(mocker): - async def mock_async_validate(v, md, *args, **kwargs): - return (f"new-{v}", md) - - async_validate_mock = mocker.patch.object( - avs, "async_validate", side_effect=mock_async_validate - ) - - gather_spy = mocker.spy(asyncio, "gather") - - validator_map = { - "$.mock-parent-key": [], - "$.mock-parent-key.child-one-key": [], - "$.mock-parent-key.child-two-key": [], - } - - value = { - "mock-parent-key": { - "child-one-key": "child-one-value", - "child-two-key": "child-two-value", - } - } - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - validated_value, validated_metadata = await avs.validate_children( - value=value.get("mock-parent-key"), - metadata={}, - validator_map=validator_map, - iteration=iteration, - abs_parent_path="$.mock-parent-key", - ref_parent_path="$.mock-parent-key", - ) - - assert gather_spy.call_count == 1 - - assert async_validate_mock.call_count == 2 - async_validate_mock.assert_any_call( - "child-one-value", - {}, - validator_map, - iteration, - "$.mock-parent-key.child-one-key", - "$.mock-parent-key.child-one-key", - stream=False, - ) - async_validate_mock.assert_any_call( - "child-two-value", - {}, - validator_map, - iteration, - "$.mock-parent-key.child-two-key", - "$.mock-parent-key.child-two-key", - stream=False, - ) - - assert validated_value == { - "child-one-key": "new-child-one-value", - "child-two-key": "new-child-two-value", - } - assert validated_metadata == {} - - -@pytest.mark.asyncio -async def test_run_validators(mocker): - group_validators_mock = mocker.patch.object(avs, "group_validators") - fix_validator_type = create_mock_validator("fix_validator", OnFailAction.FIX) - fix_validator = fix_validator_type() - noop_validator_type = create_mock_validator("noop_validator") - noop_validator_1 = noop_validator_type() - noop_validator_type = create_mock_validator("noop_validator") - noop_validator_2 = noop_validator_type() - noop_validator_2.run_in_separate_process = True - group_validators_mock.return_value = [ - (OnFailAction.FIX, [fix_validator]), - (OnFailAction.NOOP, [noop_validator_1, noop_validator_2]), - ] - - def mock_run_validator( - iteration, validator, value, metadata, property_path, stream - ): - return ValidatorLogs( - registered_name=validator.name, - validator_name=validator.name, - value_before_validation=value, - validation_result=PassResult(), - property_path=property_path, - ) - - run_validator_mock = mocker.patch.object( - avs, "run_validator", side_effect=mock_run_validator - ) - - mock_loop = MockLoop(True) - run_in_executor_spy = mocker.spy(mock_loop, "run_in_executor") - get_running_loop_mock = mocker.patch( - "asyncio.get_running_loop", return_value=mock_loop - ) - - async def mock_gather(*args): - return args - - asyancio_gather_mock = mocker.patch("asyncio.gather", side_effect=mock_gather) - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - value, metadata = await avs.run_validators( - iteration=iteration, - validator_map={}, - value=True, - metadata={}, - absolute_property_path="$", - reference_property_path="$", - ) - - assert get_running_loop_mock.call_count == 1 - - assert group_validators_mock.call_count == 1 - group_validators_mock.assert_called_once_with([]) - - assert run_in_executor_spy.call_count == 1 - run_in_executor_spy.assert_called_once_with( - avs.multiprocessing_executor, - run_validator_mock, - iteration, - noop_validator_2, - True, - {}, - "$", - False, - ) - - assert run_validator_mock.call_count == 3 - - assert asyancio_gather_mock.call_count == 1 - - assert value is True - assert metadata == {} - - -@pytest.mark.asyncio -async def test_run_validators_with_override(mocker): - group_validators_mock = mocker.patch.object(avs, "group_validators") - override_validator_type = create_mock_validator("override") - override_validator = override_validator_type() - override_validator.override_value_on_pass = True - - group_validators_mock.return_value = [("exception", [override_validator])] - - run_validator_mock = mocker.patch.object(avs, "run_validator") - run_validator_mock.return_value = ValidatorLogs( - registered_name="override", - validator_name="override", - value_before_validation="mock-value", - validation_result=PassResult(value_override="override"), - property_path="$", - ) - - mock_loop = MockLoop(True) - run_in_executor_spy = mocker.spy(mock_loop, "run_in_executor") - get_running_loop_mock = mocker.patch( - "asyncio.get_running_loop", return_value=mock_loop - ) - - asyancio_gather_mock = mocker.patch("asyncio.gather") - - iteration = Iteration( - call_id="mock-call", - index=0, - ) - - value, metadata = await avs.run_validators( - iteration=iteration, - validator_map={}, - value=True, - metadata={}, - absolute_property_path="$", - reference_property_path="$", - ) - - assert get_running_loop_mock.call_count == 1 - - assert group_validators_mock.call_count == 1 - group_validators_mock.assert_called_once_with([]) - - assert run_in_executor_spy.call_count == 0 - - assert run_validator_mock.call_count == 1 - - assert asyancio_gather_mock.call_count == 0 - - assert value == "override" - assert metadata == {} - - -# TODO -@pytest.mark.asyncio -async def test_run_validators_with_failures(mocker): - assert True is True diff --git a/tests/unit_tests/test_validator_service.py b/tests/unit_tests/test_validator_service.py deleted file mode 100644 index 36b723382..000000000 --- a/tests/unit_tests/test_validator_service.py +++ /dev/null @@ -1,97 +0,0 @@ -import pytest - -import guardrails.validator_service as vs -from guardrails.classes.history.iteration import Iteration - -from .mocks import MockAsyncValidatorService, MockLoop, MockSequentialValidatorService - - -iteration = Iteration( - call_id="mock-call", - index=0, -) - - -@pytest.mark.asyncio -async def test_async_validate(mocker): - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - validated_value, validated_metadata = await vs.async_validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockAsyncValidatorService.async_validate" - assert validated_metadata == {"async": True} - - -def test_validate_with_running_loop(mocker): - mockLoop = MockLoop(True) - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - mocker.patch( - "guardrails.validator_service.SequentialValidatorService", - new=MockSequentialValidatorService, - ) - mocker.patch("asyncio.get_event_loop", return_value=mockLoop) - - validated_value, validated_metadata = vs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockSequentialValidatorService.validate" - assert validated_metadata == {"sync": True} - - -def test_validate_without_running_loop(mocker): - mockLoop = MockLoop(False) - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - mocker.patch( - "guardrails.validator_service.SequentialValidatorService", - new=MockSequentialValidatorService, - ) - mocker.patch("asyncio.get_event_loop", return_value=mockLoop) - validated_value, validated_metadata = vs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockAsyncValidatorService.validate" - assert validated_metadata == {"sync": True} - - -def test_validate_loop_runtime_error(mocker): - mocker.patch( - "guardrails.validator_service.AsyncValidatorService", - new=MockAsyncValidatorService, - ) - mocker.patch( - "guardrails.validator_service.SequentialValidatorService", - new=MockSequentialValidatorService, - ) - # raise RuntimeError in `get_event_loop` - mocker.patch("asyncio.get_event_loop", side_effect=RuntimeError) - - validated_value, validated_metadata = vs.validate( - value=True, - metadata={}, - validator_map={}, - iteration=iteration, - ) - - assert validated_value == "MockSequentialValidatorService.validate" - assert validated_metadata == {"sync": True} diff --git a/tests/unit_tests/utils/test_serialization_utils.py b/tests/unit_tests/utils/test_serialization_utils.py new file mode 100644 index 000000000..948604dd2 --- /dev/null +++ b/tests/unit_tests/utils/test_serialization_utils.py @@ -0,0 +1,169 @@ +import pytest +from datetime import datetime +from guardrails.utils.serialization_utils import serialize, deserialize + + +class TestSerializeAndDeserialize: + def test_string(self): + data = "value" + + serialized_data = serialize(data) + assert serialized_data == '"value"' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_int(self): + data = 1 + + serialized_data = serialize(data) + assert serialized_data == "1" + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_float(self): + data = 1.0 + + serialized_data = serialize(data) + assert serialized_data == "1.0" + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_bool(self): + data = True + + serialized_data = serialize(data) + assert serialized_data == "true" + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_datetime(self): + data = datetime(2024, 9, 10, 0, 0, 0) + + serialized_data = serialize(data) + assert serialized_data == '"2024-09-10T00:00:00"' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_dictionary(self): + data = {"key": "value"} + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_list(self): + data = ["value1", "value2"] + + serialized_data = serialize(data) + assert serialized_data == '["value1", "value2"]' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data == data + + def test_simple_class(self): + class TestClass: + def __init__(self, key: str): + self.key = key + + data = TestClass("value") + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.key == data.key + + def test_nested_classes_not_supported(self): + class TestClass: + def __init__(self, value: str): + self.value = value + + class TestClass2: + def __init__(self, value: TestClass): + self.value = value + + data = TestClass2(TestClass("value")) + + serialized_data = serialize(data) + assert serialized_data == '{"value": {"value": "value"}}' + + deserialized_data = deserialize(data, serialized_data) + with pytest.raises(AttributeError) as excinfo: + assert deserialized_data.value.value == data.value.value + + assert str(excinfo.value) == "'dict' object has no attribute 'value'" + + def test_simple_dataclass(self): + from dataclasses import dataclass + + @dataclass + class TestClass: + key: str + + data = TestClass("value") + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.key == data.key + + def test_nested_dataclasses_not_supported(self): + from dataclasses import dataclass + + @dataclass + class TestClass: + value: str + + @dataclass + class TestClass2: + value: TestClass + + data = TestClass2(TestClass("value")) + + serialized_data = serialize(data) + assert serialized_data == '{"value": {"value": "value"}}' + + deserialized_data = deserialize(data, serialized_data) + with pytest.raises(AttributeError) as excinfo: + assert deserialized_data.value.value == data.value.value + + assert str(excinfo.value) == "'dict' object has no attribute 'value'" + + def test_simple_pydantic_model(self): + from pydantic import BaseModel + + class TestClass(BaseModel): + key: str + + data = TestClass(key="value") + + serialized_data = serialize(data) + assert serialized_data == '{"key": "value"}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.key == data.key + + def test_nested_pydantic_models(self): + from pydantic import BaseModel + + class TestClass(BaseModel): + value: str + + class TestClass2(BaseModel): + value: TestClass + + data = TestClass2(value=TestClass(value="value")) + + serialized_data = serialize(data) + assert serialized_data == '{"value": {"value": "value"}}' + + deserialized_data = deserialize(data, serialized_data) + assert deserialized_data.value.value == data.value.value diff --git a/tests/unit_tests/validator_service/test_async_validator_service.py b/tests/unit_tests/validator_service/test_async_validator_service.py new file mode 100644 index 000000000..fe6ef5e89 --- /dev/null +++ b/tests/unit_tests/validator_service/test_async_validator_service.py @@ -0,0 +1,798 @@ +from datetime import datetime +from unittest.mock import MagicMock, call + +from guardrails.actions.filter import Filter +from guardrails.validator_service.validator_service_base import ValidatorRun +import pytest + +from guardrails.classes.history.iteration import Iteration +from guardrails.classes.validation.validator_logs import ValidatorLogs +from guardrails.validator_base import OnFailAction, Validator +from guardrails.validator_service.async_validator_service import AsyncValidatorService +from guardrails.classes.validation.validation_result import FailResult, PassResult + + +avs = AsyncValidatorService() + + +def test_validate(mocker): + mock_loop = mocker.MagicMock() + mock_loop.run_until_complete = mocker.MagicMock(return_value=(True, {})) + # loop_spy = mocker.spy(mock_loop, "run_until_complete", return_value=(True, {})) + async_validate_mock = mocker.patch.object(avs, "async_validate") + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + avs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=mock_loop, + ) + + assert mock_loop.run_until_complete.call_count == 1 + async_validate_mock.assert_called_once_with( + True, {}, {}, iteration, "$", "$", stream=False + ) + + +class TestAsyncValidate: + @pytest.mark.asyncio + async def test_with_dictionary(self, mocker): + validate_children_mock = mocker.patch.object(avs, "validate_children") + + run_validators_mock = mocker.patch.object( + avs, "run_validators", return_value=("run_validators_mock", {"async": True}) + ) + + value = {"a": 1} + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validated_value, validated_metadata = await avs.async_validate( + value=value, + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert validate_children_mock.call_count == 1 + validate_children_mock.assert_called_once_with( + value, {}, {}, iteration, "$", "$", stream=False + ) + + assert run_validators_mock.call_count == 1 + run_validators_mock.assert_called_once_with( + iteration, {}, value, {}, "$", "$", stream=False + ) + + assert validated_value == "run_validators_mock" + assert validated_metadata == {"async": True} + + @pytest.mark.asyncio + async def test_with_list(self, mocker): + validate_children_mock = mocker.patch.object(avs, "validate_children") + + run_validators_mock = mocker.patch.object( + avs, "run_validators", return_value=("run_validators_mock", {"async": True}) + ) + + value = ["a"] + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validated_value, validated_metadata = await avs.async_validate( + value=value, + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert validate_children_mock.call_count == 1 + validate_children_mock.assert_called_once_with( + value, {}, {}, iteration, "$", "$", stream=False + ) + + assert run_validators_mock.call_count == 1 + run_validators_mock.assert_called_once_with( + iteration, {}, value, {}, "$", "$", stream=False + ) + + assert validated_value == "run_validators_mock" + assert validated_metadata == {"async": True} + + @pytest.mark.asyncio + async def test_without_children(self, mocker): + validate_children_mock = mocker.patch.object(avs, "validate_children") + + run_validators_mock = mocker.patch.object(avs, "run_validators") + run_validators_mock.return_value = ("run_validators_mock", {"async": True}) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validated_value, validated_metadata = await avs.async_validate( + value="Hello world!", + metadata={}, + validator_map={}, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert validate_children_mock.call_count == 0 + + assert run_validators_mock.call_count == 1 + run_validators_mock.assert_called_once_with( + iteration, {}, "Hello world!", {}, "$", "$", stream=False + ) + + assert validated_value == "run_validators_mock" + assert validated_metadata == {"async": True} + + +class TestValidateChildren: + @pytest.mark.asyncio + async def test_with_list(self, mocker): + mock_async_validate = mocker.patch.object( + avs, + "async_validate", + side_effect=[ + ( + "mock-child-1-value", + { + "mock-child-1-metadata": "child-1-metadata", + "mock-shared-metadata": "shared-metadata-1", + }, + ), + ( + "mock-child-2-value", + { + "mock-child-2-metadata": "child-2-metadata", + "mock-shared-metadata": "shared-metadata-2", + }, + ), + ], + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validator_map = ({"$.*": [MagicMock(spec=Validator)]},) + value, metadata = await avs.validate_children( + value=["mock-child-1", "mock-child-2"], + metadata={"mock-shared-metadata": "shared-metadata"}, + validator_map=validator_map, + iteration=iteration, + abs_parent_path="$", + ref_parent_path="$", + ) + + assert mock_async_validate.call_count == 2 + mock_async_validate.assert_has_calls( + [ + call( + "mock-child-1", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.0", + "$.*", + stream=False, + ), + call( + "mock-child-2", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.1", + "$.*", + stream=False, + ), + ] + ) + + assert value == ["mock-child-1-value", "mock-child-2-value"] + assert metadata == { + "mock-child-1-metadata": "child-1-metadata", + "mock-child-2-metadata": "child-2-metadata", + # NOTE: This is overriden based on who finishes last + "mock-shared-metadata": "shared-metadata-2", + } + + @pytest.mark.asyncio + async def test_with_dictionary(self, mocker): + mock_async_validate = mocker.patch.object( + avs, + "async_validate", + side_effect=[ + ( + "mock-child-1-value", + { + "mock-child-1-metadata": "child-1-metadata", + "mock-shared-metadata": "shared-metadata-1", + }, + ), + ( + "mock-child-2-value", + { + "mock-child-2-metadata": "child-2-metadata", + "mock-shared-metadata": "shared-metadata-2", + }, + ), + ], + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + validator_map = ( + { + "$.child-1": [MagicMock(spec=Validator)], + "$.child-2": [MagicMock(spec=Validator)], + }, + ) + value, metadata = await avs.validate_children( + value={"child-1": "mock-child-1", "child-2": "mock-child-2"}, + metadata={"mock-shared-metadata": "shared-metadata"}, + validator_map=validator_map, + iteration=iteration, + abs_parent_path="$", + ref_parent_path="$", + ) + + assert mock_async_validate.call_count == 2 + mock_async_validate.assert_has_calls( + [ + call( + "mock-child-1", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.child-1", + "$.child-1", + stream=False, + ), + call( + "mock-child-2", + { + "mock-shared-metadata": "shared-metadata", + }, + validator_map, + iteration, + "$.child-2", + "$.child-2", + stream=False, + ), + ] + ) + + assert value == { + "child-1": "mock-child-1-value", + "child-2": "mock-child-2-value", + } + assert metadata == { + "mock-child-1-metadata": "child-1-metadata", + "mock-child-2-metadata": "child-2-metadata", + # NOTE: This is overriden based on who finishes last + "mock-shared-metadata": "shared-metadata-2", + } + + +class TestRunValidators: + @pytest.mark.asyncio + async def test_filter_exits_early(self, mocker): + mock_run_validator = mocker.patch.object( + avs, + "run_validator", + side_effect=[ + ValidatorRun( + value="mock-value", + metadata={}, + validator_logs=ValidatorLogs( + registered_name="noop_validator", + validator_name="noop_validator", + value_before_validation="mock-value", + validation_result=PassResult(), + property_path="$", + ), + ), + ValidatorRun( + value=Filter(), + metadata={}, + validator_logs=ValidatorLogs( + registered_name="filter_validator", + validator_name="filter_validator", + value_before_validation="mock-value", + validation_result=FailResult(error_message="mock-error"), + property_path="$", + ), + ), + ], + ) + mock_merge_results = mocker.patch.object(avs, "merge_results") + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + value, metadata = await avs.run_validators( + iteration=iteration, + validator_map={ + "$": [ + MagicMock(spec=Validator), + MagicMock(spec=Validator), + ] + }, + value=True, + metadata={}, + absolute_property_path="$", + reference_property_path="$", + ) + + assert mock_run_validator.call_count == 2 + assert mock_merge_results.call_count == 0 + + assert isinstance(value, Filter) + assert metadata == {} + + @pytest.mark.asyncio + async def test_calls_merge(self, mocker): + mock_run_validator = mocker.patch.object( + avs, + "run_validator", + side_effect=[ + ValidatorRun( + value="mock-value", + metadata={}, + validator_logs=ValidatorLogs( + registered_name="noop_validator", + validator_name="noop_validator", + value_before_validation="mock-value", + validation_result=PassResult(), + property_path="$", + ), + ), + ValidatorRun( + value="mock-fix-value", + metadata={}, + validator_logs=ValidatorLogs( + registered_name="fix_validator", + validator_name="fix_validator", + value_before_validation="mock-value", + validation_result=FailResult( + error_message="mock-error", fix_value="mock-fix-value" + ), + property_path="$", + ), + ), + ], + ) + mock_merge_results = mocker.patch.object( + avs, "merge_results", return_value="mock-fix-value" + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + value, metadata = await avs.run_validators( + iteration=iteration, + validator_map={ + "$": [ + MagicMock(spec=Validator), + MagicMock(spec=Validator), + ] + }, + value=True, + metadata={}, + absolute_property_path="$", + reference_property_path="$", + ) + + assert mock_run_validator.call_count == 2 + assert mock_merge_results.call_count == 1 + + assert value == "mock-fix-value" + assert metadata == {} + + @pytest.mark.asyncio + async def test_returns_value_if_no_results(self, mocker): + mock_run_validator = mocker.patch.object(avs, "run_validator") + mock_merge_results = mocker.patch.object(avs, "merge_results") + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + + value, metadata = await avs.run_validators( + iteration=iteration, + validator_map={}, + value=True, + metadata={}, + absolute_property_path="$", + reference_property_path="$", + ) + + assert mock_run_validator.call_count == 0 + assert mock_merge_results.call_count == 0 + + assert value is True + assert metadata == {} + + +class TestRunValidator: + @pytest.mark.asyncio + async def test_pass_result(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = PassResult() + mock_run_validator_async = mocker.patch.object( + avs, "run_validator_async", return_value=validation_result + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 1 + mock_run_validator_async.assert_called_once_with( + validator, "value", {}, False, validation_session_id=iteration.id + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "value" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + @pytest.mark.asyncio + async def test_pass_result_with_override(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = PassResult(value_override="override") + mock_run_validator_async = mocker.patch.object( + avs, "run_validator_async", return_value=validation_result + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 1 + mock_run_validator_async.assert_called_once_with( + validator, "value", {}, False, validation_session_id=iteration.id + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "override" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + @pytest.mark.asyncio + async def test_fail_result(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = FailResult(error_message="mock-error") + mock_run_validator_async = mocker.patch.object( + avs, "run_validator_async", return_value=validation_result + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + mock_perform_correction = mocker.patch.object( + avs, "perform_correction", return_value="corrected-value" + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = "noop" + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 1 + mock_run_validator_async.assert_called_once_with( + validator, "value", {}, False, validation_session_id=iteration.id + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert mock_perform_correction.call_count == 1 + mock_perform_correction.assert_called_once_with( + validation_result, "value", validator, rechecked_value=None + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "corrected-value" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + @pytest.mark.asyncio + async def test_fail_result_with_fix_reask(self, mocker): + validator_logs = ValidatorLogs( + validator_name="mock-validator", + registered_name="mock-validator", + instance_id=1234, + property_path="$", + value_before_validation="value", + start_time=datetime(2024, 9, 10, 9, 54, 0, 38391), + value_after_validation="value", + ) + mock_before_run_validator = mocker.patch.object( + avs, "before_run_validator", return_value=validator_logs + ) + + validation_result = FailResult( + error_message="mock-error", fix_value="fixed-value" + ) + rechecked_result = PassResult() + mock_run_validator_async = mocker.patch.object( + avs, + "run_validator_async", + side_effect=[validation_result, rechecked_result], + ) + + mock_after_run_validator = mocker.patch.object( + avs, "after_run_validator", return_value=validator_logs + ) + + mock_perform_correction = mocker.patch.object( + avs, "perform_correction", return_value="fixed-value" + ) + + iteration = Iteration( + call_id="mock-call", + index=0, + ) + validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = OnFailAction.FIX_REASK + + result = await avs.run_validator( + iteration=iteration, + validator=validator, + value="value", + metadata={}, + absolute_property_path="$", + ) + + assert mock_before_run_validator.call_count == 1 + mock_before_run_validator.assert_called_once_with( + iteration, validator, "value", "$" + ) + + assert mock_run_validator_async.call_count == 2 + mock_run_validator_async.assert_has_calls( + [ + call(validator, "value", {}, False, validation_session_id=iteration.id), + call( + validator, + "fixed-value", + {}, + False, + validation_session_id=iteration.id, + ), + ] + ) + + assert mock_after_run_validator.call_count == 1 + mock_after_run_validator.assert_called_once_with( + validator, validator_logs, validation_result + ) + + assert mock_perform_correction.call_count == 1 + mock_perform_correction.assert_called_once_with( + validation_result, "value", validator, rechecked_value=rechecked_result + ) + + assert isinstance(result, ValidatorRun) + assert result.value == "fixed-value" + assert result.metadata == {} + assert result.validator_logs == validator_logs + + +class TestRunValidatorAsync: + @pytest.mark.asyncio + async def test_happy_path(self, mocker): + mock_validator = MagicMock(spec=Validator) + + validation_result = PassResult() + mock_execute_validator = mocker.patch.object( + avs, "execute_validator", return_value=validation_result + ) + + result = await avs.run_validator_async( + validator=mock_validator, + value="value", + metadata={}, + stream=False, + validation_session_id="mock-session", + ) + + assert result == validation_result + + assert mock_execute_validator.call_count == 1 + mock_execute_validator.assert_called_once_with( + mock_validator, "value", {}, False, validation_session_id="mock-session" + ) + + @pytest.mark.asyncio + async def test_result_is_a_coroutine(self, mocker): + mock_validator = MagicMock(spec=Validator) + + validation_result = PassResult() + + async def result_coroutine(): + return validation_result + + mock_execute_validator = mocker.patch.object( + avs, "execute_validator", return_value=result_coroutine() + ) + + result = await avs.run_validator_async( + validator=mock_validator, + value="value", + metadata={}, + stream=False, + validation_session_id="mock-session", + ) + + assert result == validation_result + + assert mock_execute_validator.call_count == 1 + mock_execute_validator.assert_called_once_with( + mock_validator, "value", {}, False, validation_session_id="mock-session" + ) + + @pytest.mark.asyncio + async def test_result_is_none(self, mocker): + mock_validator = MagicMock(spec=Validator) + + validation_result = None + mock_execute_validator = mocker.patch.object( + avs, "execute_validator", return_value=validation_result + ) + + result = await avs.run_validator_async( + validator=mock_validator, + value="value", + metadata={}, + stream=False, + validation_session_id="mock-session", + ) + + assert isinstance(result, PassResult) + + assert mock_execute_validator.call_count == 1 + mock_execute_validator.assert_called_once_with( + mock_validator, "value", {}, False, validation_session_id="mock-session" + ) diff --git a/tests/unit_tests/validator_service/test_validator_service.py b/tests/unit_tests/validator_service/test_validator_service.py new file mode 100644 index 000000000..4514ad22c --- /dev/null +++ b/tests/unit_tests/validator_service/test_validator_service.py @@ -0,0 +1,152 @@ +from unittest.mock import AsyncMock +import pytest + +import guardrails.validator_service as vs +from guardrails.classes.history.iteration import Iteration + + +iteration = Iteration( + call_id="mock-call", + index=0, +) + + +class TestShouldRunSync: + def test_process_count_of_1(self, mocker): + mocker.patch("os.environ.get", side_effect=["1", "false"]) + assert vs.should_run_sync() is True + + def test_run_sync_set_to_true(self, mocker): + mocker.patch("os.environ.get", side_effect=["10", "True"]) + assert vs.should_run_sync() is True + + def test_should_run_sync_default(self, mocker): + mocker.patch("os.environ.get", side_effect=["10", "false"]) + assert vs.should_run_sync() is False + + +class TestGetLoop: + def test_get_loop_with_running_loop(self, mocker): + mocker.patch("asyncio.get_running_loop", return_value="running loop") + with pytest.raises(RuntimeError): + vs.get_loop() + + def test_get_loop_without_running_loop(self, mocker): + mocker.patch("asyncio.get_running_loop", side_effect=RuntimeError) + mocker.patch("asyncio.get_event_loop", return_value="event loop") + assert vs.get_loop() == "event loop" + + def test_get_loop_with_uvloop(self, mocker): + mocker.patch("guardrails.validator_service.uvloop") + mock_event_loop_policy = mocker.patch( + "guardrails.validator_service.uvloop.EventLoopPolicy" + ) + mocker.patch("asyncio.get_running_loop", side_effect=RuntimeError) + mocker.patch("asyncio.get_event_loop", return_value="event loop") + mock_set_event_loop_policy = mocker.patch("asyncio.set_event_loop_policy") + + assert vs.get_loop() == "event loop" + + mock_event_loop_policy.assert_called_once() + mock_set_event_loop_policy.assert_called_once_with( + mock_event_loop_policy.return_value + ) + + +class TestValidate: + def test_validate_with_sync(self, mocker): + mocker.patch("guardrails.validator_service.should_run_sync", return_value=True) + mocker.patch("guardrails.validator_service.SequentialValidatorService") + mocker.patch("guardrails.validator_service.AsyncValidatorService") + mocker.patch("guardrails.validator_service.get_loop") + mocker.patch("guardrails.validator_service.warnings") + + vs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + vs.SequentialValidatorService.assert_called_once_with(True) + vs.SequentialValidatorService.return_value.validate.assert_called_once_with( + True, + {}, + {}, + iteration, + "$", + "$", + loop=None, + ) + + def test_validate_with_async(self, mocker): + mocker.patch("guardrails.validator_service.should_run_sync", return_value=False) + mocker.patch("guardrails.validator_service.SequentialValidatorService") + mocker.patch("guardrails.validator_service.AsyncValidatorService") + mocker.patch("guardrails.validator_service.get_loop", return_value="event loop") + mocker.patch("guardrails.validator_service.warnings") + + vs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + vs.AsyncValidatorService.assert_called_once_with(True) + vs.AsyncValidatorService.return_value.validate.assert_called_once_with( + True, + {}, + {}, + iteration, + "$", + "$", + loop="event loop", + ) + + def test_validate_with_no_available_event_loop(self, mocker): + mocker.patch("guardrails.validator_service.should_run_sync", return_value=False) + mocker.patch("guardrails.validator_service.SequentialValidatorService") + mocker.patch("guardrails.validator_service.AsyncValidatorService") + mocker.patch("guardrails.validator_service.get_loop", side_effect=RuntimeError) + mock_warn = mocker.patch("guardrails.validator_service.warnings.warn") + + vs.validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + mock_warn.assert_called_once_with( + "Could not obtain an event loop. Falling back to synchronous validation." + ) + + vs.SequentialValidatorService.assert_called_once_with(True) + vs.SequentialValidatorService.return_value.validate.assert_called_once_with( + True, + {}, + {}, + iteration, + "$", + "$", + loop=None, + ) + + +@pytest.mark.asyncio +async def test_async_validate(mocker): + mocker.patch( + "guardrails.validator_service.AsyncValidatorService", return_value=AsyncMock() + ) + await vs.async_validate( + value=True, + metadata={}, + validator_map={}, + iteration=iteration, + ) + + vs.AsyncValidatorService.assert_called_once_with(True) + vs.AsyncValidatorService.return_value.async_validate.assert_called_once_with( + True, {}, {}, iteration, "$", "$", False + ) From 1f9f206130a4df82645549fcc4b6114ea7631446 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Tue, 10 Sep 2024 15:48:38 -0500 Subject: [PATCH 02/13] error to string for warning, don't int a None, avoid mock attr infinite loop --- guardrails/utils/serialization_utils.py | 4 ++-- guardrails/validator_service/__init__.py | 3 ++- tests/conftest.py | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/guardrails/utils/serialization_utils.py b/guardrails/utils/serialization_utils.py index 18e435374..d124a9069 100644 --- a/guardrails/utils/serialization_utils.py +++ b/guardrails/utils/serialization_utils.py @@ -11,7 +11,7 @@ def serialize(val: Any) -> Optional[str]: try: return json.dumps(val, cls=DefaultJSONEncoder) except Exception as e: - warnings.warn(e) + warnings.warn(str(e)) return None @@ -37,5 +37,5 @@ def deserialize(original: Optional[Any], serialized: Optional[str]) -> Any: return original.__class__(loaded_val) return loaded_val except Exception as e: - warnings.warn(e) + warnings.warn(str(e)) return None diff --git a/guardrails/validator_service/__init__.py b/guardrails/validator_service/__init__.py index 416975294..341168063 100644 --- a/guardrails/validator_service/__init__.py +++ b/guardrails/validator_service/__init__.py @@ -25,7 +25,7 @@ def should_run_sync(): - process_count = int(os.environ.get("GUARDRAILS_PROCESS_COUNT")) + process_count = os.environ.get("GUARDRAILS_PROCESS_COUNT") if process_count is not None: warnings.warn( "GUARDRAILS_PROCESS_COUNT is deprecated" @@ -33,6 +33,7 @@ def should_run_sync(): " To force synchronous validation, please use GUARDRAILS_RUN_SYNC instead.", DeprecationWarning, ) + process_count = int(process_count) run_sync = os.environ.get("GUARDRAILS_RUN_SYNC", "false") bool_values = ["true", "false"] if run_sync.lower() not in bool_values: diff --git a/tests/conftest.py b/tests/conftest.py index c8d13ee79..5db54f65b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,7 @@ def mock_span(): def mock_guard_hub_telemetry(): with patch("guardrails.guard.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @@ -31,6 +32,7 @@ def mock_guard_hub_telemetry(): def mock_validator_base_hub_telemetry(): with patch("guardrails.validator_base.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @@ -40,6 +42,7 @@ def mock_validator_service_hub_telemetry(): "guardrails.validator_service.validator_service_base.HubTelemetry" ) as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry @@ -47,6 +50,7 @@ def mock_validator_service_hub_telemetry(): def mock_runner_hub_telemetry(): with patch("guardrails.run.runner.HubTelemetry") as MockHubTelemetry: MockHubTelemetry.return_value = MagicMock() + MockHubTelemetry.return_value.to_dict = None yield MockHubTelemetry From a37b632acb47383479ab82ca08ead1a51787100a Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Wed, 11 Sep 2024 13:09:39 -0500 Subject: [PATCH 03/13] fix existing tests, consider custom fixes, reask with all failures --- guardrails/validator_base.py | 7 +- .../async_validator_service.py | 32 +++- .../sequential_validator_service.py | 12 +- .../validator_service_base.py | 4 +- .../validator_parallelism_prompt_2.txt | 1 + .../validator_parallelism_reask_1.py | 5 + tests/unit_tests/test_async_guard.py | 107 ++++++++----- tests/unit_tests/test_guard.py | 100 +++++++----- tests/unit_tests/test_validator_base.py | 150 ++++++++++++++++++ .../test_async_validator_service.py | 6 + 10 files changed, 334 insertions(+), 90 deletions(-) diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index ce7484a18..a70bea84c 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -124,6 +124,7 @@ def __init__( ) self.on_fail_method = None else: + self.on_fail_descriptor = OnFailAction.CUSTOM self.on_fail_method = on_fail # Store the kwargs for the validator. @@ -336,11 +337,13 @@ def get_args(self): def __call__(self, value): result = self.validate(value, {}) if isinstance(result, FailResult): - from guardrails.validator_service import ValidatorServiceBase + from guardrails.validator_service.validator_service_base import ( + ValidatorServiceBase, + ) validator_service = ValidatorServiceBase() return validator_service.perform_correction( - [result], value, self, self.on_fail_descriptor + result, value, self, self.on_fail_descriptor ) return value diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py index 0a13c5e99..fde57b0d8 100644 --- a/guardrails/validator_service/async_validator_service.py +++ b/guardrails/validator_service/async_validator_service.py @@ -106,6 +106,7 @@ async def run_validator( return ValidatorRun( value=value, metadata=metadata, + on_fail_action=validator.on_fail_descriptor, validator_logs=validator_logs, ) @@ -136,17 +137,40 @@ async def run_validators( coroutines.append(coroutine) results = await asyncio.gather(*coroutines) + reasks: List[FieldReAsk] = [] for res in results: validators_logs.extend(res.validator_logs) # QUESTION: Do we still want to do this here or handle it during the merge? # return early if we have a filter, refrain, or reask - if isinstance(res.value, (Filter, Refrain, FieldReAsk)): + if isinstance(res.value, (Filter, Refrain)): return res.value, metadata + elif isinstance(res.value, FieldReAsk): + reasks.append(res.value) + + # handle reasks + if len(reasks) > 0: + first_reask = reasks[0] + fail_results = [] + for reask in reasks: + fail_results.extend(reask.fail_results) + first_reask.fail_results = fail_results + return first_reask, metadata # merge the results - if len(results) > 0: - values = [res.value for res in results] - value = self.merge_results(value, values) + fix_values = [ + res.value + for res in results + if ( + isinstance(res.validator_logs.validation_result, FailResult) + and ( + res.on_fail_action == OnFailAction.FIX + or res.on_fail_action == OnFailAction.FIX_REASK + or res.on_fail_action == OnFailAction.CUSTOM + ) + ) + ] + if len(fix_values) > 0: + value = self.merge_results(value, fix_values) return value, metadata diff --git a/guardrails/validator_service/sequential_validator_service.py b/guardrails/validator_service/sequential_validator_service.py index 3c5893d1c..5be48461b 100644 --- a/guardrails/validator_service/sequential_validator_service.py +++ b/guardrails/validator_service/sequential_validator_service.py @@ -171,10 +171,9 @@ def run_validators_stream_fix( break rechecked_value = None chunk = self.perform_correction( - [result], + result, chunk, validator, - validator.on_fail_descriptor, rechecked_value=rechecked_value, ) fixed_values.append(chunk) @@ -239,10 +238,9 @@ def run_validators_stream_fix( if isinstance(result, FailResult): rechecked_value = None last_chunk = self.perform_correction( - [result], + result, last_chunk, validator, - validator.on_fail_descriptor, rechecked_value=rechecked_value, ) validator_partial_acc[id(validator)] += last_chunk # type: ignore @@ -303,10 +301,9 @@ def run_validators_stream_noop( if isinstance(result, FailResult): rechecked_value = None chunk = self.perform_correction( - [result], + result, chunk, validator, - validator.on_fail_descriptor, rechecked_value=rechecked_value, ) elif isinstance(result, PassResult): @@ -391,10 +388,9 @@ def run_validators( **kwargs, ) value = self.perform_correction( - [result], + result, value, validator, - validator.on_fail_descriptor, rechecked_value=rechecked_value, ) elif isinstance(result, PassResult): diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index bff57687c..c5122855a 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -27,6 +27,7 @@ class ValidatorRun: value: Any metadata: Dict + on_fail_action: Union[str, OnFailAction] validator_logs: ValidatorLogs @@ -89,7 +90,7 @@ def perform_correction( ) return fixed_value - if on_fail_descriptor == "custom": + if on_fail_descriptor == OnFailAction.CUSTOM: if validator.on_fail_method is None: raise ValueError("on_fail is 'custom' but on_fail_method is None") return validator.on_fail_method(value, [result]) @@ -197,6 +198,7 @@ def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: current = merge( serialize(current), serialize(nextval), serialize(original_value) ) + current = deserialize(original_value, current) deserialized_value = deserialize(original_value, current) if deserialized_value is None and current is not None: # QUESTION: How do we escape hatch diff --git a/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt b/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt index a91f6ecbf..63fb66348 100644 --- a/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt +++ b/tests/integration_tests/test_assets/python_rail/validator_parallelism_prompt_2.txt @@ -9,6 +9,7 @@ Generate a new response that corrects your old response such that the following - must be exactly two words - Value Hello a you and me is not lower case. +- Value has length greater than 10. Please return a shorter output, that is shorter than 10 characters. diff --git a/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py b/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py index 5c17e7188..463fe95c7 100644 --- a/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py +++ b/tests/integration_tests/test_assets/python_rail/validator_parallelism_reask_1.py @@ -14,5 +14,10 @@ error_message="Value Hello a you\nand me is not lower case.", fix_value="hello a you\nand me", ), + FailResult( + outcome="fail", + error_message="Value has length greater than 10. Please return a shorter output, that is shorter than 10 characters.", # noqa: E501 + fix_value="Hello a yo", + ), ], ) diff --git a/tests/unit_tests/test_async_guard.py b/tests/unit_tests/test_async_guard.py index 78f63dd3c..d8b331b8d 100644 --- a/tests/unit_tests/test_async_guard.py +++ b/tests/unit_tests/test_async_guard.py @@ -445,55 +445,86 @@ def test_use_many_tuple(): ) -@pytest.mark.asyncio -async def test_validate(): - guard: AsyncGuard = ( - AsyncGuard() - .use(OneLine) - .use( - LowerCase(on_fail=OnFailAction.FIX), on="output" - ) # default on="output", still explicitly set - .use(TwoWords) - .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) - ) +# TODO: Move to integration tests; these are not unit tests... +class TestValidate: + @pytest.mark.asyncio + async def test_output_only_success(self): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output: str = "Oh Canada" # bc it meets our criteria - response = await guard.validate(llm_output) + llm_output: str = "Oh Canada" # bc it meets our criteria - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() - llm_output_2 = "Star Spangled Banner" # to stick with the theme + response = await guard.validate(llm_output) - response_2 = await guard.validate(llm_output_2) + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() - assert response_2.validation_passed is False - assert response_2.validated_output is None + @pytest.mark.asyncio + async def test_output_only_failure(self): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - # Test with a combination of prompt, output, instructions and msg_history validators - # Should still only use the output validators to validate the output - guard: AsyncGuard = ( - AsyncGuard() - .use(OneLine, on="prompt") - .use(LowerCase, on="instructions") - .use(UpperCase, on="msg_history") - .use(LowerCase, on="output", on_fail=OnFailAction.FIX) - .use(TwoWords, on="output") - .use(ValidLength, 0, 12, on="output") - ) + llm_output = "Star Spangled Banner" # to stick with the theme + + response = await guard.validate(llm_output) - llm_output: str = "Oh Canada" # bc it meets our criteria + assert response.validation_passed is False + assert response.validated_output is None - response = await guard.validate(llm_output) + @pytest.mark.asyncio + async def test_on_many_success(self): + # Test with a combination of prompt, output, + # instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() + llm_output: str = "Oh Canada" # bc it meets our criteria + + response = await guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + @pytest.mark.asyncio + async def test_on_many_failure(self): + guard: AsyncGuard = ( + AsyncGuard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output_2 = "Star Spangled Banner" # to stick with the theme + llm_output = "Star Spangled Banner" # to stick with the theme - response_2 = await guard.validate(llm_output_2) + response = await guard.validate(llm_output) - assert response_2.validation_passed is False - assert response_2.validated_output is None + assert response.validation_passed is False + assert response.validated_output is None def test_use_and_use_many(): diff --git a/tests/unit_tests/test_guard.py b/tests/unit_tests/test_guard.py index f3e951929..aad985d78 100644 --- a/tests/unit_tests/test_guard.py +++ b/tests/unit_tests/test_guard.py @@ -492,56 +492,82 @@ def test_use_many_tuple(): ) -def test_validate(): - guard: Guard = ( - Guard() - .use(OneLine) - .use( - LowerCase(on_fail=OnFailAction.FIX), on="output" - ) # default on="output", still explicitly set - .use(TwoWords) - .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) - ) +# TODO: Move to integration tests; these are not unit tests... +class TestValidate: + def test_output_only_success(self): + guard: Guard = ( + Guard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output: str = "Oh Canada" # bc it meets our criteria + llm_output: str = "Oh Canada" # bc it meets our criteria - response = guard.validate(llm_output) + response = guard.validate(llm_output) + + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() + + def test_output_only_failure(self): + guard: Guard = ( + Guard() + .use(OneLine) + .use( + LowerCase(on_fail=OnFailAction.FIX), on="output" + ) # default on="output", still explicitly set + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() + llm_output = "Star Spangled Banner" # to stick with the theme - llm_output_2 = "Star Spangled Banner" # to stick with the theme + response = guard.validate(llm_output) - response_2 = guard.validate(llm_output_2) + assert response.validation_passed is False + assert response.validated_output is None - assert response_2.validation_passed is False - assert response_2.validated_output is None + def test_on_many_success(self): + # Test with a combination of prompt, output, + # instructions and msg_history validators + # Should still only use the output validators to validate the output + guard: Guard = ( + Guard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - # Test with a combination of prompt, output, instructions and msg_history validators - # Should still only use the output validators to validate the output - guard: Guard = ( - Guard() - .use(OneLine, on="prompt") - .use(LowerCase, on="instructions") - .use(UpperCase, on="msg_history") - .use(LowerCase, on="output", on_fail=OnFailAction.FIX) - .use(TwoWords, on="output") - .use(ValidLength, 0, 12, on="output") - ) + llm_output: str = "Oh Canada" # bc it meets our criteria - llm_output: str = "Oh Canada" # bc it meets our criteria + response = guard.validate(llm_output) - response = guard.validate(llm_output) + assert response.validation_passed is True + assert response.validated_output == llm_output.lower() - assert response.validation_passed is True - assert response.validated_output == llm_output.lower() + def test_on_many_failure(self): + guard: Guard = ( + Guard() + .use(OneLine, on="prompt") + .use(LowerCase, on="instructions") + .use(UpperCase, on="msg_history") + .use(LowerCase, on="output", on_fail=OnFailAction.FIX) + .use(TwoWords) + .use(ValidLength, 0, 12, on_fail=OnFailAction.REFRAIN) + ) - llm_output_2 = "Star Spangled Banner" # to stick with the theme + llm_output = "Star Spangled Banner" # to stick with the theme - response_2 = guard.validate(llm_output_2) + response = guard.validate(llm_output) - assert response_2.validation_passed is False - assert response_2.validated_output is None + assert response.validation_passed is False + assert response.validated_output is None def test_use_and_use_many(): diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index 1b30aaee2..e6865d1d9 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -308,6 +308,156 @@ class Pet(BaseModel): assert response.validated_output == expected_result +class TestCustomOnFailHandler: + def test_custom_fix(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + expected_result = {"pet_type": "dog dog", "name": "Fido"} + + validator: Validator = TwoWords(on_fail=custom_fix_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + assert response.validation_passed is True + assert response.validated_output == expected_result + + def test_custom_reask(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + expected_result = FieldReAsk( + incorrect_value="dog", + path=["pet_type"], + fail_results=[ + FailResult( + error_message="must be exactly two words", + fix_value="dog dog", + ) + ], + ) + + validator: Validator = TwoWords(on_fail=custom_reask_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + + # Why? Because we have a bad habit of applying every fix value + # to the output even if the user doesn't ask us to. + assert response.validation_passed is True + assert guard.history.first.iterations.first.reasks[0] == expected_result + + def test_custom_exception(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + + validator: Validator = TwoWords(on_fail=custom_exception_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + with pytest.raises(ValidationError) as excinfo: + guard.parse(output, num_reasks=0) + assert str(excinfo.value) == "Something went wrong!" + + def test_custom_filter(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + + validator: Validator = TwoWords(on_fail=custom_filter_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + + # NOTE: This doesn't seem right. + # Shouldn't pass if filtering is successful on the target property? + assert response.validation_passed is False + assert response.validated_output is None + + def test_custom_refrain(self): + prompt = """ + What kind of pet should I get and what should I name it? + + ${gr.complete_json_suffix_v2} + """ + + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + + validator: Validator = TwoWords(on_fail=custom_refrain_on_fail_handler) + + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") + + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) + + response = guard.parse(output, num_reasks=0) + + assert response.validation_passed is False + assert response.validated_output is None + + class Pet(BaseModel): name: str = Field(description="a unique pet name") diff --git a/tests/unit_tests/validator_service/test_async_validator_service.py b/tests/unit_tests/validator_service/test_async_validator_service.py index fe6ef5e89..e85e2150a 100644 --- a/tests/unit_tests/validator_service/test_async_validator_service.py +++ b/tests/unit_tests/validator_service/test_async_validator_service.py @@ -317,6 +317,7 @@ async def test_filter_exits_early(self, mocker): ValidatorRun( value="mock-value", metadata={}, + on_fail_action="noop", validator_logs=ValidatorLogs( registered_name="noop_validator", validator_name="noop_validator", @@ -328,6 +329,7 @@ async def test_filter_exits_early(self, mocker): ValidatorRun( value=Filter(), metadata={}, + on_fail_action="filter", validator_logs=ValidatorLogs( registered_name="filter_validator", validator_name="filter_validator", @@ -374,6 +376,7 @@ async def test_calls_merge(self, mocker): ValidatorRun( value="mock-value", metadata={}, + on_fail_action="noop", validator_logs=ValidatorLogs( registered_name="noop_validator", validator_name="noop_validator", @@ -385,6 +388,7 @@ async def test_calls_merge(self, mocker): ValidatorRun( value="mock-fix-value", metadata={}, + on_fail_action="fix", validator_logs=ValidatorLogs( registered_name="fix_validator", validator_name="fix_validator", @@ -482,6 +486,7 @@ async def test_pass_result(self, mocker): index=0, ) validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = "noop" result = await avs.run_validator( iteration=iteration, @@ -540,6 +545,7 @@ async def test_pass_result_with_override(self, mocker): index=0, ) validator = MagicMock(spec=Validator) + validator.on_fail_descriptor = "noop" result = await avs.run_validator( iteration=iteration, From 541cbdd418b986496086c0340fd8cd15889cf2c6 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Thu, 12 Sep 2024 12:28:43 -0500 Subject: [PATCH 04/13] make async validation async all the way down; integration tests to show concurrency --- guardrails/merge.py | 12 +- guardrails/telemetry/validator_tracing.py | 63 ++++ guardrails/validator_base.py | 28 +- guardrails/validator_service/__init__.py | 9 +- .../async_validator_service.py | 56 +++- .../sequential_validator_service.py | 2 +- .../validator_service_base.py | 8 +- poetry.lock | 3 +- pyproject.toml | 5 +- .../test_async_validator_service_it.py | 275 ++++++++++++++++ .../validator_service/test_init.py | 299 ++++++++++++++++++ .../test_async_validator_service.py | 28 -- .../test_validator_service.py | 37 ++- 13 files changed, 755 insertions(+), 70 deletions(-) create mode 100644 tests/integration_tests/validator_service/test_async_validator_service_it.py create mode 100644 tests/integration_tests/validator_service/test_init.py diff --git a/guardrails/merge.py b/guardrails/merge.py index e25acd3ce..07e3d9487 100644 --- a/guardrails/merge.py +++ b/guardrails/merge.py @@ -1,4 +1,5 @@ # SOURCE: https://github.com/spyder-ide/three-merge/blob/master/three_merge/merge.py +from typing import Optional from diff_match_patch import diff_match_patch # Constants @@ -10,7 +11,12 @@ ADDITION = 1 -def merge(source: str, target: str, base: str) -> str: +def merge( + source: Optional[str], target: Optional[str], base: Optional[str] +) -> Optional[str]: + if source is None or target is None or base is None: + return None + diff1_l = DIFFER.diff_main(base, source) diff2_l = DIFFER.diff_main(base, target) @@ -75,7 +81,7 @@ def merge(source: str, target: str, base: str) -> str: invariant = "" target = (target_status, target_text) # type: ignore if advance: - prev_source_text = source[1] + prev_source_text = source[1] # type: ignore source = next(diff1, None) # type: ignore elif len(source_text) < len(target_text): # Addition performed by source @@ -119,7 +125,7 @@ def merge(source: str, target: str, base: str) -> str: invariant = "" source = (source_status, source_text) # type: ignore if advance: - prev_target_text = target[1] + prev_target_text = target[1] # type: ignore target = next(diff2, None) # type: ignore else: # Source and target are equal diff --git a/guardrails/telemetry/validator_tracing.py b/guardrails/telemetry/validator_tracing.py index 3d2904327..4f88cd23e 100644 --- a/guardrails/telemetry/validator_tracing.py +++ b/guardrails/telemetry/validator_tracing.py @@ -1,6 +1,7 @@ from functools import wraps from typing import ( Any, + Awaitable, Callable, Dict, Optional, @@ -139,3 +140,65 @@ def trace_validator_wrapper(*args, **kwargs): return trace_validator_wrapper return trace_validator_decorator + + +def trace_async_validator( + validator_name: str, + obj_id: int, + on_fail_descriptor: Optional[str] = None, + tracer: Optional[Tracer] = None, + *, + validation_session_id: str, + **init_kwargs, +): + def trace_validator_decorator( + fn: Callable[..., Awaitable[Optional[ValidationResult]]], + ): + @wraps(fn) + async def trace_validator_wrapper(*args, **kwargs): + if not settings.disable_tracing: + current_otel_context = context.get_current() + _tracer = get_tracer(tracer) or trace.get_tracer( + "guardrails-ai", GUARDRAILS_VERSION + ) + validator_span_name = f"{validator_name}.validate" + with _tracer.start_as_current_span( + name=validator_span_name, # type: ignore + context=current_otel_context, # type: ignore + ) as validator_span: + try: + resp = await fn(*args, **kwargs) + add_validator_attributes( + *args, + validator_span=validator_span, + validator_name=validator_name, + obj_id=obj_id, + on_fail_descriptor=on_fail_descriptor, + result=resp, + init_kwargs=init_kwargs, + validation_session_id=validation_session_id, + **kwargs, + ) + return resp + except Exception as e: + validator_span.set_status( + status=StatusCode.ERROR, description=str(e) + ) + add_validator_attributes( + *args, + validator_span=validator_span, + validator_name=validator_name, + obj_id=obj_id, + on_fail_descriptor=on_fail_descriptor, + result=None, + init_kwargs=init_kwargs, + validation_session_id=validation_session_id, + **kwargs, + ) + raise e + else: + return await fn(*args, **kwargs) + + return trace_validator_wrapper + + return trace_validator_decorator diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index a70bea84c..66aaf9d66 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -3,6 +3,8 @@ # - [ ] Maintain validator_base.py for exports but deprecate them # - [ ] Remove validator_base.py in 0.6.x +import asyncio +from functools import partial import inspect import logging from collections import defaultdict @@ -175,6 +177,19 @@ def validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: self._log_telemetry() return validation_result + async def async_validate( + self, value: Any, metadata: Dict[str, Any] + ) -> ValidationResult: + """Use this function if your validation logic requires asyncio. + + Guaranteed to work with AsyncGuard + + May not work with synchronous Guards if they are used within an + async context due to lack of available event loops. + """ + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.validate, value, metadata) + def _inference(self, model_input: Any) -> Any: """Calls either a local or remote inference engine for use in the validation call. @@ -256,6 +271,15 @@ def validate_stream( return validation_result + async def async_validate_stream( + self, chunk: Any, metadata: Dict[str, Any], **kwargs + ) -> Optional[ValidationResult]: + loop = asyncio.get_event_loop() + validate_stream_partial = partial( + self.validate_stream, chunk, metadata, **kwargs + ) + return await loop.run_in_executor(None, validate_stream_partial) + def _hub_inference_request( self, request_body: Union[dict, str], validation_endpoint: str ) -> Any: @@ -342,9 +366,7 @@ def __call__(self, value): ) validator_service = ValidatorServiceBase() - return validator_service.perform_correction( - result, value, self, self.on_fail_descriptor - ) + return validator_service.perform_correction(result, value, self) return value def __eq__(self, other): diff --git a/guardrails/validator_service/__init__.py b/guardrails/validator_service/__init__.py index 341168063..1ea4ef9c7 100644 --- a/guardrails/validator_service/__init__.py +++ b/guardrails/validator_service/__init__.py @@ -86,7 +86,14 @@ def validate( validator_service = SequentialValidatorService(disable_tracer) return validator_service.validate( - value, metadata, validator_map, iteration, path, path, loop=loop, **kwargs + value, + metadata, + validator_map, + iteration, + path, + path, + loop=loop, # type: ignore It exists when we need it to. + **kwargs, ) diff --git a/guardrails/validator_service/async_validator_service.py b/guardrails/validator_service/async_validator_service.py index fde57b0d8..323e6c410 100644 --- a/guardrails/validator_service/async_validator_service.py +++ b/guardrails/validator_service/async_validator_service.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Awaitable, Coroutine, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Awaitable, Coroutine, Dict, List, Optional, Tuple, Union from guardrails.actions.filter import Filter from guardrails.actions.refrain import Refrain @@ -9,6 +9,7 @@ PassResult, ValidationResult, ) +from guardrails.telemetry.validator_tracing import trace_async_validator from guardrails.types import ValidatorMap, OnFailAction from guardrails.classes.validation.validator_logs import ValidatorLogs from guardrails.actions.reask import FieldReAsk @@ -22,6 +23,32 @@ class AsyncValidatorService(ValidatorServiceBase): + async def execute_validator( + self, + validator: Validator, + value: Any, + metadata: Optional[Dict], + stream: Optional[bool] = False, + *, + validation_session_id: str, + **kwargs, + ) -> Optional[ValidationResult]: + validate_func = ( + validator.async_validate_stream if stream else validator.async_validate + ) + traced_validator = trace_async_validator( + validator_name=validator.rail_alias, + obj_id=id(validator), + on_fail_descriptor=validator.on_fail_descriptor, + validation_session_id=validation_session_id, + **validator._kwargs, + )(validate_func) + if stream: + result = await traced_validator(value, metadata, **kwargs) + else: + result = await traced_validator(value, metadata) + return result + async def run_validator_async( self, validator: Validator, @@ -32,7 +59,7 @@ async def run_validator_async( validation_session_id: str, **kwargs, ) -> ValidationResult: - result: ValidatorResult = self.execute_validator( + result = await self.execute_validator( validator, value, metadata, @@ -40,13 +67,9 @@ async def run_validator_async( validation_session_id=validation_session_id, **kwargs, ) - if asyncio.iscoroutine(result): - result = await result if result is None: result = PassResult() - else: - result = cast(ValidationResult, result) return result async def run_validator( @@ -125,21 +148,22 @@ async def run_validators( coroutines: List[Coroutine[Any, Any, ValidatorRun]] = [] validators_logs: List[ValidatorLogs] = [] for validator in validators: - coroutine: Coroutine[Any, Any, ValidatorRun] = self.run_validator( - iteration, - validator, - value, - metadata, - absolute_property_path, - stream=stream, - **kwargs, + coroutines.append( + self.run_validator( + iteration, + validator, + value, + metadata, + absolute_property_path, + stream=stream, + **kwargs, + ) ) - coroutines.append(coroutine) results = await asyncio.gather(*coroutines) reasks: List[FieldReAsk] = [] for res in results: - validators_logs.extend(res.validator_logs) + validators_logs.append(res.validator_logs) # QUESTION: Do we still want to do this here or handle it during the merge? # return early if we have a filter, refrain, or reask if isinstance(res.value, (Filter, Refrain)): diff --git a/guardrails/validator_service/sequential_validator_service.py b/guardrails/validator_service/sequential_validator_service.py index 5be48461b..e86598277 100644 --- a/guardrails/validator_service/sequential_validator_service.py +++ b/guardrails/validator_service/sequential_validator_service.py @@ -109,7 +109,7 @@ def run_validators_stream( ) # requires at least 2 validators - def multi_merge(self, original: str, new_values: list[str]) -> str: + def multi_merge(self, original: str, new_values: list[str]) -> Optional[str]: current = new_values.pop() print("Fmerging these:", new_values) while len(new_values) > 0: diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index c5122855a..6b675b251 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -52,6 +52,8 @@ def execute_validator( *, validation_session_id: str, **kwargs, + # TODO: Make this just Optional[ValidationResult] + # Also maybe move to SequentialValidatorService ) -> ValidatorResult: validate_func = validator.validate_stream if stream else validator.validate traced_validator = trace_validator( @@ -189,12 +191,6 @@ def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: current = new_values.pop() while len(new_values) > 0: nextval = new_values.pop() - # print("current:", current) - # print("serialize(current):", serialize(current)) - # print("nextval:", nextval) - # print("serialize(nextval):", serialize(nextval)) - # print("original_value:", original_value) - # print("serialize(original_value):", serialize(original_value)) current = merge( serialize(current), serialize(nextval), serialize(original_value) ) diff --git a/poetry.lock b/poetry.lock index 8c8f56368..24b48dec6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8413,9 +8413,10 @@ docs-build = ["docspec_python", "nbdoc", "pydoc-markdown"] huggingface = ["jsonformer", "torch", "transformers"] manifest = ["manifest-ml"] sql = ["sqlalchemy", "sqlglot", "sqlvalidator"] +uv = ["uvloop"] vectordb = ["faiss-cpu", "numpy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "aec41326aef66af046ce16d49c036fec48698032995f3f49df634b9da411caf7" +content-hash = "6253610141bb5686330057ae658550f9257aabe83ee7b279b783a7f4418a26a6" diff --git a/pyproject.toml b/pyproject.toml index 76a927a54..c10cf2a66 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ guardrails-api-client = ">=0.3.8" diff-match-patch = "^20230430" guardrails-api = ">=0.0.1" mlflow = {version = ">=2.0.1", optional = true} +uvloop = {version = "^0.20.0", optional = true} [tool.poetry.extras] sql = ["sqlvalidator", "sqlalchemy", "sqlglot"] @@ -70,6 +71,7 @@ docs-build = ["nbdoc", "docspec_python", "pydoc-markdown"] huggingface = ["transformers", "torch", "jsonformer"] api = ["guardrails-api"] databricks = ["mlflow"] +uv = ["uvloop"] [tool.poetry.group.dev.dependencies] @@ -106,9 +108,6 @@ cairosvg = "^2.7.1" mkdocs-glightbox = "^0.3.4" -[tool.poetry.group.uv.dependencies] -uvloop = {version = "^0.20.0", optional = true} - [[tool.poetry.source]] name = "PyPI" diff --git a/tests/integration_tests/validator_service/test_async_validator_service_it.py b/tests/integration_tests/validator_service/test_async_validator_service_it.py new file mode 100644 index 000000000..643105e8b --- /dev/null +++ b/tests/integration_tests/validator_service/test_async_validator_service_it.py @@ -0,0 +1,275 @@ +import asyncio +import pytest +from time import sleep +from guardrails.validator_base import Validator, register_validator +from guardrails.classes.validation.validation_result import PassResult + + +@register_validator(name="test/validator1", data_type="string") +class Validator1(Validator): + def validate(self, value, metadata): + # This seems more realistic but is unreliable + # counter = 0 + # for i in range(100000000): + # counter += 1 + # This seems suspicious, but is consistent + sleep(0.3) + metadata["order"].append("test/validator1") + return PassResult() + + +@register_validator(name="test/validator2", data_type="string") +class Validator2(Validator): + def validate(self, value, metadata): + # counter = 0 + # for i in range(1): + # counter += 1 + sleep(0.1) + metadata["order"].append("test/validator2") + return PassResult() + + +@register_validator(name="test/validator3", data_type="string") +class Validator3(Validator): + def validate(self, value, metadata): + # counter = 0 + # for i in range(100000): + # counter += 1 + sleep(0.2) + metadata["order"].append("test/validator3") + return PassResult() + + +@register_validator(name="test/async_validator1", data_type="string") +class AsyncValidator1(Validator): + async def async_validate(self, value, metadata): + await asyncio.sleep(0.3) + metadata["order"].append("test/async_validator1") + return PassResult() + + +@register_validator(name="test/async_validator2", data_type="string") +class AsyncValidator2(Validator): + async def async_validate(self, value, metadata): + await asyncio.sleep(0.1) + metadata["order"].append("test/async_validator2") + return PassResult() + + +@register_validator(name="test/async_validator3", data_type="string") +class AsyncValidator3(Validator): + async def async_validate(self, value, metadata): + await asyncio.sleep(0.2) + metadata["order"].append("test/async_validator3") + return PassResult() + + +class TestValidatorConcurrency: + @pytest.mark.asyncio + async def test_async_validate_with_sync_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + value, metadata = await async_validator_service.async_validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + Validator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/validator3", "test/validator1"] + } + + def test_validate_with_sync_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + loop = asyncio.get_event_loop() + value, metadata = async_validator_service.validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + Validator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=loop, + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/validator3", "test/validator1"] + } + + @pytest.mark.asyncio + async def test_async_validate_with_async_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + value, metadata = await async_validator_service.async_validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + AsyncValidator1(), + AsyncValidator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert value == "value" + assert metadata == { + "order": [ + "test/async_validator2", + "test/async_validator3", + "test/async_validator1", + ] + } + + def test_validate_with_async_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + loop = asyncio.get_event_loop() + value, metadata = async_validator_service.validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + AsyncValidator1(), + AsyncValidator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=loop, + ) + + assert value == "value" + assert metadata == { + "order": [ + "test/async_validator2", + "test/async_validator3", + "test/async_validator1", + ] + } + + @pytest.mark.asyncio + async def test_async_validate_with_mixed_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + value, metadata = await async_validator_service.async_validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/async_validator3", "test/validator1"] + } + + def test_validate_with_mixed_validators(self): + from guardrails.validator_service import AsyncValidatorService + from guardrails.classes.history import Iteration + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + async_validator_service = AsyncValidatorService() + + loop = asyncio.get_event_loop() + value, metadata = async_validator_service.validate( + value="value", + metadata={"order": []}, + validator_map={ + "$": [ + # Note the order + Validator1(), + Validator2(), + AsyncValidator3(), + ] + }, + iteration=iteration, + absolute_path="$", + reference_path="$", + loop=loop, + ) + + assert value == "value" + assert metadata == { + "order": ["test/validator2", "test/async_validator3", "test/validator1"] + } diff --git a/tests/integration_tests/validator_service/test_init.py b/tests/integration_tests/validator_service/test_init.py new file mode 100644 index 000000000..cd08ba774 --- /dev/null +++ b/tests/integration_tests/validator_service/test_init.py @@ -0,0 +1,299 @@ +from asyncio import get_event_loop +from asyncio.unix_events import _UnixSelectorEventLoop +import os +import pytest + +from guardrails.validator_service import should_run_sync, get_loop +from guardrails.classes.history import Iteration + + +try: + import uvloop +except ImportError: + uvloop = None + + +class TestShouldRunSync: + def test_process_count_is_one(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "1" + if os.environ.get("GUARDRAILS_RUN_SYNC"): + del os.environ["GUARDRAILS_RUN_SYNC"] + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + + def test_process_count_is_2(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "2" + if os.environ.get("GUARDRAILS_RUN_SYNC"): + del os.environ["GUARDRAILS_RUN_SYNC"] + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is False + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + + def test_guardrails_run_sync_is_true(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_RUN_SYNC"] = "true" + if os.environ.get("GUARDRAILS_PROCESS_COUNT"): + del os.environ["GUARDRAILS_PROCESS_COUNT"] + + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_guardrails_run_sync_is_false(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_RUN_SYNC"] = "false" + if os.environ.get("GUARDRAILS_PROCESS_COUNT"): + del os.environ["GUARDRAILS_PROCESS_COUNT"] + + result = should_run_sync() + assert result is False + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_process_count_is_1_and_guardrails_run_sync_is_false(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "1" + os.environ["GUARDRAILS_RUN_SYNC"] = "false" + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_process_count_is_2_and_guardrails_run_sync_is_true(self): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_PROCESS_COUNT"] = "2" + os.environ["GUARDRAILS_RUN_SYNC"] = "true" + + with pytest.warns( + DeprecationWarning, + match=( + "GUARDRAILS_PROCESS_COUNT is deprecated" + " and will be removed in a future release." + " To force synchronous validation," + " please use GUARDRAILS_RUN_SYNC instead." + ), + ): + result = should_run_sync() + assert result is True + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + else: + del os.environ["GUARDRAILS_PROCESS_COUNT"] + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + +class TestGetLoop: + def test_raises_if_loop_is_running(self): + loop = get_event_loop() + + async def callback(): + # NOTE: This means only AsyncGuard will parallelize validators + # if it's called within an async function. + with pytest.raises(RuntimeError, match="An event loop is already running."): + get_loop() + + loop.run_until_complete(callback()) + + @pytest.mark.skipif(uvloop is None, reason="uvloop is not installed") + def test_uvloop_is_used_when_installed(self): + loop = get_loop() + assert isinstance(loop, uvloop.Loop) + + @pytest.mark.skipif(uvloop is not None, reason="uvloop is installed") + def test_asyncio_default_is_used_otherwise(self): + loop = get_loop() + assert isinstance(loop, _UnixSelectorEventLoop) + + +class TestValidate: + def test_forced_sync(self, mocker): + GUARDRAILS_PROCESS_COUNT_bak = os.environ.get("GUARDRAILS_PROCESS_COUNT") + GUARDRAILS_RUN_SYNC_bak = os.environ.get("GUARDRAILS_RUN_SYNC") + os.environ["GUARDRAILS_RUN_SYNC"] = "true" + if os.environ.get("GUARDRAILS_PROCESS_COUNT"): + del os.environ["GUARDRAILS_PROCESS_COUNT"] + + from guardrails.validator_service import validate, SequentialValidatorService + + mocker.spy(SequentialValidatorService, "__init__") + mocker.spy(SequentialValidatorService, "validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + value, metadata = validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + + assert value == "value" + assert metadata == {} + SequentialValidatorService.__init__.assert_called_once() + SequentialValidatorService.validate.assert_called_once() + + if GUARDRAILS_PROCESS_COUNT_bak is not None: + os.environ["GUARDRAILS_PROCESS_COUNT"] = GUARDRAILS_PROCESS_COUNT_bak + if GUARDRAILS_RUN_SYNC_bak is not None: + os.environ["GUARDRAILS_RUN_SYNC"] = GUARDRAILS_RUN_SYNC_bak + else: + del os.environ["GUARDRAILS_RUN_SYNC"] + + def test_async(self, mocker): + from guardrails.validator_service import validate, AsyncValidatorService + + mocker.spy(AsyncValidatorService, "__init__") + mocker.spy(AsyncValidatorService, "validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + value, metadata = validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + + assert value == "value" + assert metadata == {} + AsyncValidatorService.__init__.assert_called_once() + AsyncValidatorService.validate.assert_called_once() + + def test_sync_busy_loop(self, mocker): + from guardrails.validator_service import validate, SequentialValidatorService + + mocker.spy(SequentialValidatorService, "__init__") + mocker.spy(SequentialValidatorService, "validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + loop = get_event_loop() + + async def callback(): + with pytest.warns( + Warning, + match=( + "Could not obtain an event loop." + " Falling back to synchronous validation." + ), + ): + value, metadata = validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + assert value == "value" + assert metadata == {} + + loop.run_until_complete(callback()) + + SequentialValidatorService.__init__.assert_called_once() + SequentialValidatorService.validate.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_validate(mocker): + from guardrails.validator_service import async_validate, AsyncValidatorService + + mocker.spy(AsyncValidatorService, "__init__") + mocker.spy(AsyncValidatorService, "async_validate") + + iteration = Iteration( + call_id="mock_call_id", + index=0, + ) + + value, metadata = await async_validate( + value="value", + metadata={}, + validator_map={}, + iteration=iteration, + ) + + assert value == "value" + assert metadata == {} + AsyncValidatorService.__init__.assert_called_once() + AsyncValidatorService.async_validate.assert_called_once() diff --git a/tests/unit_tests/validator_service/test_async_validator_service.py b/tests/unit_tests/validator_service/test_async_validator_service.py index e85e2150a..7c0ebd368 100644 --- a/tests/unit_tests/validator_service/test_async_validator_service.py +++ b/tests/unit_tests/validator_service/test_async_validator_service.py @@ -751,34 +751,6 @@ async def test_happy_path(self, mocker): mock_validator, "value", {}, False, validation_session_id="mock-session" ) - @pytest.mark.asyncio - async def test_result_is_a_coroutine(self, mocker): - mock_validator = MagicMock(spec=Validator) - - validation_result = PassResult() - - async def result_coroutine(): - return validation_result - - mock_execute_validator = mocker.patch.object( - avs, "execute_validator", return_value=result_coroutine() - ) - - result = await avs.run_validator_async( - validator=mock_validator, - value="value", - metadata={}, - stream=False, - validation_session_id="mock-session", - ) - - assert result == validation_result - - assert mock_execute_validator.call_count == 1 - mock_execute_validator.assert_called_once_with( - mock_validator, "value", {}, False, validation_session_id="mock-session" - ) - @pytest.mark.asyncio async def test_result_is_none(self, mocker): mock_validator = MagicMock(spec=Validator) diff --git a/tests/unit_tests/validator_service/test_validator_service.py b/tests/unit_tests/validator_service/test_validator_service.py index 4514ad22c..7d0be1ae0 100644 --- a/tests/unit_tests/validator_service/test_validator_service.py +++ b/tests/unit_tests/validator_service/test_validator_service.py @@ -13,27 +13,42 @@ class TestShouldRunSync: def test_process_count_of_1(self, mocker): - mocker.patch("os.environ.get", side_effect=["1", "false"]) + mocker.patch( + "guardrails.validator_service.os.environ.get", side_effect=["1", "false"] + ) assert vs.should_run_sync() is True def test_run_sync_set_to_true(self, mocker): - mocker.patch("os.environ.get", side_effect=["10", "True"]) + mocker.patch( + "guardrails.validator_service.os.environ.get", side_effect=["10", "True"] + ) assert vs.should_run_sync() is True def test_should_run_sync_default(self, mocker): - mocker.patch("os.environ.get", side_effect=["10", "false"]) + mocker.patch( + "guardrails.validator_service.os.environ.get", side_effect=["10", "false"] + ) assert vs.should_run_sync() is False class TestGetLoop: def test_get_loop_with_running_loop(self, mocker): - mocker.patch("asyncio.get_running_loop", return_value="running loop") + mocker.patch( + "guardrails.validator_service.asyncio.get_running_loop", + return_value="running loop", + ) with pytest.raises(RuntimeError): vs.get_loop() def test_get_loop_without_running_loop(self, mocker): - mocker.patch("asyncio.get_running_loop", side_effect=RuntimeError) - mocker.patch("asyncio.get_event_loop", return_value="event loop") + mocker.patch( + "guardrails.validator_service.asyncio.get_running_loop", + side_effect=RuntimeError, + ) + mocker.patch( + "guardrails.validator_service.asyncio.get_event_loop", + return_value="event loop", + ) assert vs.get_loop() == "event loop" def test_get_loop_with_uvloop(self, mocker): @@ -41,8 +56,14 @@ def test_get_loop_with_uvloop(self, mocker): mock_event_loop_policy = mocker.patch( "guardrails.validator_service.uvloop.EventLoopPolicy" ) - mocker.patch("asyncio.get_running_loop", side_effect=RuntimeError) - mocker.patch("asyncio.get_event_loop", return_value="event loop") + mocker.patch( + "guardrails.validator_service.asyncio.get_running_loop", + side_effect=RuntimeError, + ) + mocker.patch( + "guardrails.validator_service.asyncio.get_event_loop", + return_value="event loop", + ) mock_set_event_loop_policy = mocker.patch("asyncio.set_event_loop_policy") assert vs.get_loop() == "event loop" From 9dfa0dbf076f43f1499b20ab83cbe770ed80b77f Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 13 Sep 2024 10:16:43 -0500 Subject: [PATCH 05/13] stricter typing for custom on_fail methods --- guardrails/validator_base.py | 31 ++++- .../validator_service_base.py | 2 +- tests/unit_tests/test_validator_base.py | 119 ++++++------------ 3 files changed, 70 insertions(+), 82 deletions(-) diff --git a/guardrails/validator_base.py b/guardrails/validator_base.py index 66aaf9d66..c773e7c30 100644 --- a/guardrails/validator_base.py +++ b/guardrails/validator_base.py @@ -12,6 +12,7 @@ from string import Template from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union from warnings import warn +import warnings import nltk import requests @@ -26,6 +27,7 @@ from guardrails.logger import logger from guardrails.remote_inference import remote_inference from guardrails.types.on_fail import OnFailAction +from guardrails.utils.safe_get import safe_get from guardrails.utils.hub_telemetry_utils import HubTelemetry # See: https://github.com/guardrails-ai/guardrails/issues/829 @@ -78,7 +80,7 @@ class Validator: def __init__( self, - on_fail: Optional[Union[Callable, OnFailAction]] = None, + on_fail: Optional[Union[Callable[[Any, FailResult], Any], OnFailAction]] = None, **kwargs, ): self.creds = Credentials.from_rc_file() @@ -127,7 +129,7 @@ def __init__( self.on_fail_method = None else: self.on_fail_descriptor = OnFailAction.CUSTOM - self.on_fail_method = on_fail + self._set_on_fail_method(on_fail) # Store the kwargs for the validator. self._kwargs = kwargs @@ -136,6 +138,31 @@ def __init__( self.rail_alias in validators_registry ), f"Validator {self.__class__.__name__} is not registered. " + def _set_on_fail_method(self, on_fail: Callable[[Any, FailResult], Any]): + """Set the on_fail method for the validator.""" + on_fail_args = inspect.getfullargspec(on_fail) + second_arg = safe_get(on_fail_args.args, 1) + if second_arg is None: + raise ValueError( + "The on_fail method must take two arguments: " + "the value being validated and the FailResult." + ) + second_arg_type = on_fail_args.annotations.get(second_arg) + if second_arg_type == List[FailResult]: + warnings.warn( + "Specifying a List[FailResult] as the second argument" + " for a custom on_fail handler is deprecated. " + "Please use FailResult instead.", + DeprecationWarning, + ) + + def on_fail_wrapper(value: Any, fail_result: FailResult) -> Any: + return on_fail(value, [fail_result]) # type: ignore + + self.on_fail_method = on_fail_wrapper + else: + self.on_fail_method = on_fail + def _validate(self, value: Any, metadata: Dict[str, Any]) -> ValidationResult: """User implementable function. diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index 6b675b251..5da7e33a6 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -95,7 +95,7 @@ def perform_correction( if on_fail_descriptor == OnFailAction.CUSTOM: if validator.on_fail_method is None: raise ValueError("on_fail is 'custom' but on_fail_method is None") - return validator.on_fail_method(value, [result]) + return validator.on_fail_method(value, result) if on_fail_descriptor == OnFailAction.REASK: return FieldReAsk( incorrect_value=value, diff --git a/tests/unit_tests/test_validator_base.py b/tests/unit_tests/test_validator_base.py index e6865d1d9..78069c349 100644 --- a/tests/unit_tests/test_validator_base.py +++ b/tests/unit_tests/test_validator_base.py @@ -1,4 +1,5 @@ import json +import re from typing import Any, Dict, List import pytest @@ -209,106 +210,66 @@ def test_to_xml_attrib(min, max, expected_xml): assert xml_validator == expected_xml -def custom_fix_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_deprecated_on_fail_handler(value: Any, fail_results: List[FailResult]): + return value + " deprecated" + + +def custom_fix_on_fail_handler(value: Any, fail_result: FailResult): return value + " " + value -def custom_reask_on_fail_handler(value: Any, fail_results: List[FailResult]): - return FieldReAsk(incorrect_value=value, fail_results=fail_results) +def custom_reask_on_fail_handler(value: Any, fail_result: FailResult): + return FieldReAsk(incorrect_value=value, fail_results=[fail_result]) -def custom_exception_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_exception_on_fail_handler(value: Any, fail_result: FailResult): raise ValidationError("Something went wrong!") -def custom_filter_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_filter_on_fail_handler(value: Any, fail_result: FailResult): return Filter() -def custom_refrain_on_fail_handler(value: Any, fail_results: List[FailResult]): +def custom_refrain_on_fail_handler(value: Any, fail_result: FailResult): return Refrain() -@pytest.mark.parametrize( - "custom_reask_func, expected_result", - [ - ( - custom_fix_on_fail_handler, - {"pet_type": "dog dog", "name": "Fido"}, - ), - ( - custom_reask_on_fail_handler, - FieldReAsk( - incorrect_value="dog", - path=["pet_type"], - fail_results=[ - FailResult( - error_message="must be exactly two words", - fix_value="dog dog", - ) - ], - ), - ), - ( - custom_exception_on_fail_handler, - ValidationError, - ), - ( - custom_filter_on_fail_handler, - None, - ), - ( - custom_refrain_on_fail_handler, - None, - ), - ], -) -# @pytest.mark.parametrize( -# "validator_spec", -# [ -# lambda val_func: TwoWords(on_fail=val_func), -# # This was never supported even pre-0.5.x. -# # Trying this with function calling will throw. -# lambda val_func: ("two-words", val_func), -# ], -# ) -def test_custom_on_fail_handler( - custom_reask_func, - expected_result, -): - prompt = """ - What kind of pet should I get and what should I name it? +class TestCustomOnFailHandler: + def test_deprecated_on_fail_handler(self): + prompt = """ + What kind of pet should I get and what should I name it? - ${gr.complete_json_suffix_v2} - """ + ${gr.complete_json_suffix_v2} + """ - output = """ - { - "pet_type": "dog", - "name": "Fido" - } - """ + output = """ + { + "pet_type": "dog", + "name": "Fido" + } + """ + expected_result = {"pet_type": "dog deprecated", "name": "Fido"} + + with pytest.warns( + DeprecationWarning, + match=re.escape( # Becuase of square brackets in the message + "Specifying a List[FailResult] as the second argument" + " for a custom on_fail handler is deprecated. " + "Please use FailResult instead." + ), + ): + validator: Validator = TwoWords(on_fail=custom_deprecated_on_fail_handler) # type: ignore - validator: Validator = TwoWords(on_fail=custom_reask_func) + class Pet(BaseModel): + pet_type: str = Field(description="Species of pet", validators=[validator]) + name: str = Field(description="a unique pet name") - class Pet(BaseModel): - pet_type: str = Field(description="Species of pet", validators=[validator]) - name: str = Field(description="a unique pet name") + guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) - guard = Guard.from_pydantic(output_class=Pet, prompt=prompt) - if isinstance(expected_result, type) and issubclass(expected_result, Exception): - with pytest.raises(ValidationError) as excinfo: - guard.parse(output, num_reasks=0) - assert str(excinfo.value) == "Something went wrong!" - else: response = guard.parse(output, num_reasks=0) - if isinstance(expected_result, FieldReAsk): - assert guard.history.first.iterations.first.reasks[0] == expected_result - else: - assert response.validated_output == expected_result - + assert response.validation_passed is True + assert response.validated_output == expected_result -class TestCustomOnFailHandler: def test_custom_fix(self): prompt = """ What kind of pet should I get and what should I name it? From d3c8eb402611433e8803b130dde2f5160b8d8b0c Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 13 Sep 2024 15:01:26 -0500 Subject: [PATCH 06/13] document custom on fail --- docs/how_to_guides/custom_validators.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/how_to_guides/custom_validators.md b/docs/how_to_guides/custom_validators.md index ee113ae64..5c5446605 100644 --- a/docs/how_to_guides/custom_validators.md +++ b/docs/how_to_guides/custom_validators.md @@ -75,6 +75,7 @@ Validators ship with several out of the box `on_fail` policies. The `OnFailActio | `OnFailAction.NOOP` | Do nothing. The failure will still be recorded in the logs, but no corrective action will be taken. | | `OnFailAction.EXCEPTION` | Raise an exception when validation fails. | | `OnFailAction.FIX_REASK` | First, fix the generated output deterministically, and then rerun validation with the deterministically fixed output. If validation fails, then perform reasking. | +| `OnFailAction.CUSTOM` | This action is set internally when the validator is passed a custom function to handle failures. The function is called with the value that failed validation and the FailResult returned from the Validator. i.e. the custom on fail handler must implement the method signature `def on_fail(value: Any, fail_result: FailResult) -> Any` | In the code below, a `fix_value` will be supplied in the `FailResult`. This value will represent a programmatic fix that can be applied to the output if `on_fail='fix'` is passed during validator initialization. ```py From 4badf318ef4fc97e2dadce847e05d07c02546cfc Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 13 Sep 2024 15:01:44 -0500 Subject: [PATCH 07/13] dont deserialize twice --- guardrails/validator_service/validator_service_base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/guardrails/validator_service/validator_service_base.py b/guardrails/validator_service/validator_service_base.py index 5da7e33a6..d580a9ff0 100644 --- a/guardrails/validator_service/validator_service_base.py +++ b/guardrails/validator_service/validator_service_base.py @@ -195,8 +195,7 @@ def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: serialize(current), serialize(nextval), serialize(original_value) ) current = deserialize(original_value, current) - deserialized_value = deserialize(original_value, current) - if deserialized_value is None and current is not None: + if current is None and original_value is not None: # QUESTION: How do we escape hatch # for when deserializing the merged value fails? @@ -205,4 +204,4 @@ def merge_results(self, original_value: Any, new_values: list[Any]) -> Any: # Or just pick one of the new values? return new_vals[0] - return deserialized_value + return current From 224dddc9d8a17496aa0cfef63278111ddbf0fde0 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 13 Sep 2024 15:01:53 -0500 Subject: [PATCH 08/13] export ValidationOutcome from main init --- guardrails/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guardrails/__init__.py b/guardrails/__init__.py index 7c2f5f255..6dcb563b7 100644 --- a/guardrails/__init__.py +++ b/guardrails/__init__.py @@ -10,6 +10,7 @@ from guardrails.validator_base import Validator, register_validator from guardrails.settings import settings from guardrails.hub.install import install +from guardrails.classes.validation_outcome import ValidationOutcome __all__ = [ "Guard", @@ -25,4 +26,5 @@ "Instructions", "settings", "install", + "ValidationOutcome", ] From c6d052b883332b70aa40c01a9527808180bcb032 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Fri, 13 Sep 2024 15:03:36 -0500 Subject: [PATCH 09/13] parallel -> concurrent, document on fail in concurrency --- .../{parallelization.md => concurrency.md} | 137 +++++++++++++++++- docusaurus/sidebars.js | 2 +- 2 files changed, 131 insertions(+), 8 deletions(-) rename docs/concepts/{parallelization.md => concurrency.md} (54%) diff --git a/docs/concepts/parallelization.md b/docs/concepts/concurrency.md similarity index 54% rename from docs/concepts/parallelization.md rename to docs/concepts/concurrency.md index 1c4a069e5..437880cdb 100644 --- a/docs/concepts/parallelization.md +++ b/docs/concepts/concurrency.md @@ -1,11 +1,11 @@ -# Parallelization +# Concurrency ## And the Orchestration of Guard Executions This document is a description of the current implementation of the Guardrails' validation loop. It attempts to explain the current patterns used with some notes on why those patterns were accepted at the time of implementation and potential future optimizations. It is _not_ meant to be prescriptive as there can, and will, be improvements made in future versions. In general you will find that our approach to performance is two fold: 1. Complete computationally cheaper, static checks first and exit early to avoid spending time and resources on more expensive checks that are unlikely to pass when the former fail. -2. Parallelize processing where possible. +2. Run processes concurrently where possible. ## Background: The Validation Loop When a Guard is executed, that is called via `guard()`, `guard.parse()`, `guard.validate()`, etc., it goes through an internal process that has the following steps: @@ -60,7 +60,7 @@ Besides handling asynchronous calls to the LLM, using an `AsyncGuard` also ensur * An asyncio event loop is available. * The asyncio event loop is not taken/already running. -## Validation Orchestration and Parallelization +## Validation Orchestration and Concurrency ### Structured Data Validation We perform validation with a "deep-first" approach. This has no meaning for unstructured text output since there is only one value, but for structured output it means that the objects are validated from the inside out. @@ -79,7 +79,7 @@ Take the below structure as an example: } ``` -As of versions v0.4.x and v0.5.x of Guardrails, the above object would validated as follows: +As of versions v0.4.x and v0.5.x of Guardrails, the above object would be validated as follows: 1. foo.baz 2. foo.bez @@ -88,11 +88,134 @@ As of versions v0.4.x and v0.5.x of Guardrails, the above object would validated 5. bar.buz 6. bar - > NOTE: The approach currently used, and outlined above, was predicated on the assumption that if child properties fail validation, it is unlikely that the parent property would pass. With the current atomic state of validation, it can be argued that this assumption is false. That is, the types of validations applied to parent properties typically take the form of checking the appropriate format of the container like a length check on a list. These types of checks are generally independent of any requirements the child properties have. This opens up the possibility of running all six paths listed above in parallel at once instead of performing them in steps based on key path. + > NOTE: The approach currently used, and outlined above, was predicated on the assumption that if child properties fail validation, it is unlikely that the parent property would pass. With the current atomic state of validation, it can be argued that this assumption is false. That is, the types of validations applied to parent properties typically take the form of checking the appropriate format of the container like a length check on a list. These types of checks are generally independent of any requirements the child properties have. This opens up the possibility of running all six paths listed above concurrently instead of performing them in steps based on key path. When synchronous validation occurs as defined in [Benefits of AsyncGuard](#benefits-of-async-guard), the validators for each property would be run in the order they are defined on the schema. That also means that any on fail actions are applied in that same order. -When asynchronous validation occurs, there are multiple levels of parallelization possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen in parallel via the asyncio event loop. Second, within the validation for each property, if the validators have `run_in_separate_process` set to `True`, they are run in parallel via multiprocessing. This multiprocessing is capped to the process count specified by the `GUARDRAILS_PROCESS_COUNT` environment variable which defaults to 10. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to set this environment variable to 1. +When asynchronous validation occurs, there are multiple levels of concurrency possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen concurrently via the asyncio event loop. Second, the validators on any given property are also run concurrently via the event loop. For validators that only define a synchronous `validate` method, calls to this method are run in the event loops default executore. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to either set the executor to a thread processor instead or limit validation to running synchronously by setting `GUARDRAILS_PROCESS_COUNT=1` or `GUARDRAILS_RUN_SYNC=true`. ### Unstructured Data Validation -When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run in parallel utilizing multiprocessing when `run_in_separate_process` is set to `True` on the validators. \ No newline at end of file +When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run concurrently utilizing the event loop. + +### Handling Failures During Async Concurrency +The Guardrails validation loop is opinionated about how it handles failures when running validators concurrently so that it spends the least amount of time processing an output that would result in a failure. It's behaviour comes down to when and what it returns based on the [corrective action](/how_to_guides/custom_validators#on-fail) specified on a validator. Corrective actions are processed concurrently since they are specific to a given validator on a given property. This means that interuptive corrective actions, namely `EXCEPTION`, will be the first corrective action enforced because the exception is raised as soon as the failure is evaluated. The remaining actions are handled in the following order after all futures are collected from the validation of a specific property: +1. `FILTER` and `REFRAIN` +2. `REASK` +3. `FIX` + + \*_NOTE:_ `NOOP` Does not require any special handling because it does not alter the value. + + \*_NOTE:_ `FIX_REASK` Will fall into either the `REASK` or `FIX` bucket based on if the fixed value passes the second round of validation. + +This means that if any valdiator with `on_fail=OnFailAction.EXCEPTION` returns a `FailResult`, then Guardrails will raise a `ValidationError` interrupting the process. + +If any validator on a specific property which has `on_fail=OnFailAction.FILTER` or `on_fail=OnFailAction.REFRAIN` returns a `FailResult`, whichever of these is the first to finish will the returned early as the value for that property, + +If any validator on a specific property which has `on_fail=OnFailAction.REASK` returns a `FailResult`, all reasks for that property will be merged and a `FieldReAsk` will be returned early as the value for that property. + +If any validator on a specific property which has `on_fail=OnFailAction.FIX` returns a `FailResult`, all fix values for that property will be merged and the result of that merge will be returned as the value for that property. + +Custom on_fail handlers will fall into one of the above actions based on what it returns; i.e. if it returns an updated value it's considered a `FIX`, if it returns an instance of `Filter` then `FILTER`, etc.. + +Let's look at an example. We'll keep the validation logic simple and write out some assertions to demonstrate the evaluation order discussed above. + +```py +import asyncio +from random import randint +from typing import Optional +from guardrails import AsyncGuard, ValidationOutcome +from guardrails.errors import ValidationError +from guardrails.validators import ( + Validator, + register_validator, + ValidationResult, + PassResult, + FailResult +) + +@register_validator(name='custom/contains', data_type='string') +class Contains(Validator): + def __init__(self, match_value: str, **kwargs): + super().__init__( + match_value=match_value, + **kwargs + ) + self.match_value = match_value + + def validate(self, value, metadata = {}) -> ValidationResult: + if self.match_value in value: + return PassResult() + + fix_value = None + if self.on_fail_descriptor == 'fix': + # Insert the match_value into the value at a random index + insertion = randint(0, len(value)) + fix_value = f"{value[:insertion]}{self.match_value}{value[insertion:]}" + + return FailResult( + error_message=f'Value must contain {self.match_value}', + fix_value=fix_value + ) + +exception_validator = Contains("a", on_fail='exception') +filter_validator = Contains("b", on_fail='filter') +refrain_validator = Contains("c", on_fail='refrain') +reask_validator_1 = Contains("d", on_fail='reask') +reask_validator_2 = Contains("e", on_fail='reask') +fix_validator_1 = Contains("f", on_fail='fix') +fix_validator_2 = Contains("g", on_fail='fix') + +guard = AsyncGuard().use_many( + exception_validator, + filter_validator, + refrain_validator, + reask_validator_1, + reask_validator_2, + fix_validator_1, + fix_validator_2 +) + +### Trigger the exception validator ### +error = None +result: Optional[ValidationOutcome] = None +try: + result = asyncio.run(guard.validate("z", metadata={})) + +except ValidationError as e: + error = e + +assert result is None +assert error is not None +assert str(error) == "Validation failed for field with errors: Value must contain a" + + + +### Trigger the Filter and Refrain validators ### +result = asyncio.run(guard.validate("a", metadata={})) + +assert result.validation_passed is False +# The output was filtered or refrained +assert result.validated_output is None +assert result.reask is None + + + +### Trigger the Reask validator ### +result = asyncio.run(guard.validate("abc", metadata={})) + +assert result.validation_passed is False +# If allowed, a ReAsk would have occured +assert result.reask is not None +error_messages = [f.error_message for f in result.reask.fail_results] +assert error_messages == ["Value must contain d", "Value must contain e"] + + +### Trigger the Fix validator ### +result = asyncio.run(guard.validate("abcde", metadata={})) + +assert result.validation_passed is True +# The fix values have been merged +assert "f" in result.validated_output +assert "g" in result.validated_output +print(result.validated_output) +``` \ No newline at end of file diff --git a/docusaurus/sidebars.js b/docusaurus/sidebars.js index 1cce74653..ddb5a7034 100644 --- a/docusaurus/sidebars.js +++ b/docusaurus/sidebars.js @@ -66,7 +66,7 @@ const sidebars = { "concepts/streaming_fixes", ], }, - "concepts/parallelization", + "concepts/concurrency", "concepts/logs", "concepts/telemetry", "concepts/error_remediation", From 6b06136351f180786249bfd659eed7f2aef79c64 Mon Sep 17 00:00:00 2001 From: Caleb Courier Date: Mon, 16 Sep 2024 11:41:49 -0500 Subject: [PATCH 10/13] update faq regarding custom fix --- docs/faq.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/faq.md b/docs/faq.md index 5d92c9a44..72e5e8ff9 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -58,7 +58,7 @@ You can override the `fix` behavior by passing it as a function to the Guard obj ```python from guardrails import Guard -def fix_is_cake(value, metadata): +def fix_is_cake(value, fail_result: FailResult): return "IT IS cake" guard = Guard().use(is_cake, on_fail=fix_is_cake) From 435a1806a502430ac4214ea94d35ce88cd1feec8 Mon Sep 17 00:00:00 2001 From: Caleb Courier <13314870+CalebCourier@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:12:05 -0500 Subject: [PATCH 11/13] typo Co-authored-by: dtam --- docs/concepts/concurrency.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/concepts/concurrency.md b/docs/concepts/concurrency.md index 437880cdb..1e9ebd69f 100644 --- a/docs/concepts/concurrency.md +++ b/docs/concepts/concurrency.md @@ -92,7 +92,7 @@ As of versions v0.4.x and v0.5.x of Guardrails, the above object would be valida When synchronous validation occurs as defined in [Benefits of AsyncGuard](#benefits-of-async-guard), the validators for each property would be run in the order they are defined on the schema. That also means that any on fail actions are applied in that same order. -When asynchronous validation occurs, there are multiple levels of concurrency possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen concurrently via the asyncio event loop. Second, the validators on any given property are also run concurrently via the event loop. For validators that only define a synchronous `validate` method, calls to this method are run in the event loops default executore. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to either set the executor to a thread processor instead or limit validation to running synchronously by setting `GUARDRAILS_PROCESS_COUNT=1` or `GUARDRAILS_RUN_SYNC=true`. +When asynchronous validation occurs, there are multiple levels of concurrency possible. First, running validation on the child properties (e.g. `foo.baz` and `foo.bez`) will happen concurrently via the asyncio event loop. Second, the validators on any given property are also run concurrently via the event loop. For validators that only define a synchronous `validate` method, calls to this method are run in the event loops default executor. Note that some environments, like AWS Lambda, may not support multiprocessing in which case you would need to either set the executor to a thread processor instead or limit validation to running synchronously by setting `GUARDRAILS_PROCESS_COUNT=1` or `GUARDRAILS_RUN_SYNC=true`. ### Unstructured Data Validation When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run concurrently utilizing the event loop. From 70b929a757f9e55d3c71383108b3e173d6e3b842 Mon Sep 17 00:00:00 2001 From: Caleb Courier <13314870+CalebCourier@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:12:30 -0500 Subject: [PATCH 12/13] typos Co-authored-by: dtam --- docs/concepts/concurrency.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/concepts/concurrency.md b/docs/concepts/concurrency.md index 1e9ebd69f..f5c85ca56 100644 --- a/docs/concepts/concurrency.md +++ b/docs/concepts/concurrency.md @@ -98,7 +98,7 @@ When asynchronous validation occurs, there are multiple levels of concurrency po When validating unstructured data, i.e. text, the LLM output is treated the same as if it were a property on an object. This means that the validators applied to is have the ability to run concurrently utilizing the event loop. ### Handling Failures During Async Concurrency -The Guardrails validation loop is opinionated about how it handles failures when running validators concurrently so that it spends the least amount of time processing an output that would result in a failure. It's behaviour comes down to when and what it returns based on the [corrective action](/how_to_guides/custom_validators#on-fail) specified on a validator. Corrective actions are processed concurrently since they are specific to a given validator on a given property. This means that interuptive corrective actions, namely `EXCEPTION`, will be the first corrective action enforced because the exception is raised as soon as the failure is evaluated. The remaining actions are handled in the following order after all futures are collected from the validation of a specific property: +The Guardrails validation loop is opinionated about how it handles failures when running validators concurrently so that it spends the least amount of time processing an output that would result in a failure. It's behavior comes down to when and what it returns based on the [corrective action](/how_to_guides/custom_validators#on-fail) specified on a validator. Corrective actions are processed concurrently since they are specific to a given validator on a given property. This means that interruptive corrective actions, namely `EXCEPTION`, will be the first corrective action enforced because the exception is raised as soon as the failure is evaluated. The remaining actions are handled in the following order after all futures are collected from the validation of a specific property: 1. `FILTER` and `REFRAIN` 2. `REASK` 3. `FIX` From cfff6e1dc15cc94e52a3d51b86562a8edafcc39d Mon Sep 17 00:00:00 2001 From: Caleb Courier <13314870+CalebCourier@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:12:45 -0500 Subject: [PATCH 13/13] typo Co-authored-by: dtam --- docs/concepts/concurrency.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/concepts/concurrency.md b/docs/concepts/concurrency.md index f5c85ca56..bbe23fad8 100644 --- a/docs/concepts/concurrency.md +++ b/docs/concepts/concurrency.md @@ -107,7 +107,7 @@ The Guardrails validation loop is opinionated about how it handles failures when \*_NOTE:_ `FIX_REASK` Will fall into either the `REASK` or `FIX` bucket based on if the fixed value passes the second round of validation. -This means that if any valdiator with `on_fail=OnFailAction.EXCEPTION` returns a `FailResult`, then Guardrails will raise a `ValidationError` interrupting the process. +This means that if any validator with `on_fail=OnFailAction.EXCEPTION` returns a `FailResult`, then Guardrails will raise a `ValidationError` interrupting the process. If any validator on a specific property which has `on_fail=OnFailAction.FILTER` or `on_fail=OnFailAction.REFRAIN` returns a `FailResult`, whichever of these is the first to finish will the returned early as the value for that property,