From 6fcaa0a1048e3e26ddf7eb988e2e6f75258bf7dd Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 17 Dec 2024 19:40:40 +0100 Subject: [PATCH] Add type hints to Redis (#3110) --- .../instrumentation/redis/__init__.py | 105 +++++++++++++----- .../instrumentation/redis/py.typed | 0 2 files changed, 79 insertions(+), 26 deletions(-) create mode 100644 instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/py.typed diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index e81beb6f3d..8a3096ad41 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -91,8 +91,9 @@ def response_hook(span, instance, response): --- """ -import typing -from typing import Any, Collection +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Collection import redis from wrapt import wrap_function_wrapper @@ -109,18 +110,43 @@ def response_hook(span, instance, response): from opentelemetry.instrumentation.redis.version import __version__ from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import Span, StatusCode +from opentelemetry.trace import Span, StatusCode, Tracer -_DEFAULT_SERVICE = "redis" +if TYPE_CHECKING: + from typing import Awaitable, TypeVar -_RequestHookT = typing.Optional[ - typing.Callable[ - [Span, redis.connection.Connection, typing.List, typing.Dict], None + import redis.asyncio.client + import redis.asyncio.cluster + import redis.client + import redis.cluster + import redis.connection + + _RequestHookT = Callable[ + [Span, redis.connection.Connection, list[Any], dict[str, Any]], None ] -] -_ResponseHookT = typing.Optional[ - typing.Callable[[Span, redis.connection.Connection, Any], None] -] + _ResponseHookT = Callable[[Span, redis.connection.Connection, Any], None] + + AsyncPipelineInstance = TypeVar( + "AsyncPipelineInstance", + redis.asyncio.client.Pipeline, + redis.asyncio.cluster.ClusterPipeline, + ) + AsyncRedisInstance = TypeVar( + "AsyncRedisInstance", redis.asyncio.Redis, redis.asyncio.RedisCluster + ) + PipelineInstance = TypeVar( + "PipelineInstance", + redis.client.Pipeline, + redis.cluster.ClusterPipeline, + ) + RedisInstance = TypeVar( + "RedisInstance", redis.client.Redis, redis.cluster.RedisCluster + ) + R = TypeVar("R") + + +_DEFAULT_SERVICE = "redis" + _REDIS_ASYNCIO_VERSION = (4, 2, 0) if redis.VERSION >= _REDIS_ASYNCIO_VERSION: @@ -132,7 +158,9 @@ def response_hook(span, instance, response): _FIELD_TYPES = ["NUMERIC", "TEXT", "GEO", "TAG", "VECTOR"] -def _set_connection_attributes(span, conn): +def _set_connection_attributes( + span: Span, conn: RedisInstance | AsyncRedisInstance +) -> None: if not span.is_recording() or not hasattr(conn, "connection_pool"): return for key, value in _extract_conn_attributes( @@ -141,7 +169,9 @@ def _set_connection_attributes(span, conn): span.set_attribute(key, value) -def _build_span_name(instance, cmd_args): +def _build_span_name( + instance: RedisInstance | AsyncRedisInstance, cmd_args: tuple[Any, ...] +) -> str: if len(cmd_args) > 0 and cmd_args[0]: if cmd_args[0] == "FT.SEARCH": name = "redis.search" @@ -154,7 +184,9 @@ def _build_span_name(instance, cmd_args): return name -def _build_span_meta_data_for_pipeline(instance): +def _build_span_meta_data_for_pipeline( + instance: PipelineInstance | AsyncPipelineInstance, +) -> tuple[list[Any], str, str]: try: command_stack = ( instance.command_stack @@ -184,11 +216,16 @@ def _build_span_meta_data_for_pipeline(instance): # pylint: disable=R0915 def _instrument( - tracer, - request_hook: _RequestHookT = None, - response_hook: _ResponseHookT = None, + tracer: Tracer, + request_hook: _RequestHookT | None = None, + response_hook: _ResponseHookT | None = None, ): - def _traced_execute_command(func, instance, args, kwargs): + def _traced_execute_command( + func: Callable[..., R], + instance: RedisInstance, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> R: query = _format_command_args(args) name = _build_span_name(instance, args) with tracer.start_as_current_span( @@ -210,7 +247,12 @@ def _traced_execute_command(func, instance, args, kwargs): response_hook(span, instance, response) return response - def _traced_execute_pipeline(func, instance, args, kwargs): + def _traced_execute_pipeline( + func: Callable[..., R], + instance: PipelineInstance, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> R: ( command_stack, resource, @@ -242,7 +284,7 @@ def _traced_execute_pipeline(func, instance, args, kwargs): return response - def _add_create_attributes(span, args): + def _add_create_attributes(span: Span, args: tuple[Any, ...]): _set_span_attribute_if_value( span, "redis.create_index.index", _value_or_none(args, 1) ) @@ -266,7 +308,7 @@ def _add_create_attributes(span, args): field_attribute, ) - def _add_search_attributes(span, response, args): + def _add_search_attributes(span: Span, response, args): _set_span_attribute_if_value( span, "redis.search.index", _value_or_none(args, 1) ) @@ -326,7 +368,12 @@ def _add_search_attributes(span, response, args): _traced_execute_pipeline, ) - async def _async_traced_execute_command(func, instance, args, kwargs): + async def _async_traced_execute_command( + func: Callable[..., Awaitable[R]], + instance: AsyncRedisInstance, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Awaitable[R]: query = _format_command_args(args) name = _build_span_name(instance, args) @@ -344,7 +391,12 @@ async def _async_traced_execute_command(func, instance, args, kwargs): response_hook(span, instance, response) return response - async def _async_traced_execute_pipeline(func, instance, args, kwargs): + async def _async_traced_execute_pipeline( + func: Callable[..., Awaitable[R]], + instance: AsyncPipelineInstance, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Awaitable[R]: ( command_stack, resource, @@ -408,14 +460,15 @@ async def _async_traced_execute_pipeline(func, instance, args, kwargs): class RedisInstrumentor(BaseInstrumentor): - """An instrumentor for Redis + """An instrumentor for Redis. + See `BaseInstrumentor` """ def instrumentation_dependencies(self) -> Collection[str]: return _instruments - def _instrument(self, **kwargs): + def _instrument(self, **kwargs: Any): """Instruments the redis module Args: @@ -436,7 +489,7 @@ def _instrument(self, **kwargs): response_hook=kwargs.get("response_hook"), ) - def _uninstrument(self, **kwargs): + def _uninstrument(self, **kwargs: Any): if redis.VERSION < (3, 0, 0): unwrap(redis.StrictRedis, "execute_command") unwrap(redis.StrictRedis, "pipeline") diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/py.typed b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/py.typed new file mode 100644 index 0000000000..e69de29bb2