Skip to content

Commit

Permalink
Update transport.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saimedhi authored Mar 18, 2024
1 parent 1047f0b commit fc9b216
Showing 1 changed file with 35 additions and 10 deletions.
45 changes: 35 additions & 10 deletions opensearchpy/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from itertools import chain
from typing import Any, Callable, Collection, Dict, List, Mapping, Optional, Type, Union

from opensearchpy.metrics import Metrics

from .connection import Connection, Urllib3HttpConnection
from .connection_pool import ConnectionPool, DummyConnectionPool, EmptyConnectionPool
from .exceptions import (
Expand Down Expand Up @@ -91,6 +93,7 @@ class Transport(object):
last_sniff: float
sniff_timeout: Optional[float]
host_info_callback: Any
metrics: Type[Metrics]

def __init__(
self,
Expand All @@ -112,6 +115,7 @@ def __init__(
retry_on_status: Collection[int] = (502, 503, 504),
retry_on_timeout: bool = False,
send_get_body_as: str = "GET",
metrics: Type[Metrics] = None,
**kwargs: Any
) -> None:
"""
Expand Down Expand Up @@ -153,6 +157,7 @@ def __init__(
when creating and instance unless overridden by that connection's
options provided as part of the hosts parameter.
"""
self.metrics = metrics
if connection_class is None:
connection_class = self.DEFAULT_CONNECTION_CLASS

Expand Down Expand Up @@ -242,7 +247,7 @@ def _create_connection(host: Any) -> Any:
kwargs.update(host)
if self.pool_maxsize and isinstance(self.pool_maxsize, int):
kwargs["pool_maxsize"] = self.pool_maxsize
return self.connection_class(**kwargs)
return self.connection_class(metrics=self.metrics, **kwargs)

connections = list(zip(map(_create_connection, hosts), hosts))
if len(connections) == 1:
Expand Down Expand Up @@ -405,15 +410,31 @@ def perform_request(
connection = self.get_connection()

try:
status, headers_response, data = connection.perform_request(
method,
url,
params,
body,
headers=headers,
ignore=ignore,
timeout=timeout,
)
if self.metrics:
(
status,
headers_response,
data,
service_time,
) = connection.perform_request(
method,
url,
params,
body,
headers=headers,
ignore=ignore,
timeout=timeout,
)
else:
status, headers_response, data = connection.perform_request(
method,
url,
params,
body,
headers=headers,
ignore=ignore,
timeout=timeout,
)

# Lowercase all the header names for consistency in accessing them.
headers_response = {
Expand Down Expand Up @@ -457,6 +478,10 @@ def perform_request(
data = self.deserializer.loads(
data, headers_response.get("content-type")
)

if self.metrics:
data["client_metrics"] = {"service_time": service_time}

return data

def close(self) -> Any:
Expand Down

0 comments on commit fc9b216

Please sign in to comment.