From 9f939ce74ef43ca77e6cd0000b5ac85d534327e7 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 20 Dec 2024 09:34:56 -1000 Subject: [PATCH] Move encryption and api functions into the base class (#277) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- switchbot/devices/device.py | 164 +++++++++++++++++++++++++++++- switchbot/devices/lock.py | 145 ++------------------------ switchbot/devices/relay_switch.py | 32 +++--- 3 files changed, 188 insertions(+), 153 deletions(-) diff --git a/switchbot/devices/device.py b/switchbot/devices/device.py index 328adce..ff05fca 100644 --- a/switchbot/devices/device.py +++ b/switchbot/devices/device.py @@ -12,6 +12,7 @@ from collections.abc import Callable from uuid import UUID +import aiohttp from bleak.backends.device import BLEDevice from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection from bleak.exc import BleakDBusError @@ -23,7 +24,15 @@ establish_connection, ) -from ..const import DEFAULT_RETRY_COUNT, DEFAULT_SCAN_TIMEOUT +from ..api_config import SWITCHBOT_APP_API_BASE_URL, SWITCHBOT_APP_CLIENT_ID +from ..const import ( + DEFAULT_RETRY_COUNT, + DEFAULT_SCAN_TIMEOUT, + SwitchbotAccountConnectionError, + SwitchbotApiError, + SwitchbotAuthenticationError, + SwitchbotModel, +) from ..discovery import GetSwitchbotDevices from ..models import SwitchBotAdvertisement @@ -152,6 +161,35 @@ def __init__( self._last_full_update: float = -PASSIVE_POLL_INTERVAL self._timed_disconnect_task: asyncio.Task[None] | None = None + @classmethod + async def api_request( + cls, + session: aiohttp.ClientSession, + subdomain: str, + path: str, + data: dict = None, + headers: dict = None, + ) -> dict: + url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}" + async with session.post( + url, + json=data, + headers=headers, + timeout=aiohttp.ClientTimeout(total=10), + ) as result: + if result.status > 299: + raise SwitchbotApiError( + f"Unexpected status code returned by SwitchBot API: {result.status}" + ) + + response = await result.json() + if response["statusCode"] != 100: + raise SwitchbotApiError( + f"{response['message']}, status code: {response['statusCode']}" + ) + + return response["body"] + def advertisement_changed(self, advertisement: SwitchBotAdvertisement) -> bool: """Check if the advertisement has changed.""" return bool( @@ -666,6 +704,130 @@ def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> No self._set_advertisement_data(advertisement) +class SwitchbotEncryptedDevice(SwitchbotDevice): + """A Switchbot device that uses encryption.""" + + def __init__( + self, + device: BLEDevice, + key_id: str, + encryption_key: str, + model: SwitchbotModel, + interface: int = 0, + **kwargs: Any, + ) -> None: + """Switchbot base class constructor for encrypted devices.""" + if len(key_id) == 0: + raise ValueError("key_id is missing") + elif len(key_id) != 2: + raise ValueError("key_id is invalid") + if len(encryption_key) == 0: + raise ValueError("encryption_key is missing") + elif len(encryption_key) != 32: + raise ValueError("encryption_key is invalid") + self._key_id = key_id + self._encryption_key = bytearray.fromhex(encryption_key) + self._iv: bytes | None = None + self._cipher: bytes | None = None + self._model = model + super().__init__(device, None, interface, **kwargs) + + # Old non-async method preserved for backwards compatibility + @classmethod + def retrieve_encryption_key(cls, device_mac: str, username: str, password: str): + async def async_fn(): + async with aiohttp.ClientSession() as session: + return await cls.async_retrieve_encryption_key( + session, device_mac, username, password + ) + + return asyncio.run(async_fn()) + + @classmethod + async def async_retrieve_encryption_key( + cls, + session: aiohttp.ClientSession, + device_mac: str, + username: str, + password: str, + ) -> dict: + """Retrieve lock key from internal SwitchBot API.""" + device_mac = device_mac.replace(":", "").replace("-", "").upper() + + try: + auth_result = await cls.api_request( + session, + "account", + "account/api/v1/user/login", + { + "clientId": SWITCHBOT_APP_CLIENT_ID, + "username": username, + "password": password, + "grantType": "password", + "verifyCode": "", + }, + ) + auth_headers = {"authorization": auth_result["access_token"]} + except Exception as err: + raise SwitchbotAuthenticationError(f"Authentication failed: {err}") from err + + try: + userinfo = await cls.api_request( + session, "account", "account/api/v1/user/userinfo", {}, auth_headers + ) + if "botRegion" in userinfo and userinfo["botRegion"] != "": + region = userinfo["botRegion"] + else: + region = "us" + except Exception as err: + raise SwitchbotAccountConnectionError( + f"Failed to retrieve SwitchBot Account user details: {err}" + ) from err + + try: + device_info = await cls.api_request( + session, + f"wonderlabs.{region}", + "wonder/keys/v1/communicate", + { + "device_mac": device_mac, + "keyType": "user", + }, + auth_headers, + ) + + return { + "key_id": device_info["communicationKey"]["keyId"], + "encryption_key": device_info["communicationKey"]["key"], + } + except Exception as err: + raise SwitchbotAccountConnectionError( + f"Failed to retrieve encryption key from SwitchBot Account: {err}" + ) from err + + @classmethod + async def verify_encryption_key( + cls, + device: BLEDevice, + key_id: str, + encryption_key: str, + model: SwitchbotModel, + **kwargs: Any, + ) -> bool: + try: + device = cls( + device, key_id=key_id, encryption_key=encryption_key, model=model + ) + except ValueError: + return False + try: + info = await device.get_basic_info() + except SwitchbotOperationError: + return False + + return info is not None + + class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice): """Base Representation of a Switchbot Device. diff --git a/switchbot/devices/lock.py b/switchbot/devices/lock.py index 8c4fe31..8766af0 100644 --- a/switchbot/devices/lock.py +++ b/switchbot/devices/lock.py @@ -2,24 +2,15 @@ from __future__ import annotations -import asyncio import logging import time from typing import Any -import aiohttp from bleak.backends.device import BLEDevice from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from ..api_config import SWITCHBOT_APP_API_BASE_URL, SWITCHBOT_APP_CLIENT_ID -from ..const import ( - LockStatus, - SwitchbotAccountConnectionError, - SwitchbotApiError, - SwitchbotAuthenticationError, - SwitchbotModel, -) -from .device import SwitchbotDevice, SwitchbotOperationError +from ..const import LockStatus, SwitchbotModel +from .device import SwitchbotEncryptedDevice COMMAND_HEADER = "57" COMMAND_GET_CK_IV = f"{COMMAND_HEADER}0f2103" @@ -54,7 +45,7 @@ # The return value of the command is 6 when the command is successful but the battery is low. -class SwitchbotLock(SwitchbotDevice): +class SwitchbotLock(SwitchbotEncryptedDevice): """Representation of a Switchbot Lock.""" def __init__( @@ -66,141 +57,23 @@ def __init__( model: SwitchbotModel = SwitchbotModel.LOCK, **kwargs: Any, ) -> None: - if len(key_id) == 0: - raise ValueError("key_id is missing") - elif len(key_id) != 2: - raise ValueError("key_id is invalid") - if len(encryption_key) == 0: - raise ValueError("encryption_key is missing") - elif len(encryption_key) != 32: - raise ValueError("encryption_key is invalid") if model not in (SwitchbotModel.LOCK, SwitchbotModel.LOCK_PRO): raise ValueError("initializing SwitchbotLock with a non-lock model") - self._iv = None - self._cipher = None - self._key_id = key_id - self._encryption_key = bytearray.fromhex(encryption_key) self._notifications_enabled: bool = False - self._model: SwitchbotModel = model - super().__init__(device, None, interface, **kwargs) + super().__init__(device, key_id, encryption_key, model, interface, **kwargs) - @staticmethod + @classmethod async def verify_encryption_key( + cls, device: BLEDevice, key_id: str, encryption_key: str, model: SwitchbotModel = SwitchbotModel.LOCK, **kwargs: Any, ) -> bool: - try: - lock = SwitchbotLock( - device, key_id=key_id, encryption_key=encryption_key, model=model - ) - except ValueError: - return False - try: - lock_info = await lock.get_basic_info() - except SwitchbotOperationError: - return False - - return lock_info is not None - - @staticmethod - async def api_request( - session: aiohttp.ClientSession, - subdomain: str, - path: str, - data: dict = None, - headers: dict = None, - ) -> dict: - url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}" - async with session.post( - url, - json=data, - headers=headers, - timeout=aiohttp.ClientTimeout(total=10), - ) as result: - if result.status > 299: - raise SwitchbotApiError( - f"Unexpected status code returned by SwitchBot API: {result.status}" - ) - - response = await result.json() - if response["statusCode"] != 100: - raise SwitchbotApiError( - f"{response['message']}, status code: {response['statusCode']}" - ) - - return response["body"] - - # Old non-async method preserved for backwards compatibility - @staticmethod - def retrieve_encryption_key(device_mac: str, username: str, password: str): - async def async_fn(): - async with aiohttp.ClientSession() as session: - return await SwitchbotLock.async_retrieve_encryption_key( - session, device_mac, username, password - ) - - return asyncio.run(async_fn()) - - @staticmethod - async def async_retrieve_encryption_key( - session: aiohttp.ClientSession, device_mac: str, username: str, password: str - ) -> dict: - """Retrieve lock key from internal SwitchBot API.""" - device_mac = device_mac.replace(":", "").replace("-", "").upper() - - try: - auth_result = await SwitchbotLock.api_request( - session, - "account", - "account/api/v1/user/login", - { - "clientId": SWITCHBOT_APP_CLIENT_ID, - "username": username, - "password": password, - "grantType": "password", - "verifyCode": "", - }, - ) - auth_headers = {"authorization": auth_result["access_token"]} - except Exception as err: - raise SwitchbotAuthenticationError(f"Authentication failed: {err}") from err - - try: - userinfo = await SwitchbotLock.api_request( - session, "account", "account/api/v1/user/userinfo", {}, auth_headers - ) - if "botRegion" in userinfo and userinfo["botRegion"] != "": - region = userinfo["botRegion"] - else: - region = "us" - except Exception as err: - raise SwitchbotAccountConnectionError( - f"Failed to retrieve SwitchBot Account user details: {err}" - ) from err - - try: - device_info = await SwitchbotLock.api_request( - session, - f"wonderlabs.{region}", - "wonder/keys/v1/communicate", - { - "device_mac": device_mac, - "keyType": "user", - }, - auth_headers, - ) - - return { - "key_id": device_info["communicationKey"]["keyId"], - "encryption_key": device_info["communicationKey"]["key"], - } - except Exception as err: - raise SwitchbotAccountConnectionError( - f"Failed to retrieve encryption key from SwitchBot Account: {err}" - ) from err + return super().verify_encryption_key( + device, key_id, encryption_key, model, **kwargs + ) async def lock(self) -> bool: """Send lock command.""" diff --git a/switchbot/devices/relay_switch.py b/switchbot/devices/relay_switch.py index 2947a7b..97b071b 100644 --- a/switchbot/devices/relay_switch.py +++ b/switchbot/devices/relay_switch.py @@ -7,7 +7,7 @@ from ..const import SwitchbotModel from ..models import SwitchBotAdvertisement -from .device import SwitchbotDevice +from .device import SwitchbotEncryptedDevice _LOGGER = logging.getLogger(__name__) @@ -20,7 +20,7 @@ PASSIVE_POLL_INTERVAL = 10 * 60 -class SwitchbotRelaySwitch(SwitchbotDevice): +class SwitchbotRelaySwitch(SwitchbotEncryptedDevice): """Representation of a Switchbot relay switch 1pm.""" def __init__( @@ -32,21 +32,21 @@ def __init__( model: SwitchbotModel = SwitchbotModel.RELAY_SWITCH_1PM, **kwargs: Any, ) -> None: - if len(key_id) == 0: - raise ValueError("key_id is missing") - elif len(key_id) != 2: - raise ValueError("key_id is invalid") - if len(encryption_key) == 0: - raise ValueError("encryption_key is missing") - elif len(encryption_key) != 32: - raise ValueError("encryption_key is invalid") - self._iv = None - self._cipher = None - self._key_id = key_id - self._encryption_key = bytearray.fromhex(encryption_key) - self._model: SwitchbotModel = model self._force_next_update = False - super().__init__(device, None, interface, **kwargs) + super().__init__(device, key_id, encryption_key, model, interface, **kwargs) + + @classmethod + async def verify_encryption_key( + cls, + device: BLEDevice, + key_id: str, + encryption_key: str, + model: SwitchbotModel = SwitchbotModel.RELAY_SWITCH_1PM, + **kwargs: Any, + ) -> bool: + return super().verify_encryption_key( + device, key_id, encryption_key, model, **kwargs + ) def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> None: """Update device data from advertisement."""