Skip to content

Commit

Permalink
Record spans and add tests for Vertex AI instrumentation
Browse files Browse the repository at this point in the history
  • Loading branch information
aabmass committed Jan 16, 2025
1 parent d2ae60f commit 105cf3c
Show file tree
Hide file tree
Showing 9 changed files with 746 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,17 @@

from typing import Any, Collection

from wrapt import (
wrap_function_wrapper, # type: ignore[reportUnknownVariableType]
)

from opentelemetry._events import get_event_logger
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.vertexai.package import _instruments
from opentelemetry.instrumentation.vertexai.patch import (
generate_content_create,
)
from opentelemetry.instrumentation.vertexai.utils import is_content_enabled
from opentelemetry.semconv.schemas import Schemas
from opentelemetry.trace import get_tracer

Expand All @@ -55,20 +63,29 @@ def instrumentation_dependencies(self) -> Collection[str]:
def _instrument(self, **kwargs: Any):
"""Enable VertexAI instrumentation."""
tracer_provider = kwargs.get("tracer_provider")
_tracer = get_tracer(
tracer = get_tracer(
__name__,
"",
tracer_provider,
schema_url=Schemas.V1_28_0.value,
)
event_logger_provider = kwargs.get("event_logger_provider")
_event_logger = get_event_logger(
event_logger = get_event_logger(
__name__,
"",
schema_url=Schemas.V1_28_0.value,
event_logger_provider=event_logger_provider,
)
# TODO: implemented in later PR

wrap_function_wrapper(
module="vertexai.generative_models._generative_models",
# Patching this base class also instruments the vertexai.preview.generative_models
# package
name="_GenerativeModel.generate_content",
wrapper=generate_content_create(
tracer, event_logger, is_content_enabled()
),
)

def _uninstrument(self, **kwargs: Any) -> None:
"""TODO: implemented in later PR"""
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,104 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional

from opentelemetry._events import EventLogger
from opentelemetry.instrumentation.vertexai.utils import (
GenerateContentParams,
get_genai_request_attributes,
get_span_name,
handle_span_exception,
)
from opentelemetry.trace import SpanKind, Tracer

if TYPE_CHECKING:
from vertexai.generative_models import (
GenerationResponse,
Tool,
ToolConfig,
)
from vertexai.generative_models._generative_models import (
ContentsType,
GenerationConfigType,
SafetySettingsType,
_GenerativeModel,
)


def generate_content_create(
tracer: Tracer, event_logger: EventLogger, capture_content: bool
):
"""Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""

def traced_method(
wrapped: Callable[
..., GenerationResponse | Iterable[GenerationResponse]
],
instance: _GenerativeModel,
args: Any,
kwargs: Any,
):
# Use exact parameter signature to handle named vs positional args robustly
def extract_params(
contents: ContentsType,
*,
generation_config: Optional[GenerationConfigType] = None,
safety_settings: Optional[SafetySettingsType] = None,
tools: Optional[list[Tool]] = None,
tool_config: Optional[ToolConfig] = None,
labels: Optional[dict[str, str]] = None,
stream: bool = False,
) -> GenerateContentParams:
return GenerateContentParams(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
labels=labels,
stream=stream,
)

params = extract_params(*args, **kwargs)

span_attributes = get_genai_request_attributes(instance, params)

span_name = get_span_name(span_attributes)
with tracer.start_as_current_span(
name=span_name,
kind=SpanKind.CLIENT,
attributes=span_attributes,
end_on_exit=False,
) as span:
# TODO: emit request events
# if span.is_recording():
# for message in kwargs.get("messages", []):
# event_logger.emit(
# message_to_event(message, capture_content)
# )

try:
result = wrapped(*args, **kwargs)
# TODO: handle streaming
# if is_streaming(kwargs):
# return StreamWrapper(
# result, span, event_logger, capture_content
# )

# TODO: add response attributes and events
# if span.is_recording():
# _set_response_attributes(
# span, result, event_logger, capture_content
# )
span.end()
return result

except Exception as error:
handle_span_exception(span, error)
raise

return traced_method
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from dataclasses import dataclass
from os import environ
from typing import (
TYPE_CHECKING,
Dict,
List,
Mapping,
Optional,
TypedDict,
cast,
)

from opentelemetry.semconv._incubating.attributes import (
gen_ai_attributes as GenAIAttributes,
)
from opentelemetry.semconv.attributes import (
error_attributes as ErrorAttributes,
)
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util.types import AttributeValue

if TYPE_CHECKING:
from vertexai.generative_models import Tool, ToolConfig
from vertexai.generative_models._generative_models import (
ContentsType,
GenerationConfigType,
SafetySettingsType,
_GenerativeModel,
)


@dataclass(frozen=True)
class GenerateContentParams:
contents: ContentsType
generation_config: Optional[GenerationConfigType]
safety_settings: Optional[SafetySettingsType]
tools: Optional[List["Tool"]]
tool_config: Optional["ToolConfig"]
labels: Optional[Dict[str, str]]
stream: bool


class GenerationConfigDict(TypedDict, total=False):
temperature: Optional[float]
top_p: Optional[float]
top_k: Optional[int]
max_output_tokens: Optional[int]
stop_sequences: Optional[List[str]]
presence_penalty: Optional[float]
frequency_penalty: Optional[float]
seed: Optional[int]
# And more fields which aren't needed yet


def get_genai_request_attributes(
# TODO: use types
instance: _GenerativeModel,
params: GenerateContentParams,
operation_name: GenAIAttributes.GenAiOperationNameValues = GenAIAttributes.GenAiOperationNameValues.CHAT,
):
model = _get_model_name(instance)
generation_config = _get_generation_config(instance, params)
attributes = {
GenAIAttributes.GEN_AI_OPERATION_NAME: operation_name.value,
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.VERTEX_AI.value,
GenAIAttributes.GEN_AI_REQUEST_MODEL: model,
GenAIAttributes.GEN_AI_REQUEST_TEMPERATURE: generation_config.get(
"temperature"
),
GenAIAttributes.GEN_AI_REQUEST_TOP_P: generation_config.get("top_p"),
GenAIAttributes.GEN_AI_REQUEST_MAX_TOKENS: generation_config.get(
"max_output_tokens"
),
GenAIAttributes.GEN_AI_REQUEST_PRESENCE_PENALTY: generation_config.get(
"presence_penalty"
),
GenAIAttributes.GEN_AI_REQUEST_FREQUENCY_PENALTY: generation_config.get(
"frequency_penalty"
),
GenAIAttributes.GEN_AI_OPENAI_REQUEST_SEED: generation_config.get(
"seed"
),
GenAIAttributes.GEN_AI_REQUEST_STOP_SEQUENCES: generation_config.get(
"stop_sequences"
),
}

# filter out None values
return {k: v for k, v in attributes.items() if v is not None}


def _get_generation_config(
instance: _GenerativeModel,
params: GenerateContentParams,
) -> GenerationConfigDict:
generation_config = params.generation_config or instance._generation_config
if generation_config is None:
return {}
if isinstance(generation_config, dict):
return cast(GenerationConfigDict, generation_config)
return cast(GenerationConfigDict, generation_config.to_dict())


_RESOURCE_PREFIX = "publishers/google/models/"


def _get_model_name(instance: _GenerativeModel) -> str:
model_name = instance._model_name

# Can use str.removeprefix() once 3.8 is dropped
if model_name.startswith(_RESOURCE_PREFIX):
model_name = model_name[len(_RESOURCE_PREFIX) :]
return model_name


# TODO: Everything below here should be replaced with
# opentelemetry.instrumentation.genai_utils instead once it is released.
# https://github.com/open-telemetry/opentelemetry-python-contrib/issues/3191

OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT = (
"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"
)


def is_content_enabled() -> bool:
capture_content = environ.get(
OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT, "false"
)

return capture_content.lower() == "true"


def get_span_name(span_attributes: Mapping[str, AttributeValue]):
name = span_attributes.get(GenAIAttributes.GEN_AI_OPERATION_NAME, "")
model = span_attributes.get(GenAIAttributes.GEN_AI_REQUEST_MODEL, "")
return f"{name} {model}"


def handle_span_exception(span: Span, error: Exception):
span.set_status(Status(StatusCode.ERROR, str(error)))
if span.is_recording():
span.set_attribute(
ErrorAttributes.ERROR_TYPE, type(error).__qualname__
)
span.end()
Loading

0 comments on commit 105cf3c

Please sign in to comment.