diff --git a/pymilvus/client/grpc_handler.py b/pymilvus/client/grpc_handler.py index cbb55b627..e02ff100c 100644 --- a/pymilvus/client/grpc_handler.py +++ b/pymilvus/client/grpc_handler.py @@ -4,7 +4,7 @@ import socket import time from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from urllib import parse import grpc @@ -88,6 +88,9 @@ def __init__( self._setup_db_interceptor(kwargs.get("db_name", None)) self._setup_grpc_channel() + def register_state_change_callback(self, callback: Callable): + self._final_channel.subscribe(callback, try_to_connect=True) + def __get_address(self, uri: str, host: str, port: str) -> str: if host != "" and port != "" and is_legal_host(host) and is_legal_port(port): return f"{host}:{port}" @@ -141,6 +144,7 @@ def _wait_for_channel_ready(self, timeout: Union[float] = 10): raise e from e def close(self): + self._final_channel.unsubscribe(callback) self._channel.close() def reset_db_name(self, db_name: str): diff --git a/pymilvus/orm/connections.py b/pymilvus/orm/connections.py index 7826de8f2..bf8db784a 100644 --- a/pymilvus/orm/connections.py +++ b/pymilvus/orm/connections.py @@ -11,7 +11,9 @@ # the License. import copy +import logging import threading +import time from typing import Callable, Tuple, Union from urllib import parse @@ -24,6 +26,8 @@ ) from pymilvus.settings import Config +logger = logging.getLogger(__name__) + VIRTUAL_PORT = 443 @@ -58,6 +62,53 @@ def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) +class ReconnectHandler: + def __init__(self, conns: object, connection_name: str, kwargs: object) -> None: + self.connection_name = connection_name + self.conns = conns + self._kwargs = kwargs + self.is_idle_state = False + self.reconnect_lock = threading.Lock() + + def check_state_and_reconnect_later(self): + check_after_seconds = 3 + logger.debug(f"state is idle, schedule reconnect in {check_after_seconds} seconds") + time.sleep(check_after_seconds) + if not self.is_idle_state: + logger.debug("idle state changed, skip reconnect") + return + with self.reconnect_lock: + logger.info("reconnect on idle state") + self.is_idle_state = False + try: + logger.debug("try disconnecting old connection...") + self.conns.disconnect(self.connection_name) + except Exception: + logger.warning("disconnect failed: {e}") + finally: + reconnected = False + while not reconnected: + try: + logger.debug("try reconnecting...") + self.conns.connect(self.connection_name, **self._kwargs) + reconnected = True + except Exception as e: + logger.warning( + f"reconnect failed: {e}, try again after {check_after_seconds} seconds" + ) + time.sleep(check_after_seconds) + logger.info("reconnected") + + def reconnect_on_idle(self, state: object): + logger.debug(f"state change to: {state}") + with self.reconnect_lock: + if state.value[1] != "idle": + self.is_idle_state = False + return + self.is_idle_state = True + threading.Thread(target=self.check_state_and_reconnect_later).start() + + class Connections(metaclass=SingleInstanceMetaClass): """Class for managing all connections of milvus. Used as a singleton in this module.""" @@ -293,6 +344,9 @@ def connect( >>> connections.connect("test", host="localhost", port="19530") """ + # kwargs_copy is used for auto reconnect + kwargs_copy = copy.deepcopy(kwargs) + def connect_milvus(**kwargs): gh = GrpcHandler(**kwargs) @@ -300,6 +354,9 @@ def connect_milvus(**kwargs): timeout = t if isinstance(t, (int, float)) else Config.MILVUS_CONN_TIMEOUT gh._wait_for_channel_ready(timeout=timeout) + gh.register_state_change_callback( + ReconnectHandler(self, alias, kwargs_copy).reconnect_on_idle + ) kwargs.pop("password") kwargs.pop("token", None) kwargs.pop("secure", None)