From 08eee5f4c74494328033480ea55511de1e1458c4 Mon Sep 17 00:00:00 2001 From: Luke Date: Sat, 6 Apr 2024 19:56:49 -0400 Subject: [PATCH] chore: move more things around in version 1 api --- roborock/api.py | 236 +--------------- roborock/cloud_api.py | 9 +- roborock/code_mappings.py | 85 +++++- roborock/containers.py | 25 ++ roborock/local_api.py | 39 +-- roborock/protocol.py | 27 +- roborock/roborock_message.py | 4 +- roborock/version_1_apis/roborock_client_v1.py | 256 +++++++++++++++++- .../roborock_local_client_v1.py | 44 ++- .../version_1_apis/roborock_mqtt_client_v1.py | 7 +- 10 files changed, 438 insertions(+), 294 deletions(-) diff --git a/roborock/api.py b/roborock/api.py index 92b579c..7a4338f 100644 --- a/roborock/api.py +++ b/roborock/api.py @@ -4,125 +4,35 @@ import asyncio import base64 -import dataclasses -import hashlib -import json import logging import secrets -import struct import time from collections.abc import Callable, Coroutine -from typing import Any, TypeVar, final +from typing import Any -from .command_cache import CacheableAttribute, CommandType, RoborockAttribute, find_cacheable_attribute, get_cache_map from .containers import ( - Consumable, DeviceData, ModelStatus, - RoborockBase, S7MaxVStatus, Status, ) from .exceptions import ( - RoborockException, RoborockTimeout, UnknownMethodError, VacuumError, ) -from .protocol import Utils from .roborock_future import RoborockFuture from .roborock_message import ( - ROBOROCK_DATA_CONSUMABLE_PROTOCOL, - ROBOROCK_DATA_STATUS_PROTOCOL, - RoborockDataProtocol, RoborockMessage, - RoborockMessageProtocol, ) from .roborock_typing import RoborockCommand -from .util import RepeatableTask, RoborockLoggerAdapter, get_running_loop_or_create_one +from .util import RoborockLoggerAdapter, get_running_loop_or_create_one _LOGGER = logging.getLogger(__name__) KEEPALIVE = 60 -RT = TypeVar("RT", bound=RoborockBase) - - -def md5hex(message: str) -> str: - md5 = hashlib.md5() - md5.update(message.encode()) - return md5.hexdigest() - - -EVICT_TIME = 60 - - -class AttributeCache: - def __init__(self, attribute: RoborockAttribute, api: RoborockClient): - self.attribute = attribute - self.api = api - self.attribute = attribute - self.task = RepeatableTask(self.api.event_loop, self._async_value, EVICT_TIME) - self._value: Any = None - self._mutex = asyncio.Lock() - self.unsupported: bool = False - - @property - def value(self): - return self._value - - async def _async_value(self): - if self.unsupported: - return None - try: - self._value = await self.api._send_command(self.attribute.get_command) - except UnknownMethodError as err: - # Limit the amount of times we call unsupported methods - self.unsupported = True - raise err - return self._value - - async def async_value(self): - async with self._mutex: - if self._value is None: - return await self.task.reset() - return self._value - - def stop(self): - self.task.cancel() - - async def update_value(self, params): - if self.attribute.set_command is None: - raise RoborockException(f"{self.attribute.attribute} have no set command") - response = await self.api._send_command(self.attribute.set_command, params) - await self._async_value() - return response - - async def add_value(self, params): - if self.attribute.add_command is None: - raise RoborockException(f"{self.attribute.attribute} have no add command") - response = await self.api._send_command(self.attribute.add_command, params) - await self._async_value() - return response - - async def close_value(self, params=None): - if self.attribute.close_command is None: - raise RoborockException(f"{self.attribute.attribute} have no close command") - response = await self.api._send_command(self.attribute.close_command, params) - await self._async_value() - return response - - async def refresh_value(self): - await self._async_value() - - -@dataclasses.dataclass -class ListenerModel: - protocol_handlers: dict[RoborockDataProtocol, list[Callable[[Status | Consumable], None]]] - cache: dict[CacheableAttribute, AttributeCache] class RoborockClient: - _listeners: dict[str, ListenerModel] = {} - def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int = 4) -> None: self.event_loop = get_running_loop_or_create_one() self.device_info = device_info @@ -136,15 +46,9 @@ def __init__(self, endpoint: str, device_info: DeviceData, queue_timeout: int = "misc_info": {"Nonce": base64.b64encode(self._nonce).decode("utf-8")} } self._logger = RoborockLoggerAdapter(device_info.device.name, _LOGGER) - self.cache: dict[CacheableAttribute, AttributeCache] = { - cacheable_attribute: AttributeCache(attr, self) for cacheable_attribute, attr in get_cache_map().items() - } self.is_available: bool = True self.queue_timeout = queue_timeout self._status_type: type[Status] = ModelStatus.get(self.device_info.model, S7MaxVStatus) - if device_info.device.duid not in self._listeners: - self._listeners[device_info.device.duid] = ListenerModel({}, self.cache) - self.listener_model = self._listeners[device_info.device.duid] def __del__(self) -> None: self.release() @@ -156,11 +60,9 @@ def status_type(self) -> type[Status]: def release(self): self.sync_disconnect() - [item.stop() for item in self.cache.values()] async def async_release(self): await self.async_disconnect() - [item.stop() for item in self.cache.values()] @property def diagnostic_data(self) -> dict: @@ -185,95 +87,7 @@ async def async_disconnect(self) -> Any: raise NotImplementedError def on_message_received(self, messages: list[RoborockMessage]) -> None: - try: - self._last_device_msg_in = self.time_func() - for data in messages: - protocol = data.protocol - if data.payload and protocol in [ - RoborockMessageProtocol.RPC_RESPONSE, - RoborockMessageProtocol.GENERAL_REQUEST, - ]: - payload = json.loads(data.payload.decode()) - for data_point_number, data_point in payload.get("dps").items(): - if data_point_number == "102": - data_point_response = json.loads(data_point) - request_id = data_point_response.get("id") - queue = self._waiting_queue.get(request_id) - if queue and queue.protocol == protocol: - error = data_point_response.get("error") - if error: - queue.resolve( - ( - None, - VacuumError( - error.get("code"), - error.get("message"), - ), - ) - ) - else: - result = data_point_response.get("result") - if isinstance(result, list) and len(result) == 1: - result = result[0] - queue.resolve((result, None)) - else: - try: - data_protocol = RoborockDataProtocol(int(data_point_number)) - self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}") - if data_protocol in ROBOROCK_DATA_STATUS_PROTOCOL: - if data_protocol not in self.listener_model.protocol_handlers: - self._logger.debug( - f"Got status update({data_protocol.name}) before get_status was called." - ) - return - value = self.listener_model.cache[CacheableAttribute.status].value - value[data_protocol.name] = data_point - status = self._status_type.from_dict(value) - for listener in self.listener_model.protocol_handlers.get(data_protocol, []): - listener(status) - elif data_protocol in ROBOROCK_DATA_CONSUMABLE_PROTOCOL: - if data_protocol not in self.listener_model.protocol_handlers: - self._logger.debug( - f"Got consumable update({data_protocol.name})" - + "before get_consumable was called." - ) - return - value = self.listener_model.cache[CacheableAttribute.consumable].value - value[data_protocol.name] = data_point - consumable = Consumable.from_dict(value) - for listener in self.listener_model.protocol_handlers.get(data_protocol, []): - listener(consumable) - return - except ValueError: - self._logger.warning( - f"Got listener data for {data_point_number}, data: {data_point}. " - f"This lets us update data quicker, please open an issue " - f"at https://github.com/humbertogontijo/python-roborock/issues" - ) - - pass - dps = {data_point_number: data_point} - self._logger.debug(f"Got unknown data point {dps}") - elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE: - payload = data.payload[0:24] - [endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload) - if endpoint.decode().startswith(self._endpoint): - try: - decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce) - except ValueError as err: - raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err - decompressed = Utils.decompress(decrypted) - queue = self._waiting_queue.get(request_id) - if queue: - if isinstance(decompressed, list): - decompressed = decompressed[0] - queue.resolve((decompressed, None)) - else: - queue = self._waiting_queue.get(data.seq) - if queue: - queue.resolve((data.payload, None)) - except Exception as ex: - self._logger.exception(ex) + raise NotImplementedError def on_connection_lost(self, exc: Exception | None) -> None: self._last_disconnection = self.time_func() @@ -320,47 +134,3 @@ async def _send_command( params: list | dict | int | None = None, ): raise NotImplementedError - - @final - async def send_command( - self, - method: RoborockCommand | str, - params: list | dict | int | None = None, - return_type: type[RT] | None = None, - ) -> RT: - cacheable_attribute_result = find_cacheable_attribute(method) - - cache = None - command_type = None - if cacheable_attribute_result is not None: - cache = self.cache[cacheable_attribute_result.attribute] - command_type = cacheable_attribute_result.type - - response: Any = None - if cache is not None and command_type == CommandType.GET: - response = await cache.async_value() - else: - response = await self._send_command(method, params) - if cache is not None and command_type == CommandType.CHANGE: - await cache.refresh_value() - - if return_type: - return return_type.from_dict(response) - return response - - def add_listener( - self, protocol: RoborockDataProtocol, listener: Callable, cache: dict[CacheableAttribute, AttributeCache] - ) -> None: - self.listener_model.cache = cache - if protocol not in self.listener_model.protocol_handlers: - self.listener_model.protocol_handlers[protocol] = [] - self.listener_model.protocol_handlers[protocol].append(listener) - - def remove_listener(self, protocol: RoborockDataProtocol, listener: Callable) -> None: - self.listener_model.protocol_handlers[protocol].remove(listener) - - async def get_from_cache(self, key: CacheableAttribute) -> AttributeCache | None: - val = self.cache.get(key) - if val is not None: - return await val.async_value() - return None diff --git a/roborock/cloud_api.py b/roborock/cloud_api.py index a8fa408..9b79406 100644 --- a/roborock/cloud_api.py +++ b/roborock/cloud_api.py @@ -4,6 +4,7 @@ import base64 import logging import threading +import typing import uuid from asyncio import Lock, Task from typing import Any @@ -11,15 +12,17 @@ import paho.mqtt.client as mqtt -from .api import KEEPALIVE, RoborockClient, md5hex +from .api import KEEPALIVE, RoborockClient from .containers import DeviceData, UserData from .exceptions import RoborockException, VacuumError -from .protocol import MessageParser, Utils +from .protocol import MessageParser, Utils, md5hex from .roborock_future import RoborockFuture from .roborock_message import RoborockMessage from .roborock_typing import RoborockCommand from .util import RoborockLoggerAdapter +if typing.TYPE_CHECKING: + pass _LOGGER = logging.getLogger(__name__) CONNECT_REQUEST_ID = 0 DISCONNECT_REQUEST_ID = 1 @@ -78,7 +81,7 @@ def on_connect(self, *args, **kwargs): connection_queue.resolve((True, None)) def on_message(self, *args, **kwargs): - _, __, msg = args + client, __, msg = args try: messages, _ = MessageParser.parse(msg.payload, local_key=self.device_info.device.local_key) super().on_message_received(messages) diff --git a/roborock/code_mappings.py b/roborock/code_mappings.py index f6f85ab..244314d 100644 --- a/roborock/code_mappings.py +++ b/roborock/code_mappings.py @@ -98,8 +98,8 @@ class RoborockDyadStateCode(RoborockEnum): self_clean_deep_cleaning = 6 self_clean_rinsing = 7 self_clean_dehydrating = 8 - drying = 10 - ventilating = 11 # drying + drying = 9 + ventilating = 10 # drying reserving = 12 mop_washing_paused = 13 dusting_mode = 14 @@ -369,3 +369,84 @@ class RoborockCategory(Enum): def __missing__(self, key): _LOGGER.warning("Missing key %s from category", key) return RoborockCategory.UNKNOWN + + +class DyadSelfCleanMode(RoborockEnum): + self_clean = 1 + self_clean_and_dry = 2 + dry = 3 + ventilation = 4 + + +class DyadSelfCleanLevel(RoborockEnum): + normal = 1 + deep = 2 + + +class DyadWarmLevel(RoborockEnum): + normal = 1 + deep = 2 + + +class DyadMode(RoborockEnum): + wash = 1 + wash_and_dry = 2 + dry = 3 + + +class DyadCleanMode(RoborockEnum): + auto = 1 + max = 2 + dehydration = 3 + power_saving = 4 + + +class DyadSuction(RoborockEnum): + l1 = 1 + l2 = 2 + l3 = 3 + l4 = 4 + l5 = 5 + l6 = 6 + + +class DyadWaterLevel(RoborockEnum): + l1 = 1 + l2 = 2 + l3 = 3 + l4 = 4 + + +class DyadBrushSpeed(RoborockEnum): + l1 = 1 + l2 = 2 + + +class DyadCleanser(RoborockEnum): + none = 0 + normal = 1 + deep = 2 + max = 3 + + +class DyadError(RoborockEnum): + none = 0 + dirty_tank_full = 20000 # Dirty tank full. Empty it + water_level_sensor_stuck = 20001 # Water level sensor is stuck. Clean it. + clean_tank_empty = 20002 # Clean tank empty. Refill now + clean_head_entangled = 20003 # Check if the cleaning head is entangled with foreign objects. + clean_head_too_hot = 20004 # Cleaning head temperature protection. Wait for the temperature to return to normal. + fan_protection_e5 = 10005 # Fan protection (E5). Restart the vacuum cleaner. + cleaning_head_blocked = 20005 # Remove blockages from the cleaning head and pipes. + temperature_protection = 20006 # Temperature protection. Wait for the temperature to return to normal + fan_protection_e4 = 10004 # Fan protection (E4). Restart the vacuum cleaner. + fan_protection_e9 = 10009 # Fan protection (E9). Restart the vacuum cleaner. + battery_temperature_protection_e0 = 10000 + battery_temperature_protection = ( + 20007 # Battery temperature protection. Wait for the temperature to return to a normal range. + ) + battery_temperature_protection_2 = 20008 + power_adapter_error = 20009 # Check if the power adapter is working properly. + dirty_charging_contacts = 10007 # Disconnection between the device and dock. Wipe charging contacts. + low_battery = 20017 # Low battery level. Charge before starting self-cleaning. + battery_under_10 = 20018 # Charge until the battery level exceeds 10% before manually starting self-cleaning. diff --git a/roborock/containers.py b/roborock/containers.py index a6c33b5..96f7ff0 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -829,3 +829,28 @@ class RoborockCategoryDetail(RoborockBase): @dataclass class ProductResponse(RoborockBase): category_detail_list: list[RoborockCategoryDetail] + + +@dataclass +class DyadProductInfo(RoborockBase): + sn: str + ssid: str + timezone: str + posix_timezone: str + ip: str + mac: str + oba: dict + + +@dataclass +class DyadSndState(RoborockBase): + sid_in_use: int + sid_version: int + location: str + bom: str + language: str + + +@dataclass +class DyadOtaNfo(RoborockBase): + mqttOtaData: dict diff --git a/roborock/local_api.py b/roborock/local_api.py index 2f519fc..92e95bc 100644 --- a/roborock/local_api.py +++ b/roborock/local_api.py @@ -8,7 +8,7 @@ from . import DeviceData from .api import RoborockClient -from .exceptions import CommandVacuumError, RoborockConnectionException, RoborockException +from .exceptions import RoborockConnectionException, RoborockException from .protocol import MessageParser from .roborock_message import RoborockMessage, RoborockMessageProtocol from .roborock_typing import RoborockCommand @@ -121,40 +121,3 @@ def _send_msg_raw(self, data: bytes): self.transport.write(data) except Exception as e: raise RoborockException(e) from e - - async def send_message(self, roborock_message: RoborockMessage): - await self.validate_connection() - method = roborock_message.get_method() - params = roborock_message.get_params() - request_id: int | None - if not method or not method.startswith("get"): - request_id = roborock_message.seq - response_protocol = request_id + 1 - else: - request_id = roborock_message.get_request_id() - response_protocol = RoborockMessageProtocol.GENERAL_REQUEST - if request_id is None: - raise RoborockException(f"Failed build message {roborock_message}") - local_key = self.device_info.device.local_key - msg = MessageParser.build(roborock_message, local_key=local_key) - if method: - self._logger.debug(f"id={request_id} Requesting method {method} with {params}") - # Send the command to the Roborock device - async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol)) - self._send_msg_raw(msg) - (response, err) = await async_response - self._diagnostic_data[method if method is not None else "unknown"] = { - "params": roborock_message.get_params(), - "response": response, - "error": err, - } - if err: - raise CommandVacuumError(method, err) from err - if roborock_message.protocol == RoborockMessageProtocol.GENERAL_REQUEST: - self._logger.debug(f"id={request_id} Response from method {roborock_message.get_method()}: {response}") - if response == "retry": - retry_id = roborock_message.get_retry_id() - return self.send_command( - RoborockCommand.RETRY_REQUEST, {"retry_id": retry_id, "retry_count": 8, "method": method} - ) - return response diff --git a/roborock/protocol.py b/roborock/protocol.py index 51750fc..603f06c 100644 --- a/roborock/protocol.py +++ b/roborock/protocol.py @@ -13,7 +13,6 @@ Bytes, Checksum, ChecksumError, - Const, Construct, Container, GreedyBytes, @@ -36,11 +35,19 @@ _LOGGER = logging.getLogger(__name__) SALT = b"TXdfu$jyZ#TZHsg4" +A01_HASH = "726f626f726f636b2d67a6d6da" +A01_AES_DECIPHER = "ELSYN0wTI4AUm7C4" BROADCAST_TOKEN = b"qWKYcdQWrbm9hPqe" AP_CONFIG = 1 SOCK_DISCOVERY = 2 +def md5hex(message: str) -> str: + md5 = hashlib.md5() + md5.update(message.encode()) + return md5.hexdigest() + + class RoborockProtocol(asyncio.DatagramProtocol): def __init__(self, timeout: int = 5): self.timeout = timeout @@ -199,12 +206,22 @@ def _encode(self, obj, context, _): :param obj: JSON object to encrypt """ + if context.version == b"A01": + iv = md5hex(format(context.random, "08x") + A01_HASH)[8:24] + decipher = AES.new(bytes(A01_AES_DECIPHER, "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8")) + f = decipher.encrypt(obj) + return f token = self.token_func(context) encrypted = Utils.encrypt_ecb(obj, token) return encrypted def _decode(self, obj, context, _): """Decrypts the given payload with the token stored in the context.""" + if context.version == b"A01": + iv = md5hex(format(context.random, "08x") + A01_HASH)[8:24] + decipher = AES.new(bytes(A01_AES_DECIPHER, "utf-8"), AES.MODE_CBC, bytes(iv, "utf-8")) + f = decipher.decrypt(obj) + return f token = self.token_func(context) decrypted = Utils.decrypt_ecb(obj, token) return decrypted @@ -227,9 +244,9 @@ def _parse(self, stream, context, path): class PrefixedStruct(Struct): def _parse(self, stream, context, path): - subcon1 = Peek(Optional(Const(b"1.0"))) + subcon1 = Peek(Optional(Bytes(3))) peek_version = subcon1.parse_stream(stream, **context) - if peek_version is None: + if peek_version not in (b"1.0", b"A01"): subcon2 = Bytes(4) subcon2.parse_stream(stream, **context) return super()._parse(stream, context, path) @@ -251,7 +268,7 @@ def _build(self, obj, stream, context, path): _Message = RawCopy( Struct( - "version" / Const(b"1.0"), + "version" / Bytes(3), "seq" / Int32ub, "random" / Int32ub, "timestamp" / Int32ub, @@ -280,7 +297,7 @@ def _build(self, obj, stream, context, path): "message" / RawCopy( Struct( - "version" / Const(b"1.0"), + "version" / Bytes(3), "seq" / Int32ub, "protocol" / Int16ub, "payload" / EncryptionAdapter(lambda ctx: BROADCAST_TOKEN), diff --git a/roborock/roborock_message.py b/roborock/roborock_message.py index 43ad3fe..2128892 100644 --- a/roborock/roborock_message.py +++ b/roborock/roborock_message.py @@ -57,7 +57,7 @@ class RoborockDyadDataProtocol(RoborockEnum): COUNTDOWN_TIME = 210 AUTO_SELF_CLEAN_SET = 212 AUTO_DRY = 213 - MESH_LEF = 214 + MESH_LEFT = 214 BRUSH_LEFT = 215 ERROR = 216 MESH_RESET = 218 @@ -70,7 +70,7 @@ class RoborockDyadDataProtocol(RoborockEnum): SILENT_MODE = 226 SILENT_MODE_START_TIME = 227 SILENT_MODE_END_TIME = 228 - RECENT_RUN_TIMe = 229 + RECENT_RUN_TIME = 229 TOTAL_RUN_TIME = 230 FEATURE_INFO = 235 RECOVER_SETTINGS = 236 diff --git a/roborock/version_1_apis/roborock_client_v1.py b/roborock/version_1_apis/roborock_client_v1.py index 908872e..14f8861 100644 --- a/roborock/version_1_apis/roborock_client_v1.py +++ b/roborock/version_1_apis/roborock_client_v1.py @@ -1,14 +1,30 @@ import asyncio +import dataclasses import json import math +import struct import time -from collections.abc import Coroutine +from collections.abc import Callable, Coroutine from random import randint -from typing import Any +from typing import Any, TypeVar, final -from roborock import DeviceProp, DockSummary, RoborockCommand, RoborockDockTypeCode +from roborock import ( + DeviceProp, + DockSummary, + RoborockCommand, + RoborockDockTypeCode, + RoborockException, + UnknownMethodError, + VacuumError, +) from roborock.api import RoborockClient -from roborock.command_cache import CacheableAttribute +from roborock.command_cache import ( + CacheableAttribute, + CommandType, + RoborockAttribute, + find_cacheable_attribute, + get_cache_map, +) from roborock.containers import ( ChildLockStatus, CleanRecord, @@ -21,6 +37,7 @@ ModelStatus, MultiMapsList, NetworkInfo, + RoborockBase, RoomMapping, S7MaxVStatus, ServerTimer, @@ -29,7 +46,15 @@ ValleyElectricityTimer, WashTowelMode, ) -from roborock.util import unpack_list +from roborock.protocol import Utils +from roborock.roborock_message import ( + ROBOROCK_DATA_CONSUMABLE_PROTOCOL, + ROBOROCK_DATA_STATUS_PROTOCOL, + RoborockDataProtocol, + RoborockMessage, + RoborockMessageProtocol, +) +from roborock.util import RepeatableTask, unpack_list COMMANDS_SECURED = [ RoborockCommand.GET_MAP_V1, @@ -41,14 +66,96 @@ RoborockDockTypeCode.s8_dock, RoborockDockTypeCode.p10_dock, ] +RT = TypeVar("RT", bound=RoborockBase) +EVICT_TIME = 60 + + +class AttributeCache: + def __init__(self, attribute: RoborockAttribute, api: RoborockClient): + self.attribute = attribute + self.api = api + self.attribute = attribute + self.task = RepeatableTask(self.api.event_loop, self._async_value, EVICT_TIME) + self._value: Any = None + self._mutex = asyncio.Lock() + self.unsupported: bool = False + + @property + def value(self): + return self._value + + async def _async_value(self): + if self.unsupported: + return None + try: + self._value = await self.api._send_command(self.attribute.get_command) + except UnknownMethodError as err: + # Limit the amount of times we call unsupported methods + self.unsupported = True + raise err + return self._value + + async def async_value(self): + async with self._mutex: + if self._value is None: + return await self.task.reset() + return self._value + + def stop(self): + self.task.cancel() + + async def update_value(self, params): + if self.attribute.set_command is None: + raise RoborockException(f"{self.attribute.attribute} have no set command") + response = await self.api._send_command(self.attribute.set_command, params) + await self._async_value() + return response + + async def add_value(self, params): + if self.attribute.add_command is None: + raise RoborockException(f"{self.attribute.attribute} have no add command") + response = await self.api._send_command(self.attribute.add_command, params) + await self._async_value() + return response + + async def close_value(self, params=None): + if self.attribute.close_command is None: + raise RoborockException(f"{self.attribute.attribute} have no close command") + response = await self.api._send_command(self.attribute.close_command, params) + await self._async_value() + return response + + async def refresh_value(self): + await self._async_value() + + +@dataclasses.dataclass +class ListenerModel: + protocol_handlers: dict[RoborockDataProtocol, list[Callable[[Status | Consumable], None]]] + cache: dict[CacheableAttribute, AttributeCache] class RoborockClientV1(RoborockClient): - def __init__(self, device_info: DeviceData, cache, logger, endpoint: str): + _listeners: dict[str, ListenerModel] = {} + + def __init__(self, device_info: DeviceData, logger, endpoint: str): super().__init__(endpoint, device_info) self._status_type: type[Status] = ModelStatus.get(device_info.model, S7MaxVStatus) - self.cache = cache self._logger = logger + self.cache: dict[CacheableAttribute, AttributeCache] = { + cacheable_attribute: AttributeCache(attr, self) for cacheable_attribute, attr in get_cache_map().items() + } + if device_info.device.duid not in self._listeners: + self._listeners[device_info.device.duid] = ListenerModel({}, self.cache) + self.listener_model = self._listeners[device_info.device.duid] + + def release(self): + super().release() + [item.stop() for item in self.cache.values()] + + async def async_release(self): + await super().async_release() + [item.stop() for item in self.cache.values()] @property def status_type(self) -> type[Status]: @@ -225,3 +332,138 @@ def _get_payload( ).encode() ) return request_id, timestamp, payload + + def on_message_received(self, messages: list[RoborockMessage]) -> None: + try: + self._last_device_msg_in = self.time_func() + for data in messages: + protocol = data.protocol + if data.payload and protocol in [ + RoborockMessageProtocol.RPC_RESPONSE, + RoborockMessageProtocol.GENERAL_REQUEST, + ]: + payload = json.loads(data.payload.decode()) + for data_point_number, data_point in payload.get("dps").items(): + if data_point_number == "102": + data_point_response = json.loads(data_point) + request_id = data_point_response.get("id") + queue = self._waiting_queue.get(request_id) + if queue and queue.protocol == protocol: + error = data_point_response.get("error") + if error: + queue.resolve( + ( + None, + VacuumError( + error.get("code"), + error.get("message"), + ), + ) + ) + else: + result = data_point_response.get("result") + if isinstance(result, list) and len(result) == 1: + result = result[0] + queue.resolve((result, None)) + else: + try: + data_protocol = RoborockDataProtocol(int(data_point_number)) + self._logger.debug(f"Got device update for {data_protocol.name}: {data_point}") + if data_protocol in ROBOROCK_DATA_STATUS_PROTOCOL: + if data_protocol not in self.listener_model.protocol_handlers: + self._logger.debug( + f"Got status update({data_protocol.name}) before get_status was called." + ) + return + value = self.listener_model.cache[CacheableAttribute.status].value + value[data_protocol.name] = data_point + status = self._status_type.from_dict(value) + for listener in self.listener_model.protocol_handlers.get(data_protocol, []): + listener(status) + elif data_protocol in ROBOROCK_DATA_CONSUMABLE_PROTOCOL: + if data_protocol not in self.listener_model.protocol_handlers: + self._logger.debug( + f"Got consumable update({data_protocol.name})" + + "before get_consumable was called." + ) + return + value = self.listener_model.cache[CacheableAttribute.consumable].value + value[data_protocol.name] = data_point + consumable = Consumable.from_dict(value) + for listener in self.listener_model.protocol_handlers.get(data_protocol, []): + listener(consumable) + return + except ValueError: + self._logger.warning( + f"Got listener data for {data_point_number}, data: {data_point}. " + f"This lets us update data quicker, please open an issue " + f"at https://github.com/humbertogontijo/python-roborock/issues" + ) + + pass + dps = {data_point_number: data_point} + self._logger.debug(f"Got unknown data point {dps}") + elif data.payload and protocol == RoborockMessageProtocol.MAP_RESPONSE: + payload = data.payload[0:24] + [endpoint, _, request_id, _] = struct.unpack("<8s8sH6s", payload) + if endpoint.decode().startswith(self._endpoint): + try: + decrypted = Utils.decrypt_cbc(data.payload[24:], self._nonce) + except ValueError as err: + raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err + decompressed = Utils.decompress(decrypted) + queue = self._waiting_queue.get(request_id) + if queue: + if isinstance(decompressed, list): + decompressed = decompressed[0] + queue.resolve((decompressed, None)) + else: + queue = self._waiting_queue.get(data.seq) + if queue: + queue.resolve((data.payload, None)) + except Exception as ex: + self._logger.exception(ex) + + async def get_from_cache(self, key: CacheableAttribute) -> AttributeCache | None: + val = self.cache.get(key) + if val is not None: + return await val.async_value() + return None + + def add_listener( + self, protocol: RoborockDataProtocol, listener: Callable, cache: dict[CacheableAttribute, AttributeCache] + ) -> None: + self.listener_model.cache = cache + if protocol not in self.listener_model.protocol_handlers: + self.listener_model.protocol_handlers[protocol] = [] + self.listener_model.protocol_handlers[protocol].append(listener) + + def remove_listener(self, protocol: RoborockDataProtocol, listener: Callable) -> None: + self.listener_model.protocol_handlers[protocol].remove(listener) + + @final + async def send_command( + self, + method: RoborockCommand | str, + params: list | dict | int | None = None, + return_type: type[RT] | None = None, + ) -> RT: + cacheable_attribute_result = find_cacheable_attribute(method) + + cache = None + command_type = None + if cacheable_attribute_result is not None: + cache = self.cache[cacheable_attribute_result.attribute] + command_type = cacheable_attribute_result.type + + response: Any = None + if cache is not None and command_type == CommandType.GET: + response = await cache.async_value() + else: + response = await self._send_command(method, params) + if cache is not None and command_type == CommandType.CHANGE: + await cache.refresh_value() + + if return_type: + return return_type.from_dict(response) + return response diff --git a/roborock/version_1_apis/roborock_local_client_v1.py b/roborock/version_1_apis/roborock_local_client_v1.py index 5ef1567..c4f30a3 100644 --- a/roborock/version_1_apis/roborock_local_client_v1.py +++ b/roborock/version_1_apis/roborock_local_client_v1.py @@ -1,6 +1,9 @@ +import asyncio + from roborock.local_api import RoborockLocalClient -from .. import DeviceData, RoborockCommand +from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException +from ..protocol import MessageParser from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1 @@ -8,7 +11,7 @@ class RoborockLocalClientV1(RoborockLocalClient, RoborockClientV1): def __init__(self, device_data: DeviceData, queue_timeout: int = 4): RoborockLocalClient.__init__(self, device_data, queue_timeout) - RoborockClientV1.__init__(self, device_data, self.cache, self._logger, "abc") + RoborockClientV1.__init__(self, device_data, self._logger, "abc") def build_roborock_message( self, method: RoborockCommand | str, params: list | dict | int | None = None @@ -30,3 +33,40 @@ async def _send_command( ): roborock_message = self.build_roborock_message(method, params) return await self.send_message(roborock_message) + + async def send_message(self, roborock_message: RoborockMessage): + await self.validate_connection() + method = roborock_message.get_method() + params = roborock_message.get_params() + request_id: int | None + if not method or not method.startswith("get"): + request_id = roborock_message.seq + response_protocol = request_id + 1 + else: + request_id = roborock_message.get_request_id() + response_protocol = RoborockMessageProtocol.GENERAL_REQUEST + if request_id is None: + raise RoborockException(f"Failed build message {roborock_message}") + local_key = self.device_info.device.local_key + msg = MessageParser.build(roborock_message, local_key=local_key) + if method: + self._logger.debug(f"id={request_id} Requesting method {method} with {params}") + # Send the command to the Roborock device + async_response = asyncio.ensure_future(self._async_response(request_id, response_protocol)) + self._send_msg_raw(msg) + (response, err) = await async_response + self._diagnostic_data[method if method is not None else "unknown"] = { + "params": roborock_message.get_params(), + "response": response, + "error": err, + } + if err: + raise CommandVacuumError(method, err) from err + if roborock_message.protocol == RoborockMessageProtocol.GENERAL_REQUEST: + self._logger.debug(f"id={request_id} Response from method {roborock_message.get_method()}: {response}") + if response == "retry": + retry_id = roborock_message.get_retry_id() + return self.send_command( + RoborockCommand.RETRY_REQUEST, {"retry_id": retry_id, "retry_count": 8, "method": method} + ) + return response diff --git a/roborock/version_1_apis/roborock_mqtt_client_v1.py b/roborock/version_1_apis/roborock_mqtt_client_v1.py index a403cea..f9d885d 100644 --- a/roborock/version_1_apis/roborock_mqtt_client_v1.py +++ b/roborock/version_1_apis/roborock_mqtt_client_v1.py @@ -8,7 +8,10 @@ from ..containers import DeviceData, UserData from ..exceptions import CommandVacuumError, RoborockException from ..protocol import MessageParser, Utils -from ..roborock_message import RoborockMessage, RoborockMessageProtocol +from ..roborock_message import ( + RoborockMessage, + RoborockMessageProtocol, +) from ..roborock_typing import RoborockCommand from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1 @@ -21,7 +24,7 @@ def __init__(self, user_data: UserData, device_info: DeviceData, queue_timeout: endpoint = base64.b64encode(Utils.md5(rriot.k.encode())[8:14]).decode() RoborockMqttClient.__init__(self, user_data, device_info, queue_timeout) - RoborockClientV1.__init__(self, device_info, self.cache, self._logger, endpoint) + RoborockClientV1.__init__(self, device_info, self._logger, endpoint) def _send_msg_raw(self, msg: bytes) -> None: info = self.publish(f"rr/m/i/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}", msg)