From 3a72b58273d80f0a5d8d8da473e2b0e16aeea722 Mon Sep 17 00:00:00 2001 From: Luke Date: Mon, 20 Feb 2023 15:32:09 -0500 Subject: [PATCH 1/6] chore: added some typing --- roborock/api.py | 59 +++++++++++++++++++------------------- roborock/roborock_queue.py | 7 +++-- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/roborock/api.py b/roborock/api.py index e5c3cd3..ed013b9 100644 --- a/roborock/api.py +++ b/roborock/api.py @@ -16,6 +16,7 @@ import time from asyncio import Lock from asyncio.exceptions import TimeoutError, CancelledError +from typing import Any from urllib.parse import urlparse import aiohttp @@ -53,31 +54,31 @@ MQTT_KEEPALIVE = 60 -def md5hex(message: str): +def md5hex(message: str) -> str: md5 = hashlib.md5() md5.update(message.encode()) return md5.hexdigest() -def md5bin(message: str): +def md5bin(message: str) -> bytes: md5 = hashlib.md5() md5.update(message.encode()) return md5.digest() -def encode_timestamp(_timestamp: int): +def encode_timestamp(_timestamp: int) -> str: hex_value = f"{_timestamp:x}".zfill(8) return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4]))) class PreparedRequest: - def __init__(self, base_url: str, base_headers: dict = None): + def __init__(self, base_url: str, base_headers: dict = None) -> None: self.base_url = base_url self.base_headers = base_headers or {} async def request( self, method: str, url: str, params=None, data=None, headers=None - ): + ) -> dict | list: _url = "/".join(s.strip("/") for s in [self.base_url, url]) _headers = {**self.base_headers, **(headers or {})} async with aiohttp.ClientSession() as session: @@ -99,7 +100,7 @@ async def request( class RoborockMqttClient(mqtt.Client): _thread: threading.Thread - def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]): + def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]) -> None: rriot = user_data.rriot self._mqtt_user = rriot.user self._mqtt_domain = rriot.domain @@ -126,11 +127,11 @@ def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo self._last_device_msg_in = mqtt.time_func() self._last_disconnection = mqtt.time_func() - def __del__(self): + def __del__(self) -> None: self.sync_disconnect() @run_in_executor() - async def on_connect(self, _client, _, __, rc, ___=None): + async def on_connect(self, _client, _, __, rc, ___=None) -> None: connection_queue = self._waiting_queue.get(0) if rc != mqtt.MQTT_ERR_SUCCESS: message = f"Failed to connect (rc: {rc})" @@ -156,7 +157,7 @@ async def on_connect(self, _client, _, __, rc, ___=None): await connection_queue.async_put((True, None), timeout=QUEUE_TIMEOUT) @run_in_executor() - async def on_message(self, _client, _, msg, __=None): + async def on_message(self, _client, _, msg, __=None) -> None: try: async with self._mutex: self._last_device_msg_in = mqtt.time_func() @@ -219,7 +220,7 @@ async def on_message(self, _client, _, msg, __=None): _LOGGER.exception(ex) @run_in_executor() - async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None): + async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None: try: async with self._mutex: self._last_disconnection = mqtt.time_func() @@ -241,28 +242,28 @@ async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None): _LOGGER.exception(ex) @run_in_executor() - async def _async_check_keepalive(self): + async def _async_check_keepalive(self) -> None: async with self._mutex: now = mqtt.time_func() if now - self._last_disconnection > self._keepalive ** 2 and now - self._last_device_msg_in > self._keepalive: self._ping_t = self._last_device_msg_in - def _check_keepalive(self): + def _check_keepalive(self) -> None: self._async_check_keepalive() super()._check_keepalive() - def sync_stop_loop(self): + def sync_stop_loop(self) -> None: if self._thread: _LOGGER.info("Stopping mqtt loop") super().loop_stop() - def sync_start_loop(self): + def sync_start_loop(self) -> None: if not self._thread or not self._thread.is_alive(): self.sync_stop_loop() _LOGGER.info("Starting mqtt loop") super().loop_start() - def sync_disconnect(self): + def sync_disconnect(self) -> bool: rc = mqtt.MQTT_ERR_AGAIN if self.is_connected(): _LOGGER.info("Disconnecting from mqtt") @@ -271,7 +272,7 @@ def sync_disconnect(self): raise RoborockException(f"Failed to disconnect (rc:{rc})") return rc == mqtt.MQTT_ERR_SUCCESS - def sync_connect(self): + def sync_connect(self) -> bool: rc = mqtt.MQTT_ERR_AGAIN self.sync_start_loop() if not self.is_connected(): @@ -285,7 +286,7 @@ def sync_connect(self): raise RoborockException(f"Failed to connect (rc:{rc})") return rc == mqtt.MQTT_ERR_SUCCESS - async def _async_response(self, request_id: int, protocol_id: int = 0): + async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, RoborockException | None]: try: queue = RoborockQueue(protocol_id) self._waiting_queue[request_id] = queue @@ -298,7 +299,7 @@ async def _async_response(self, request_id: int, protocol_id: int = 0): finally: del self._waiting_queue[request_id] - async def async_disconnect(self): + async def async_disconnect(self) -> Any: async with self._mutex: disconnecting = self.sync_disconnect() if disconnecting: @@ -307,7 +308,7 @@ async def async_disconnect(self): raise RoborockException(err) from err return response - async def async_connect(self): + async def async_connect(self) -> Any: async with self._mutex: connecting = self.sync_connect() if connecting: @@ -316,10 +317,10 @@ async def async_connect(self): raise RoborockException(err) from err return response - async def validate_connection(self): + async def validate_connection(self) -> None: await self.async_connect() - def _decode_msg(self, msg, device: HomeDataDevice): + def _decode_msg(self, msg, device: HomeDataDevice) -> dict[str, Any]: if msg[0:3] != "1.0".encode(): raise RoborockException("Unknown protocol version") crc32 = binascii.crc32(msg[0: len(msg) - 4]) @@ -344,7 +345,7 @@ def _decode_msg(self, msg, device: HomeDataDevice): "payload": decrypted_payload, } - def _send_msg_raw(self, device_id, protocol, timestamp, payload): + def _send_msg_raw(self, device_id, protocol, timestamp, payload) -> None: local_key = self.device_map[device_id].device.local_key aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt) cipher = AES.new(aes_key, AES.MODE_ECB) @@ -438,7 +439,7 @@ async def get_consumable(self, device_id: str) -> Consumable: if isinstance(consumable, dict): return Consumable(consumable) - async def get_prop(self, device_id: str): + async def get_prop(self, device_id: str) -> RoborockDeviceProp: [status, dnd_timer, clean_summary, consumable] = await asyncio.gather( *[ self.get_status(device_id), @@ -457,7 +458,7 @@ async def get_prop(self, device_id: str): status, dnd_timer, clean_summary, consumable, last_clean_record ) - async def get_multi_maps_list(self, device_id): + async def get_multi_maps_list(self, device_id) -> MultiMapsList: multi_maps_list = await self.send_command( device_id, RoborockCommand.GET_MULTI_MAPS_LIST ) @@ -476,7 +477,7 @@ def __init__(self, username: str, base_url=None) -> None: self.base_url = base_url self._device_identifier = secrets.token_urlsafe(16) - async def _get_base_url(self): + async def _get_base_url(self) -> str: if not self.base_url: url_request = PreparedRequest(self._default_url) response = await url_request.request( @@ -495,7 +496,7 @@ def _get_header_client_id(self): md5.update(self._device_identifier.encode()) return base64.b64encode(md5.digest()).decode() - async def request_code(self): + async def request_code(self) -> None: base_url = await self._get_base_url() header_clientid = self._get_header_client_id() code_request = PreparedRequest(base_url, {"header_clientid": header_clientid}) @@ -512,7 +513,7 @@ async def request_code(self): if code_response.get("code") != 200: raise RoborockException(code_response.get("msg")) - async def pass_login(self, password: str): + async def pass_login(self, password: str) -> UserData: base_url = await self._get_base_url() header_clientid = self._get_header_client_id() @@ -531,7 +532,7 @@ async def pass_login(self, password: str): raise RoborockException(login_response.get("msg")) return UserData(login_response.get("data")) - async def code_login(self, code): + async def code_login(self, code) -> UserData: base_url = await self._get_base_url() header_clientid = self._get_header_client_id() @@ -550,7 +551,7 @@ async def code_login(self, code): raise RoborockException(login_response.get("msg")) return UserData(login_response.get("data")) - async def get_home_data(self, user_data: UserData): + async def get_home_data(self, user_data: UserData) -> HomeData: base_url = await self._get_base_url() header_clientid = self._get_header_client_id() rriot = user_data.rriot diff --git a/roborock/roborock_queue.py b/roborock/roborock_queue.py index 006bd19..6902ff1 100644 --- a/roborock/roborock_queue.py +++ b/roborock/roborock_queue.py @@ -1,5 +1,8 @@ import asyncio from asyncio import Queue +from typing import Any + +from roborock import RoborockException class RoborockQueue(Queue): @@ -8,8 +11,8 @@ def __init__(self, protocol: int, *args): super().__init__(*args) self.protocol = protocol - async def async_put(self, item, timeout): + async def async_put(self, item: tuple[Any, RoborockException | None], timeout: float | int) -> None: return await asyncio.wait_for(self.put(item), timeout=timeout) - async def async_get(self, timeout): + async def async_get(self, timeout: float | int) -> tuple[Any, RoborockException | None]: return await asyncio.wait_for(self.get(), timeout=timeout) From be20ae1fb8c3055b54de083b542cee86874ba9f7 Mon Sep 17 00:00:00 2001 From: Luke Date: Mon, 20 Feb 2023 16:29:46 -0500 Subject: [PATCH 2/6] chore: added typing for containers --- roborock/containers.py | 232 ++++++++++++++++++++--------------------- 1 file changed, 116 insertions(+), 116 deletions(-) diff --git a/roborock/containers.py b/roborock/containers.py index b3cda46..3b9f4e3 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -232,19 +232,19 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def region(self): + def region(self) -> str: return self.get(UserDataRRiotReferenceField.REGION) @property - def api(self): + def api(self) -> str: return self.get(UserDataRRiotReferenceField.API) @property - def mqtt(self): + def mqtt(self) -> str: return self.get(UserDataRRiotReferenceField.MQTT) @property - def l_unknown(self): + def l_unknown(self) -> str: return self.get(UserDataRRiotReferenceField.L_UNKNOWN) @@ -253,19 +253,19 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def user(self): + def user(self) -> str: return self.get(UserDataRRiotField.USER) @property - def password(self): + def password(self) -> str: return self.get(UserDataRRiotField.PASSWORD) @property - def h_unknown(self): + def h_unknown(self) -> str: return self.get(UserDataRRiotField.H_UNKNOWN) @property - def domain(self): + def domain(self) -> str: return self.get(UserDataRRiotField.DOMAIN) @property @@ -278,35 +278,35 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def uid(self): + def uid(self) -> int: return self.get(UserDataField.UID) @property - def token_type(self): + def token_type(self) -> str: return self.get(UserDataField.TOKEN_TYPE) @property - def token(self): + def token(self) -> str: return self.get(UserDataField.TOKEN) @property - def rr_uid(self): + def rr_uid(self) -> str: return self.get(UserDataField.RR_UID) @property - def region(self): + def region(self) -> str: return self.get(UserDataField.REGION) @property - def country_code(self): + def country_code(self) -> str: return self.get(UserDataField.COUNTRY_CODE) @property - def country(self): + def country(self) -> str: return self.get(UserDataField.COUNTRY) @property - def nickname(self): + def nickname(self) -> str: return self.get(UserDataField.NICKNAME) @property @@ -314,11 +314,11 @@ def rriot(self) -> RRiot: return RRiot(self.get(UserDataField.RRIOT)) @property - def tuya_device_state(self): + def tuya_device_state(self) -> int: return self.get(UserDataField.TUYA_DEVICE_STATE) @property - def avatar_url(self): + def avatar_url(self) -> str: return self.get(UserDataField.AVATAR_URL) @@ -381,23 +381,23 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def id(self): + def id(self) -> str: return self.get(HomeDataProductField.ID) @property - def name(self): + def name(self) -> str: return self.get(HomeDataProductField.NAME) @property - def code(self): + def code(self) -> str: return self.get(HomeDataProductField.CODE) @property - def model(self): + def model(self) -> str: return self.get(HomeDataProductField.MODEL) @property - def iconurl(self): + def iconurl(self) -> str: return self.get(HomeDataProductField.ICONURL) @property @@ -405,11 +405,11 @@ def attribute(self): return self.get(HomeDataProductField.ATTRIBUTE) @property - def capability(self): + def capability(self) -> int: return self.get(HomeDataProductField.CAPABILITY) @property - def category(self): + def category(self) -> str: return self.get(HomeDataProductField.CATEGORY) @property @@ -467,7 +467,7 @@ def duid(self) -> str: return self.get(HomeDataDeviceField.DUID) @property - def name(self): + def name(self) -> str: return self.get(HomeDataDeviceField.NAME) @property @@ -475,7 +475,7 @@ def attribute(self): return self.get(HomeDataDeviceField.ATTRIBUTE) @property - def activetime(self): + def activetime(self) -> int: return self.get(HomeDataDeviceField.ACTIVETIME) @property @@ -487,15 +487,15 @@ def runtime_env(self): return self.get(HomeDataDeviceField.RUNTIME_ENV) @property - def time_zone_id(self): + def time_zone_id(self) -> str: return self.get(HomeDataDeviceField.TIME_ZONE_ID) @property - def icon_url(self): + def icon_url(self) -> str: return self.get(HomeDataDeviceField.ICON_URL) @property - def product_id(self): + def product_id(self) -> str: return self.get(HomeDataDeviceField.PRODUCT_ID) @property @@ -507,7 +507,7 @@ def lat(self): return self.get(HomeDataDeviceField.LAT) @property - def share(self): + def share(self) -> bool: return self.get(HomeDataDeviceField.SHARE) @property @@ -515,15 +515,15 @@ def share_time(self): return self.get(HomeDataDeviceField.SHARE_TIME) @property - def online(self): + def online(self) -> bool: return self.get(HomeDataDeviceField.ONLINE) @property - def fv(self): + def fv(self) -> str: return self.get(HomeDataDeviceField.FV) @property - def pv(self): + def pv(self) -> str: return self.get(HomeDataDeviceField.PV) @property @@ -535,7 +535,7 @@ def tuya_uuid(self): return self.get(HomeDataDeviceField.TUYA_UUID) @property - def tuya_migrated(self): + def tuya_migrated(self) -> bool: return self.get(HomeDataDeviceField.TUYA_MIGRATED) @property @@ -543,15 +543,15 @@ def extra(self): return self.get(HomeDataDeviceField.EXTRA) @property - def sn(self): + def sn(self) -> str: return self.get(HomeDataDeviceField.SN) @property - def feature_set(self): + def feature_set(self) -> str: return self.get(HomeDataDeviceField.FEATURE_SET) @property - def new_feature_set(self): + def new_feature_set(self) -> str: return self.get(HomeDataDeviceField.NEW_FEATURE_SET) @property @@ -559,7 +559,7 @@ def device_status(self) -> HomeDataDeviceStatus: return HomeDataDeviceStatus(self.get(HomeDataDeviceField.DEVICE_STATUS)) @property - def silent_ota_switch(self): + def silent_ota_switch(self) -> bool: return self.get(HomeDataDeviceField.SILENT_OTA_SWITCH) @@ -581,11 +581,11 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def id(self): + def id(self) -> int: return self.get(HomeDataField.ID) @property - def name(self): + def name(self) -> str: return self.get(HomeDataField.NAME) @property @@ -622,27 +622,27 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def msg_ver(self): + def msg_ver(self) -> int: return self.get(StatusField.MSG_VER) @property - def msg_seq(self): + def msg_seq(self) -> int: return self.get(StatusField.MSG_SEQ) @property - def state(self): + def state(self) -> int: return self.get(StatusField.STATE) @property - def battery(self): + def battery(self) -> int: return self.get(StatusField.BATTERY) @property - def clean_time(self): + def clean_time(self) -> int: return self.get(StatusField.CLEAN_TIME) @property - def clean_area(self): + def clean_area(self) -> int: return self.get(StatusField.CLEAN_AREA) @property @@ -650,143 +650,143 @@ def error_code(self) -> int: return self.get(StatusField.ERROR_CODE) @property - def map_present(self): + def map_present(self) -> int: return self.get(StatusField.MAP_PRESENT) @property - def in_cleaning(self): + def in_cleaning(self) -> int: return self.get(StatusField.IN_CLEANING) @property - def in_returning(self): + def in_returning(self) -> int: return self.get(StatusField.IN_RETURNING) @property - def in_fresh_state(self): + def in_fresh_state(self) -> int: return self.get(StatusField.IN_FRESH_STATE) @property - def lab_status(self): + def lab_status(self) -> int: return self.get(StatusField.LAB_STATUS) @property - def water_box_status(self): + def water_box_status(self) -> int: return self.get(StatusField.WATER_BOX_STATUS) @property - def back_type(self): + def back_type(self) -> int: return self.get(StatusField.BACK_TYPE) @property - def wash_phase(self): + def wash_phase(self) -> int: return self.get(StatusField.WASH_PHASE) @property - def wash_ready(self): + def wash_ready(self) -> int: return self.get(StatusField.WASH_READY) @property - def fan_power(self): + def fan_power(self) -> int: return self.get(StatusField.FAN_POWER) @property - def dnd_enabled(self): + def dnd_enabled(self) -> int: return self.get(StatusField.DND_ENABLED) @property - def map_status(self): + def map_status(self) -> int: return self.get(StatusField.MAP_STATUS) @property - def is_locating(self): + def is_locating(self) -> int: return self.get(StatusField.IS_LOCATING) @property - def lock_status(self): + def lock_status(self) -> int: return self.get(StatusField.LOCK_STATUS) @property - def water_box_mode(self): + def water_box_mode(self) -> int: return self.get(StatusField.WATER_BOX_MODE) @property - def water_box_carriage_status(self): + def water_box_carriage_status(self) -> int: return self.get(StatusField.WATER_BOX_CARRIAGE_STATUS) @property - def mop_forbidden_enable(self): + def mop_forbidden_enable(self) -> int: return self.get(StatusField.MOP_FORBIDDEN_ENABLE) @property - def camera_status(self): + def camera_status(self) -> int: return self.get(StatusField.CAMERA_STATUS) @property - def is_exploring(self): + def is_exploring(self) -> int: return self.get(StatusField.IS_EXPLORING) @property - def home_sec_status(self): + def home_sec_status(self) -> int: return self.get(StatusField.HOME_SEC_STATUS) @property - def home_sec_enable_password(self): + def home_sec_enable_password(self) -> int: return self.get(StatusField.HOME_SEC_ENABLE_PASSWORD) @property - def adbumper_status(self): + def adbumper_status(self) -> list[int]: return self.get(StatusField.ADBUMPER_STATUS) @property - def water_shortage_status(self): + def water_shortage_status(self) -> int: return self.get(StatusField.WATER_SHORTAGE_STATUS) @property - def dock_type(self): + def dock_type(self) -> int: return self.get(StatusField.DOCK_TYPE) @property - def dust_collection_status(self): + def dust_collection_status(self) -> int: return self.get(StatusField.DUST_COLLECTION_STATUS) @property - def auto_dust_collection(self): + def auto_dust_collection(self) -> int: return self.get(StatusField.AUTO_DUST_COLLECTION) @property - def avoid_count(self): + def avoid_count(self) -> int: return self.get(StatusField.AVOID_COUNT) @property - def mop_mode(self): + def mop_mode(self) -> int: return self.get(StatusField.MOP_MODE) @property - def debug_mode(self): + def debug_mode(self) -> int: return self.get(StatusField.DEBUG_MODE) @property - def collision_avoid_status(self): + def collision_avoid_status(self) -> int: return self.get(StatusField.COLLISION_AVOID_STATUS) @property - def switch_map_mode(self): + def switch_map_mode(self) -> int: return self.get(StatusField.SWITCH_MAP_MODE) @property - def dock_error_status(self): + def dock_error_status(self) -> int: return self.get(StatusField.DOCK_ERROR_STATUS) @property - def charge_status(self): + def charge_status(self) -> int: return self.get(StatusField.CHARGE_STATUS) @property - def unsave_map_reason(self): + def unsave_map_reason(self) -> int: return self.get(StatusField.UNSAVE_MAP_REASON) @property - def unsave_map_flag(self): + def unsave_map_flag(self) -> int: return self.get(StatusField.UNSAVE_MAP_FLAG) @@ -795,23 +795,23 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def start_hour(self): + def start_hour(self) -> int: return self.get(DNDTimerField.START_HOUR) @property - def start_minute(self): + def start_minute(self) -> int: return self.get(DNDTimerField.START_MINUTE) @property - def end_hour(self): + def end_hour(self) -> int: return self.get(DNDTimerField.END_HOUR) @property - def end_minute(self): + def end_minute(self) -> int: return self.get(DNDTimerField.END_MINUTE) @property - def enabled(self): + def enabled(self) -> int: return self.get(DNDTimerField.ENABLED) @@ -820,19 +820,19 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def clean_time(self): + def clean_time(self) -> int: return self.get(CleanSummaryField.CLEAN_TIME) @property - def clean_area(self): + def clean_area(self) -> int: return self.get(CleanSummaryField.CLEAN_AREA) @property - def clean_count(self): + def clean_count(self) -> int: return self.get(CleanSummaryField.CLEAN_COUNT) @property - def dust_collection_count(self): + def dust_collection_count(self) -> int: return self.get(CleanSummaryField.DUST_COLLECTION_COUNT) @property @@ -845,55 +845,55 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def begin(self): + def begin(self) -> int: return self.get(CleanRecordField.BEGIN) @property - def end(self): + def end(self) -> int: return self.get(CleanRecordField.END) @property - def duration(self): + def duration(self) -> int: return self.get(CleanRecordField.DURATION) @property - def area(self): + def area(self) -> int: return self.get(CleanRecordField.AREA) @property - def error(self): + def error(self) -> int: return self.get(CleanRecordField.ERROR) @property - def complete(self): + def complete(self) -> int: return self.get(CleanRecordField.COMPLETE) @property - def start_type(self): + def start_type(self) -> int: return self.get(CleanRecordField.START_TYPE) @property - def clean_type(self): + def clean_type(self) -> int: return self.get(CleanRecordField.CLEAN_TYPE) @property - def finish_reason(self): + def finish_reason(self) -> int: return self.get(CleanRecordField.FINISH_REASON) @property - def dust_collection_status(self): + def dust_collection_status(self) -> int: return self.get(CleanRecordField.DUST_COLLECTION_STATUS) @property - def avoid_count(self): + def avoid_count(self) -> int: return self.get(CleanRecordField.AVOID_COUNT) @property - def wash_count(self): + def wash_count(self) -> int: return self.get(CleanRecordField.WASH_COUNT) @property - def map_flag(self): + def map_flag(self) -> int: return self.get(CleanRecordField.MAP_FLAG) @@ -902,35 +902,35 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def main_brush_work_time(self): + def main_brush_work_time(self) -> int: return self.get(ConsumableField.MAIN_BRUSH_WORK_TIME) @property - def side_brush_work_time(self): + def side_brush_work_time(self) -> int: return self.get(ConsumableField.SIDE_BRUSH_WORK_TIME) @property - def filter_work_time(self): + def filter_work_time(self) -> int: return self.get(ConsumableField.FILTER_WORK_TIME) @property - def filter_element_work_time(self): + def filter_element_work_time(self) -> int: return self.get(ConsumableField.FILTER_ELEMENT_WORK_TIME) @property - def sensor_dirty_time(self): + def sensor_dirty_time(self) -> int: return self.get(ConsumableField.SENSOR_DIRTY_TIME) @property - def strainer_work_times(self): + def strainer_work_times(self) -> int: return self.get(ConsumableField.STRAINER_WORK_TIMES) @property - def dust_collection_work_times(self): + def dust_collection_work_times(self) -> int: return self.get(ConsumableField.DUST_COLLECTION_WORK_TIMES) @property - def cleaning_brush_work_times(self): + def cleaning_brush_work_times(self) -> int: return self.get(ConsumableField.CLEANING_BRUSH_WORK_TIMES) @@ -977,17 +977,17 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def max_multi_map(self): + def max_multi_map(self) -> int: return self.get(MultiMapListField.MAX_MULTI_MAP) @property - def max_bak_map(self): + def max_bak_map(self) -> int: return self.get(MultiMapListField.MAX_BAK_MAP) @property - def multi_map_count(self): + def multi_map_count(self) -> int: return self.get(MultiMapListField.MULTI_MAP_COUNT) @property - def map_info(self): + def map_info(self) -> list[MultiMapsListMapInfo]: return [MultiMapsListMapInfo(map_info) for map_info in self.get(MultiMapListField.MAP_INFO)] From 16f1d5dc10123987ee480bc4696a9a80a5bbe376 Mon Sep 17 00:00:00 2001 From: Luke Date: Tue, 21 Feb 2023 17:51:48 -0500 Subject: [PATCH 3/6] chore: add typing to user_data property --- roborock/containers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/containers.py b/roborock/containers.py index 3b9f4e3..2f9b42c 100644 --- a/roborock/containers.py +++ b/roborock/containers.py @@ -327,7 +327,7 @@ def __init__(self, data: dict[str, any]) -> None: super().__init__(data) @property - def user_data(self): + def user_data(self) -> UserData: user_data = self.get("user_data") if user_data: return UserData(user_data) From eaa4dee1dca696a5817205cd4387b92ce93df0bf Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Feb 2023 08:26:43 -0500 Subject: [PATCH 4/6] fix: change to timeout from wait_for wait_for creates a task, async_timeout does the same work and avoids the task creation --- pyproject.toml | 1 + roborock/roborock_queue.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7d557ba..dac2b2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ roborock = "roborock.cli:main" python = "^3.8" click = ">=8" aiohttp = "*" +async-timeout = "*" pycryptodome = "~3.16.0" pycryptodomex = {version = "~3.16.0", markers = "sys_platform == 'darwin'"} paho-mqtt = "~1.6.1" diff --git a/roborock/roborock_queue.py b/roborock/roborock_queue.py index 6902ff1..4db21c5 100644 --- a/roborock/roborock_queue.py +++ b/roborock/roborock_queue.py @@ -1,6 +1,7 @@ import asyncio from asyncio import Queue from typing import Any +import async_timeout from roborock import RoborockException @@ -12,7 +13,10 @@ def __init__(self, protocol: int, *args): self.protocol = protocol async def async_put(self, item: tuple[Any, RoborockException | None], timeout: float | int) -> None: - return await asyncio.wait_for(self.put(item), timeout=timeout) + async with async_timeout.timeout(timeout): + await self.put(item) async def async_get(self, timeout: float | int) -> tuple[Any, RoborockException | None]: + async with async_timeout.timeout(timeout): + await self.get() return await asyncio.wait_for(self.get(), timeout=timeout) From f2b4c89500ac169e9dc021de6e250474f6f75b15 Mon Sep 17 00:00:00 2001 From: Luke Date: Fri, 24 Feb 2023 08:31:17 -0500 Subject: [PATCH 5/6] fix: removed unneeded line --- roborock/roborock_queue.py | 3 +-- tests/test_api.py | 0 2 files changed, 1 insertion(+), 2 deletions(-) create mode 100644 tests/test_api.py diff --git a/roborock/roborock_queue.py b/roborock/roborock_queue.py index 4db21c5..4b083ae 100644 --- a/roborock/roborock_queue.py +++ b/roborock/roborock_queue.py @@ -18,5 +18,4 @@ async def async_put(self, item: tuple[Any, RoborockException | None], timeout: f async def async_get(self, timeout: float | int) -> tuple[Any, RoborockException | None]: async with async_timeout.timeout(timeout): - await self.get() - return await asyncio.wait_for(self.get(), timeout=timeout) + return await self.get() diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..e69de29 From 6f907055eca9c84c2003100bd8056ad6aa8d1493 Mon Sep 17 00:00:00 2001 From: Luke Date: Mon, 27 Feb 2023 21:40:35 -0500 Subject: [PATCH 6/6] Remove circular dependenacy Accidentally introduced this when I was adding typing --- roborock/roborock_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roborock/roborock_queue.py b/roborock/roborock_queue.py index 6902ff1..e95f002 100644 --- a/roborock/roborock_queue.py +++ b/roborock/roborock_queue.py @@ -2,7 +2,7 @@ from asyncio import Queue from typing import Any -from roborock import RoborockException +from .exceptions import RoborockException class RoborockQueue(Queue):