Skip to content

Commit

Permalink
Auto reconnection when channel state changed to idle
Browse files Browse the repository at this point in the history
Signed-off-by: shaoyue.chen <[email protected]>
  • Loading branch information
haorenfsa committed Jan 2, 2024
1 parent 11963c8 commit 96fb4d9
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
13 changes: 12 additions & 1 deletion pymilvus/client/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from urllib import parse

import grpc
Expand Down Expand Up @@ -87,6 +87,16 @@ def __init__(
self._set_authorization(**kwargs)
self._setup_db_interceptor(kwargs.get("db_name", None))
self._setup_grpc_channel()
self.callbacks = []

def register_state_change_callback(self, callback: Callable):
self.callbacks.append(callback)
self._channel.subscribe(callback, try_to_connect=True)

def deregister_state_change_callbacks(self):
for callback in self.callbacks:
self._channel.unsubscribe(callback)
self.callbacks = []

def __get_address(self, uri: str, host: str, port: str) -> str:
if host != "" and port != "" and is_legal_host(host) and is_legal_port(port):
Expand Down Expand Up @@ -141,6 +151,7 @@ def _wait_for_channel_ready(self, timeout: Union[float] = 10):
raise e from e

def close(self):
self.deregister_state_change_callbacks()
self._channel.close()

def reset_db_name(self, db_name: str):
Expand Down
64 changes: 64 additions & 0 deletions pymilvus/orm/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
# the License.

import copy
import logging
import threading
import time
from typing import Callable, Tuple, Union
from urllib import parse

Expand All @@ -24,6 +26,8 @@
)
from pymilvus.settings import Config

logger = logging.getLogger(__name__)

VIRTUAL_PORT = 443


Expand Down Expand Up @@ -58,6 +62,53 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls, *args, **kwargs)


class ReconnectHandler:
def __init__(self, conns: object, connection_name: str, kwargs: object) -> None:
self.connection_name = connection_name
self.conns = conns
self._kwargs = kwargs
self.is_idle_state = False
self.reconnect_lock = threading.Lock()

def check_state_and_reconnect_later(self):
check_after_seconds = 3
logger.debug(f"state is idle, schedule reconnect in {check_after_seconds} seconds")
time.sleep(check_after_seconds)
if not self.is_idle_state:
logger.debug("idle state changed, skip reconnect")
return
with self.reconnect_lock:
logger.info("reconnect on idle state")
self.is_idle_state = False
try:
logger.debug("try disconnecting old connection...")
self.conns.disconnect(self.connection_name)
except Exception:
logger.warning("disconnect failed: {e}")
finally:
reconnected = False
while not reconnected:
try:
logger.debug("try reconnecting...")
self.conns.connect(self.connection_name, **self._kwargs)
reconnected = True
except Exception as e:
logger.warning(
f"reconnect failed: {e}, try again after {check_after_seconds} seconds"
)
time.sleep(check_after_seconds)
logger.info("reconnected")

def reconnect_on_idle(self, state: object):
logger.debug(f"state change to: {state}")
with self.reconnect_lock:
if state.value[1] != "idle":
self.is_idle_state = False
return
self.is_idle_state = True
threading.Thread(target=self.check_state_and_reconnect_later).start()


class Connections(metaclass=SingleInstanceMetaClass):
"""Class for managing all connections of milvus. Used as a singleton in this module."""

Expand Down Expand Up @@ -270,6 +321,8 @@ def connect(
Optional. Serving as the key for identification and authentication purposes.
Whenever a token is furnished, we shall supplement the corresponding header
to each RPC call.
* *keep_alive* (``bool``) --
Optional. Default is false. If set to true, client will keep an alive connection.
* *db_name* (``str``) --
Optional. default database name of this connection
* *client_key_path* (``str``) --
Expand All @@ -293,13 +346,24 @@ def connect(
>>> connections.connect("test", host="localhost", port="19530")
"""

# kwargs_copy is used for auto reconnect
kwargs_copy = copy.deepcopy(kwargs)
kwargs_copy["user"] = user
kwargs_copy["password"] = password
kwargs_copy["db_name"] = db_name
kwargs_copy["token"] = token

def connect_milvus(**kwargs):
gh = GrpcHandler(**kwargs)

t = kwargs.get("timeout")
timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT

gh._wait_for_channel_ready(timeout=timeout)
if kwargs.get("keep_alive", False):
gh.register_state_change_callback(
ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle
)
kwargs.pop("password")
kwargs.pop("token", None)
kwargs.pop("secure", None)
Expand Down

0 comments on commit 96fb4d9

Please sign in to comment.