diff --git a/midealocal/device.py b/midealocal/device.py index e6f9a78b..67d3c4c8 100644 --- a/midealocal/device.py +++ b/midealocal/device.py @@ -150,7 +150,7 @@ def fetch_v2_message(msg: bytes) -> tuple[list, bytes]: break return result, msg - def connect(self) -> bool: + def connect(self, init: bool = False, reconnect: bool = False) -> bool: """Connect to device.""" connected = False try: @@ -164,22 +164,35 @@ def connect(self) -> bool: ) self._socket.connect((self._ip_address, self._port)) _LOGGER.debug("[%s] Connected", self._device_id) + if self._protocol == ProtocolVersion.V3: + self.authenticate() + # 1. midea_ac_lan add device verify token with connect and auth + # 2. init connection, check_protocol + # 3. reconnect, skip check_protocol + if reconnect or init: + self.refresh_status(check_protocol=init) + if init: + self.get_capabilities() connected = True except TimeoutError: _LOGGER.debug("[%s] Connection timed out", self._device_id) - # set _socket to None when connect exception matched self._socket = None - except OSError: + except OSError: # refresh_status exception _LOGGER.debug("[%s] Connection error", self._device_id) - # set _socket to None when connect exception matched self._socket = None + except AuthException: # authenticate exception + _LOGGER.debug("[%s] Authentication failed", self._device_id) + except SocketException: # refresh_status exception + _LOGGER.debug("[%s] Connect socket exception", self._device_id) + self._socket = None + except NoSupportedProtocol: # refresh_status exception + _LOGGER.debug("[%s] No supported query protocol", self._device_id) except Exception as e: _LOGGER.exception( "[%s] Unknown error during connect device", self._device_id, exc_info=e, ) - # set _socket to None when connect exception matched self._socket = None self.set_available(connected) return connected @@ -187,21 +200,16 @@ def connect(self) -> bool: def authenticate(self) -> None: """Authenticate to device. V3 only.""" request = self._security.encode_8370(self._token, MSGTYPE_HANDSHAKE_REQUEST) - _LOGGER.debug("[%s] Authentication handshaking", self._device_id) if not self._socket: - _LOGGER.debug("[%s] socket is None, close and return", self._device_id) - self.enable_device(False) - return - try: - self._socket.send(request) - response = self._socket.recv(512) - except Exception as e: - _LOGGER.exception( - "[%s] authenticate Unexpected socket error", + _LOGGER.debug( + "[%s] authenticate failure, device socket is none", self._device_id, - exc_info=e, ) - self.close_socket() + # raise exception to connect loop + raise SocketException + _LOGGER.debug("[%s] Authentication handshaking", self._device_id) + self._socket.send(request) + response = self._socket.recv(512) _LOGGER.debug( "[%s] Received auth response with %d bytes: %s", self._device_id, @@ -209,7 +217,6 @@ def authenticate(self) -> None: response.hex(), ) if len(response) < MIN_AUTH_RESPONSE: - self.enable_device(False) _LOGGER.debug( "[%s] Received auth response len %d error, bytes: %s", self._device_id, @@ -230,24 +237,59 @@ def send_message(self, data: bytes, query: bool = False) -> None: def send_message_v2(self, data: bytes, query: bool = False) -> None: """Send message V2.""" - if self._socket is not None: - try: - if query: - self._socket.settimeout(QUERY_TIMEOUT) - self._socket.send(data) - except Exception as e: - _LOGGER.exception( - "[%s] send_message_v2 Unexpected socket error", - self._device_id, - exc_info=e, - ) - self.close_socket() - else: + if not self._socket: _LOGGER.debug( - "[%s] Send failure, device socket is none, data: %s", + "[%s] send_message_v2 failure, device socket is none, data: %s", self._device_id, data.hex(), ) + # raise exception to main loop + raise SocketException + try: + _LOGGER.debug( + "[%s] send_message_v2 with data %s", + self._device_id, + data.hex(), + ) + # query msg, set timeout to QUERY_TIMEOUT + if query: + self._socket.settimeout(QUERY_TIMEOUT) + self._socket.send(data) + _LOGGER.debug( + "[%s] send_message_v2 success", + self._device_id, + ) + except TimeoutError: + _LOGGER.debug( + "[%s] send_message_v2 timed out", + self._device_id, + ) + # raise exception to main loop + raise + except ConnectionResetError as e: + _LOGGER.debug( + "[%s] send_message_v2 ConnectionResetError: %s", + self._device_id, + e, + ) + # raise exception to main loop + raise + except OSError as e: + _LOGGER.debug( + "[%s] send_message_v2 OSError: %s", + self._device_id, + e, + ) + # raise exception to main loop + raise + except Exception as e: + _LOGGER.exception( + "[%s] send_message_v2 Unexpected socket error", + self._device_id, + exc_info=e, + ) + # raise exception to main loop + raise def send_message_v3( self, @@ -262,9 +304,18 @@ def send_message_v3( def build_send(self, cmd: MessageRequest, query: bool = False) -> None: """Serialize and send.""" data = cmd.serialize() - _LOGGER.debug("[%s] Sending: %s", self._device_id, cmd) + _LOGGER.debug("[%s] Sending: %s, query is %s", self._device_id, cmd, query) msg = PacketBuilder(self._device_id, data).finalize() self.send_message(msg, query=query) + # after send set command, force refresh_status + if cmd.message_type == MessageType.set: + _LOGGER.debug( + "[%s] Force refresh after set status to: %s", + self._device_id, + cmd, + ) + now = time.time() + self._previous_refresh = now - self._refresh_interval def get_capabilities(self) -> None: """Get device capabilities.""" @@ -272,111 +323,55 @@ def get_capabilities(self) -> None: for cmd in cmds: self.build_send(cmd) - def _recv_message( - self, - check_protocol: bool = False, - ) -> dict[str, MessageResult | bytes]: - """Recv message.""" - # already connected and socket error - if not self._socket: - _LOGGER.debug("[%s] _recv_message socket error, reconnect", self._device_id) - raise SocketException - try: - msg = self._socket.recv(512) - if len(msg) == 0: - _LOGGER.warning("[%s] Empty msg received", self._device_id) - return {"result": MessageResult.PADDING} - if msg: - return {"result": MessageResult.SUCCESS, "msg": msg} - except TimeoutError: - _LOGGER.debug( - "[%s] _recv_message Socket timed out with check_protocol %s", - self._device_id, - check_protocol, - ) - # close socket when timeout and not check_protocol - if not check_protocol: - self.close_socket() - return {"result": MessageResult.TIMEOUT} - except Exception as e: - _LOGGER.exception( - "[%s] Unexpected socket error", - self._device_id, - exc_info=e, - ) - # close socket when exception matched - self.close_socket() - return {"result": MessageResult.UNEXPECTED} - return {"result": MessageResult.UNKNOWN} # Add a fallback return - def refresh_status(self, check_protocol: bool = False) -> None: """Refresh device status.""" cmds: list = self.build_query() if self._appliance_query: cmds = [MessageQueryAppliance(self.device_type), *cmds] error_count = 0 + _LOGGER.debug( + "[%s] refresh_status with cmds: %s, check_protocol %s", + self._device_id, + cmds, + check_protocol, + ) for cmd in cmds: if cmd.__class__.__name__ not in self._unsupported_protocol: - # catch socket exception and continue + # set socket QUERY_TIMEOUT for query msg + # build_send exception should be catch by connect/run + self.build_send(cmd, query=True) try: - # set query flag for query timeout - self.build_send(cmd, query=True) - # recv socket message send query - response = self._recv_message(check_protocol=check_protocol) - # recovery timeout after _recv_message is success/padding - self._recovery_timeout() - except SocketException: - _LOGGER.debug( - "[%s] refresh_status socket error, close and reconnect", - self._device_id, - ) - self.close_socket() - break - # normal msg - if response.get("result") == MessageResult.SUCCESS: - if response.get("msg"): - # parse response - msg = response.get("msg") - if isinstance(msg, bytes): - result = self.parse_message(msg=msg) - if result != MessageResult.SUCCESS: - _LOGGER.error( - "[%s] parse_message %s result is %s", - self._device_id, - msg, - result, - ) - # empty msg - elif response.get("result") == MessageResult.PADDING: - continue - # timeout msg - elif response.get("result") == MessageResult.TIMEOUT: + while True: + if not self._socket: + _LOGGER.debug( + "[%s] authenticate failure, device socket is none", + self._device_id, + ) + # raise exception to connect/main loop + raise SocketException + msg = self._socket.recv(512) + if len(msg) == 0: + raise OSError("Empty message received.") + result = self.parse_message(msg) + # Prevent infinite loop + if result == MessageResult.SUCCESS: + break + elif result == MessageResult.PADDING: # noqa: RET508 + continue + else: + raise ResponseException # noqa: TRY301 + # recovery SOCKET_TIMEOUT after recv msg + self._socket.settimeout(SOCKET_TIMEOUT) + # only catch TimoutError for check_protocol + # unexpected exception in recv/settimeout, catch by main loop + except TimeoutError: _LOGGER.debug( "[%s] protocol %s, cmd %s, timeout", self._device_id, cmd.__class__.__name__, cmd, ) - # init connection, add timeout protocol to unsupported list - if check_protocol: - error_count += 1 - self._unsupported_protocol.append(cmd.__class__.__name__) - _LOGGER.debug( - "[%s] Does not supports the protocol %s, cmd %s, ignored", - self._device_id, - cmd.__class__.__name__, - cmd, - ) - # exception msg - else: - _LOGGER.debug( - "[%s] protocol %s, cmd %s, response exception %s", - self._device_id, - cmd.__class__.__name__, - cmd, - response, - ) - # init connection, add exception protocol to unsupported list + # init check_protocol, skip timeout exception if check_protocol: error_count += 1 self._unsupported_protocol.append(cmd.__class__.__name__) @@ -386,14 +381,22 @@ def refresh_status(self, check_protocol: bool = False) -> None: cmd.__class__.__name__, cmd, ) - # init connection and all the query failed, raise error - if check_protocol and error_count == len(cmds): - _LOGGER.debug( - "[%s] all the query cmds failed %s, please report bug", - self._device_id, - cmds, - ) - raise NoSupportedProtocol + # refresh_status, raise timeout exception to main loop + else: + raise + except ResponseException: + # parse msg error + error_count += 1 + else: + error_count += 1 + # init check_protocol and all the query failed + if check_protocol and error_count == len(cmds): + _LOGGER.debug( + "[%s] all the query cmds failed %s, please report bug", + self._device_id, + cmds, + ) + raise NoSupportedProtocol def pre_process_message(self, msg: bytearray) -> bool: """Pre process message.""" @@ -424,7 +427,6 @@ def parse_message(self, msg: bytes) -> MessageResult: payload_len = message[4] + (message[5] << 8) - 56 payload_type = message[2] + (message[3] << 8) if payload_type in [0x1001, 0x0001]: - # Heartbeat detected pass elif len(message) > MIN_MSG_LENGTH: cryptographic = bytes(message[40:-16]) @@ -554,12 +556,15 @@ def close(self) -> None: self._is_run = False self.close_socket() - def close_socket(self) -> None: + def close_socket(self, init: bool = False) -> None: """Close socket.""" - self._unsupported_protocol = [] + # init connection, check_protocol + if init: + self._unsupported_protocol = [] self._buffer = b"" if self._socket: try: + self._socket.shutdown(socket.SHUT_RDWR) self._socket.close() _LOGGER.debug("[%s] Socket closed", self._device_id) except OSError as e: @@ -572,7 +577,7 @@ def set_ip_address(self, ip_address: str) -> None: if self._ip_address != ip_address: _LOGGER.debug("[%s] Update IP address to %s", self._device_id, ip_address) self._ip_address = ip_address - self.close_socket() + self.close_socket(init=True) def set_refresh_interval(self, refresh_interval: int) -> None: """Set refresh interval.""" @@ -588,37 +593,6 @@ def _check_heartbeat(self, now: float) -> None: self.send_heartbeat() self._previous_heartbeat = now - def _recovery_timeout(self) -> None: - if not self._socket: - _LOGGER.debug("[%s] _recovery_timeout socket error", self._device_id) - raise SocketException - try: - self._socket.settimeout(SOCKET_TIMEOUT) - except TimeoutError: - self.close_socket() - _LOGGER.debug("_recovery_timeout socket timeout") - - def _connect_loop(self) -> None: - """Connect loop until device online.""" - # init connection or socket broken, socket loop until device online - connection_retries = 0 - while self._socket is None: - _LOGGER.debug("[%s] Socket is None, try to connect", self._device_id) - # connect and check result - if not self.connect(): - self.close_socket() - connection_retries += 1 - sleep_time = min(60 * connection_retries, 600) - _LOGGER.warning( - "[%s] Unable to connect, sleep %s seconds and retry", - self._device_id, - sleep_time, - ) - # sleep and reconnect loop - time.sleep(sleep_time) - continue - connection_retries = 0 - def run(self) -> None: """Run loop brief description. @@ -639,59 +613,70 @@ def run(self) -> None: 4.1 socket connection should exist 4.2 send heartbeat packet to keep alive + scenario/bug fix: + 1. while True loop should sleep 0.1 second to prevent cpu usage issue + 2. device running and power off become offline, status update + 3. device disconnected and power on, become online, status update + """ + # service loop while self._is_run: - # init connection - if self._socket is None: - # connect device loop - self._connect_loop() - # connect pass, auth for v3 device - if self._protocol == ProtocolVersion.V3: - self.authenticate() - try: - # probe device with query and check response - self.refresh_status(check_protocol=True) - except NoSupportedProtocol: - _LOGGER.debug( - "[%s] query device failed, please report bug", - self._device_id, - ) - break - except SocketException: - _LOGGER.debug( - "[%s] socket error, close and reconnect", + # connect loop until online + connection_retries = 0 + while self._socket is None: + _LOGGER.debug("[%s] Socket is None, try to connect", self._device_id) + if self.connect(init=True) is False: + self.close_socket(init=True) + connection_retries += 1 + # Sleep time with exponential backoff, maximum 600 seconds + sleep_time = min(5 * (2 ** (connection_retries - 1)), 600) + _LOGGER.warning( + "[%s] Unable to connect, sleep %s seconds and retry", self._device_id, + sleep_time, ) - self.close_socket() - continue - self.get_capabilities() - # socket exist + # sleep and reconnect loop until device online + time.sleep(sleep_time) + connection_retries = 0 start = time.time() self._previous_refresh = self._previous_heartbeat = start - # loop in query and parse response + # main loop after connected while True: + reconnect = False try: - # check refresh process now = time.time() self._check_refresh(now) - # check heartbeat - now = time.time() self._check_heartbeat(now) except TimeoutError: _LOGGER.debug("[%s] Socket timed out", self._device_id) - self.close_socket() - break + reconnect = True + except SocketException: # refresh_status + _LOGGER.debug("[%s] Socket Exception", self._device_id) + reconnect = True except NoSupportedProtocol: - _LOGGER.debug("[%s] query device failed", self._device_id) - self.close_socket() - break + _LOGGER.debug("[%s] No Supported protocol", self._device_id) + # ignore and continue loop + continue + except ConnectionResetError: # refresh_status -> build_send exception + _LOGGER.debug("[%s] Connection reset by peer", self._device_id) + reconnect = True + except OSError: # refresh_status + _LOGGER.debug("[%s] OS error", self._device_id) + reconnect = True except Exception as e: _LOGGER.exception( "[%s] Unexpected error", self._device_id, exc_info=e, ) + reconnect = True + # reconnect socket and try to skip check_protocol + if reconnect: self.close_socket() + if self.connect(reconnect=True): + # pass, continue while True loop + continue + # device disconnect, break while True loop, start main loop break # prevent while True loop cpu 100% time.sleep(0.1)