Skip to content

Commit

Permalink
Merge branch 'main' into crypto_bump
Browse files Browse the repository at this point in the history
  • Loading branch information
humbertogontijo authored Mar 1, 2023
2 parents 1931073 + 311af16 commit d514ae0
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 150 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ roborock = "roborock.cli:main"
python = "^3.8"
click = ">=8"
aiohttp = "*"
async-timeout = "*"
pycryptodome = "~3.17.0"
pycryptodomex = {version = "~3.17.0", markers = "sys_platform == 'darwin'"}
paho-mqtt = "~1.6.1"
Expand Down
59 changes: 30 additions & 29 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
from asyncio import Lock
from asyncio.exceptions import TimeoutError, CancelledError
from typing import Any
from urllib.parse import urlparse

import aiohttp
Expand Down Expand Up @@ -53,31 +54,31 @@
MQTT_KEEPALIVE = 60


def md5hex(message: str):
def md5hex(message: str) -> str:
md5 = hashlib.md5()
md5.update(message.encode())
return md5.hexdigest()


def md5bin(message: str):
def md5bin(message: str) -> bytes:
md5 = hashlib.md5()
md5.update(message.encode())
return md5.digest()


def encode_timestamp(_timestamp: int):
def encode_timestamp(_timestamp: int) -> str:
hex_value = f"{_timestamp:x}".zfill(8)
return "".join(list(map(lambda idx: hex_value[idx], [5, 6, 3, 7, 1, 2, 0, 4])))


class PreparedRequest:
def __init__(self, base_url: str, base_headers: dict = None):
def __init__(self, base_url: str, base_headers: dict = None) -> None:
self.base_url = base_url
self.base_headers = base_headers or {}

async def request(
self, method: str, url: str, params=None, data=None, headers=None
):
) -> dict | list:
_url = "/".join(s.strip("/") for s in [self.base_url, url])
_headers = {**self.base_headers, **(headers or {})}
async with aiohttp.ClientSession() as session:
Expand All @@ -99,7 +100,7 @@ async def request(
class RoborockMqttClient(mqtt.Client):
_thread: threading.Thread

def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]):
def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo]) -> None:
rriot = user_data.rriot
self._mqtt_user = rriot.user
self._mqtt_domain = rriot.domain
Expand All @@ -126,11 +127,11 @@ def __init__(self, user_data: UserData, device_map: dict[str, RoborockDeviceInfo
self._last_device_msg_in = mqtt.time_func()
self._last_disconnection = mqtt.time_func()

def __del__(self):
def __del__(self) -> None:
self.sync_disconnect()

@run_in_executor()
async def on_connect(self, _client, _, __, rc, ___=None):
async def on_connect(self, _client, _, __, rc, ___=None) -> None:
connection_queue = self._waiting_queue.get(0)
if rc != mqtt.MQTT_ERR_SUCCESS:
message = f"Failed to connect (rc: {rc})"
Expand All @@ -156,7 +157,7 @@ async def on_connect(self, _client, _, __, rc, ___=None):
await connection_queue.async_put((True, None), timeout=QUEUE_TIMEOUT)

@run_in_executor()
async def on_message(self, _client, _, msg, __=None):
async def on_message(self, _client, _, msg, __=None) -> None:
try:
async with self._mutex:
self._last_device_msg_in = mqtt.time_func()
Expand Down Expand Up @@ -219,7 +220,7 @@ async def on_message(self, _client, _, msg, __=None):
_LOGGER.exception(ex)

@run_in_executor()
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None) -> None:
try:
async with self._mutex:
self._last_disconnection = mqtt.time_func()
Expand All @@ -241,28 +242,28 @@ async def on_disconnect(self, _client: mqtt.Client, _, rc, __=None):
_LOGGER.exception(ex)

@run_in_executor()
async def _async_check_keepalive(self):
async def _async_check_keepalive(self) -> None:
async with self._mutex:
now = mqtt.time_func()
if now - self._last_disconnection > self._keepalive ** 2 and now - self._last_device_msg_in > self._keepalive:
self._ping_t = self._last_device_msg_in

def _check_keepalive(self):
def _check_keepalive(self) -> None:
self._async_check_keepalive()
super()._check_keepalive()

def sync_stop_loop(self):
def sync_stop_loop(self) -> None:
if self._thread:
_LOGGER.info("Stopping mqtt loop")
super().loop_stop()

def sync_start_loop(self):
def sync_start_loop(self) -> None:
if not self._thread or not self._thread.is_alive():
self.sync_stop_loop()
_LOGGER.info("Starting mqtt loop")
super().loop_start()

def sync_disconnect(self):
def sync_disconnect(self) -> bool:
rc = mqtt.MQTT_ERR_AGAIN
if self.is_connected():
_LOGGER.info("Disconnecting from mqtt")
Expand All @@ -271,7 +272,7 @@ def sync_disconnect(self):
raise RoborockException(f"Failed to disconnect (rc:{rc})")
return rc == mqtt.MQTT_ERR_SUCCESS

def sync_connect(self):
def sync_connect(self) -> bool:
rc = mqtt.MQTT_ERR_AGAIN
self.sync_start_loop()
if not self.is_connected():
Expand All @@ -285,7 +286,7 @@ def sync_connect(self):
raise RoborockException(f"Failed to connect (rc:{rc})")
return rc == mqtt.MQTT_ERR_SUCCESS

async def _async_response(self, request_id: int, protocol_id: int = 0):
async def _async_response(self, request_id: int, protocol_id: int = 0) -> tuple[Any, RoborockException | None]:
try:
queue = RoborockQueue(protocol_id)
self._waiting_queue[request_id] = queue
Expand All @@ -298,7 +299,7 @@ async def _async_response(self, request_id: int, protocol_id: int = 0):
finally:
del self._waiting_queue[request_id]

async def async_disconnect(self):
async def async_disconnect(self) -> Any:
async with self._mutex:
disconnecting = self.sync_disconnect()
if disconnecting:
Expand All @@ -307,7 +308,7 @@ async def async_disconnect(self):
raise RoborockException(err) from err
return response

async def async_connect(self):
async def async_connect(self) -> Any:
async with self._mutex:
connecting = self.sync_connect()
if connecting:
Expand All @@ -316,10 +317,10 @@ async def async_connect(self):
raise RoborockException(err) from err
return response

async def validate_connection(self):
async def validate_connection(self) -> None:
await self.async_connect()

def _decode_msg(self, msg, device: HomeDataDevice):
def _decode_msg(self, msg, device: HomeDataDevice) -> dict[str, Any]:
if msg[0:3] != "1.0".encode():
raise RoborockException("Unknown protocol version")
crc32 = binascii.crc32(msg[0: len(msg) - 4])
Expand All @@ -344,7 +345,7 @@ def _decode_msg(self, msg, device: HomeDataDevice):
"payload": decrypted_payload,
}

def _send_msg_raw(self, device_id, protocol, timestamp, payload):
def _send_msg_raw(self, device_id, protocol, timestamp, payload) -> None:
local_key = self.device_map[device_id].device.local_key
aes_key = md5bin(encode_timestamp(timestamp) + local_key + self._salt)
cipher = AES.new(aes_key, AES.MODE_ECB)
Expand Down Expand Up @@ -438,7 +439,7 @@ async def get_consumable(self, device_id: str) -> Consumable:
if isinstance(consumable, dict):
return Consumable(consumable)

async def get_prop(self, device_id: str):
async def get_prop(self, device_id: str) -> RoborockDeviceProp:
[status, dnd_timer, clean_summary, consumable] = await asyncio.gather(
*[
self.get_status(device_id),
Expand All @@ -457,7 +458,7 @@ async def get_prop(self, device_id: str):
status, dnd_timer, clean_summary, consumable, last_clean_record
)

async def get_multi_maps_list(self, device_id):
async def get_multi_maps_list(self, device_id) -> MultiMapsList:
multi_maps_list = await self.send_command(
device_id, RoborockCommand.GET_MULTI_MAPS_LIST
)
Expand All @@ -476,7 +477,7 @@ def __init__(self, username: str, base_url=None) -> None:
self.base_url = base_url
self._device_identifier = secrets.token_urlsafe(16)

async def _get_base_url(self):
async def _get_base_url(self) -> str:
if not self.base_url:
url_request = PreparedRequest(self._default_url)
response = await url_request.request(
Expand All @@ -495,7 +496,7 @@ def _get_header_client_id(self):
md5.update(self._device_identifier.encode())
return base64.b64encode(md5.digest()).decode()

async def request_code(self):
async def request_code(self) -> None:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()
code_request = PreparedRequest(base_url, {"header_clientid": header_clientid})
Expand All @@ -512,7 +513,7 @@ async def request_code(self):
if code_response.get("code") != 200:
raise RoborockException(code_response.get("msg"))

async def pass_login(self, password: str):
async def pass_login(self, password: str) -> UserData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()

Expand All @@ -531,7 +532,7 @@ async def pass_login(self, password: str):
raise RoborockException(login_response.get("msg"))
return UserData(login_response.get("data"))

async def code_login(self, code):
async def code_login(self, code) -> UserData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()

Expand All @@ -550,7 +551,7 @@ async def code_login(self, code):
raise RoborockException(login_response.get("msg"))
return UserData(login_response.get("data"))

async def get_home_data(self, user_data: UserData):
async def get_home_data(self, user_data: UserData) -> HomeData:
base_url = await self._get_base_url()
header_clientid = self._get_header_client_id()
rriot = user_data.rriot
Expand Down
Loading

0 comments on commit d514ae0

Please sign in to comment.