diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index 89046dec2..1d6fe4e18 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -64,7 +64,6 @@ len_of, ) - class GrpcHandler: # pylint: disable=too-many-instance-attributes @@ -88,6 +87,9 @@ def __init__( self._setup_db_interceptor(kwargs.get("db_name", None)) self._setup_grpc_channel() + def register_state_change_callback(self, callback): + self._final_channel.subscribe(callback, try_to_connect=True) + def __get_address(self, uri: str, host: str, port: str) -> str: if host != "" and port != "" and is_legal_host(host) and is_legal_port(port): return f"{host}:{port}" diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 7826de8f2..15f95d49f 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -14,6 +14,8 @@ import threading from typing import Callable, Tuple, Union from urllib import parse +import logging +import time from pymilvus.client.check import is_legal_address, is_legal_host, is_legal_port from pymilvus.client.grpc_handler import GrpcHandler @@ -24,6 +26,8 @@ ) from pymilvus.settings import Config +logger = logging.getLogger(__name__) + VIRTUAL_PORT = 443 @@ -57,6 +61,52 @@ def __call__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) +class ReconnectHandler(): + def __init__(self, conns, connection_name, kwargs) -> None: + self.connection_name = connection_name + self.conns = conns + self._kwargs = kwargs + self.is_idle_state = False + self.reconnect_lock = threading.Lock() + + def check_state_and_reconnect_later(self): + check_after_seconds = 3 + logger.debug(f"state is idle, schedule reconnect in {check_after_seconds} seconds") + time.sleep(check_after_seconds) + if not self.is_idle_state: + logger.debug("idle state changed, skip reconnect") + return + # else: + with self.reconnect_lock: + logger.info("reconnect on idle state") + self.is_idle_state = False + try: + logger.debug("try disconnecting old connection...") + self.conns.disconnect(self.connection_name) + except Exception as e: + logger.warn("disconnect failed: {e}") + finally: + reconnected = False + while not reconnected: + try: + logger.debug("try reconnecting...") + self.conns.connect(self.connection_name, **self._kwargs) + reconnected = True + except Exception as e: + print(f"reconnect failed: {e}, try again after {check_after_seconds} seconds") + time.sleep(check_after_seconds) + logger.info("reconnected") + + def reconnect_on_idle(self, state): + logger.debug(f"state change to: {state}") + with self.reconnect_lock: + if state.value[1] != "idle": + self.is_idle_state = False + return + # else: + self.is_idle_state = True + threading.Thread(target=self.check_state_and_reconnect_later).start() + class Connections(metaclass=SingleInstanceMetaClass): """Class for managing all connections of milvus. Used as a singleton in this module.""" @@ -293,6 +343,8 @@ def connect( >>> connections.connect("test", host="localhost", port="19530") """ + # kwargs_copy is used for auto reconnect + kwargs_copy = copy.deepcopy(kwargs) def connect_milvus(**kwargs): gh = GrpcHandler(**kwargs) @@ -300,6 +352,7 @@ def connect_milvus(**kwargs): timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT gh._wait_for_channel_ready(timeout=timeout) + gh.register_state_change_callback(ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle) kwargs.pop("password") kwargs.pop("token", None) kwargs.pop("secure", None)