Skip to content

Commit

Permalink
Merge pull request #1164 from guardrails-ai/temp-fix-async-mlflow
Browse files Browse the repository at this point in the history
Hotfix for MLFlow validator spans during async execution
  • Loading branch information
dtam authored Nov 14, 2024
2 parents 358cdbf + ee3fdc1 commit 93a6a36
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
75 changes: 75 additions & 0 deletions guardrails/integrations/databricks/ml_flow_instrumentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def instrument(self):
export.validate
)
setattr(export, "validate", wrapped_validator_validate)

wrapped_validator_async_validate = (
self._instrument_validator_async_validate(export.async_validate)
)
setattr(export, "async_validate", wrapped_validator_async_validate)

setattr(guardrails.hub, validator_name, export) # type: ignore

def _instrument_guard(
Expand Down Expand Up @@ -387,6 +393,14 @@ def trace_validator_wrapper(*args, **kwargs):
init_kwargs = validator_self._kwargs

validator_span_name = f"{validator_name}.validate"

# Skip this instrumentation in the case of async
# when the parent span cannot be fetched from the current context
# because Validator.validate is running in a ThreadPoolExecutor
parent_span = mlflow.get_current_active_span()
if not parent_span:
return validator_validate(*args, **kwargs)

with mlflow.start_span(
name=validator_span_name,
span_type="validator",
Expand Down Expand Up @@ -425,3 +439,64 @@ def trace_validator_wrapper(*args, **kwargs):
raise e

return trace_validator_wrapper

def _instrument_validator_async_validate(
self,
validator_async_validate: Callable[..., Coroutine[Any, Any, ValidationResult]],
):
@wraps(validator_async_validate)
async def trace_async_validator_wrapper(*args, **kwargs):
validator_name = "validator"
obj_id = id(validator_async_validate)
on_fail_descriptor = "unknown"
init_kwargs = {}
validation_session_id = "unknown"

validator_self = args[0]
if validator_self is not None and isinstance(validator_self, Validator):
validator_name = validator_self.rail_alias
obj_id = id(validator_self)
on_fail_descriptor = validator_self.on_fail_descriptor
init_kwargs = validator_self._kwargs

validator_span_name = f"{validator_name}.validate"

with mlflow.start_span(
name=validator_span_name,
span_type="validator",
attributes={
"guardrails.version": GUARDRAILS_VERSION,
"type": "guardrails/guard/step/validator",
"async": True,
},
) as validator_span:
try:
resp = await validator_async_validate(*args, **kwargs)
add_validator_attributes(
*args,
validator_span=validator_span, # type: ignore
validator_name=validator_name,
obj_id=obj_id,
on_fail_descriptor=on_fail_descriptor,
result=resp,
init_kwargs=init_kwargs,
validation_session_id=validation_session_id,
**kwargs,
)
return resp
except Exception as e:
validator_span.set_status(status=SpanStatusCode.ERROR)
add_validator_attributes(
*args,
validator_span=validator_span, # type: ignore
validator_name=validator_name,
obj_id=obj_id,
on_fail_descriptor=on_fail_descriptor,
result=None,
init_kwargs=init_kwargs,
validation_session_id=validation_session_id,
**kwargs,
)
raise e

return trace_async_validator_wrapper
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,10 @@ async def test__instrument_async_runner_call(self, mocker):

def test__instrument_validator_validate(self, mocker):
mock_span = MockSpan()
mock_start_span = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.get_current_active_span",
return_value=mock_span,
)
mock_start_span = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.start_span",
return_value=mock_span,
Expand Down Expand Up @@ -630,3 +634,52 @@ def test__instrument_validator_validate(self, mocker):
init_kwargs={},
validation_session_id="unknown",
)

@pytest.mark.asyncio
async def test__instrument_validator_async_validate(self, mocker):
mock_span = MockSpan()
mock_start_span = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.mlflow.start_span",
return_value=mock_span,
)

mock_add_validator_attributes = mocker.patch(
"guardrails.integrations.databricks.ml_flow_instrumentor.add_validator_attributes"
)

from guardrails.integrations.databricks import MlFlowInstrumentor
from tests.unit_tests.mocks.mock_hub import MockValidator

m = MlFlowInstrumentor("mock experiment")

wrapped_async_validate = m._instrument_validator_async_validate(
MockValidator.async_validate
)

mock_validator = MockValidator()

resp = await wrapped_async_validate(mock_validator, True, {})

mock_start_span.assert_called_once_with(
name="mock-validator.validate",
span_type="validator",
attributes={
"guardrails.version": GUARDRAILS_VERSION,
"type": "guardrails/guard/step/validator",
"async": True,
},
)

# Internally called, not the wrapped call above
mock_add_validator_attributes.assert_called_once_with(
mock_validator,
True,
{},
validator_span=mock_span, # type: ignore
validator_name="mock-validator",
obj_id=id(mock_validator),
on_fail_descriptor="exception",
result=resp,
init_kwargs={},
validation_session_id="unknown",
)

0 comments on commit 93a6a36

Please sign in to comment.