Skip to content

Commit

Permalink
Move encryption and api functions into the base class (#277)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
bdraco and pre-commit-ci[bot] authored Dec 20, 2024
1 parent 8f3172a commit 9f939ce
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 153 deletions.
164 changes: 163 additions & 1 deletion switchbot/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections.abc import Callable
from uuid import UUID

import aiohttp
from bleak.backends.device import BLEDevice
from bleak.backends.service import BleakGATTCharacteristic, BleakGATTServiceCollection
from bleak.exc import BleakDBusError
Expand All @@ -23,7 +24,15 @@
establish_connection,
)

from ..const import DEFAULT_RETRY_COUNT, DEFAULT_SCAN_TIMEOUT
from ..api_config import SWITCHBOT_APP_API_BASE_URL, SWITCHBOT_APP_CLIENT_ID
from ..const import (
DEFAULT_RETRY_COUNT,
DEFAULT_SCAN_TIMEOUT,
SwitchbotAccountConnectionError,
SwitchbotApiError,
SwitchbotAuthenticationError,
SwitchbotModel,
)
from ..discovery import GetSwitchbotDevices
from ..models import SwitchBotAdvertisement

Expand Down Expand Up @@ -152,6 +161,35 @@ def __init__(
self._last_full_update: float = -PASSIVE_POLL_INTERVAL
self._timed_disconnect_task: asyncio.Task[None] | None = None

@classmethod
async def api_request(
cls,
session: aiohttp.ClientSession,
subdomain: str,
path: str,
data: dict = None,
headers: dict = None,
) -> dict:
url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}"
async with session.post(
url,
json=data,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as result:
if result.status > 299:
raise SwitchbotApiError(
f"Unexpected status code returned by SwitchBot API: {result.status}"
)

response = await result.json()
if response["statusCode"] != 100:
raise SwitchbotApiError(
f"{response['message']}, status code: {response['statusCode']}"
)

return response["body"]

def advertisement_changed(self, advertisement: SwitchBotAdvertisement) -> bool:
"""Check if the advertisement has changed."""
return bool(
Expand Down Expand Up @@ -666,6 +704,130 @@ def update_from_advertisement(self, advertisement: SwitchBotAdvertisement) -> No
self._set_advertisement_data(advertisement)


class SwitchbotEncryptedDevice(SwitchbotDevice):
"""A Switchbot device that uses encryption."""

def __init__(
self,
device: BLEDevice,
key_id: str,
encryption_key: str,
model: SwitchbotModel,
interface: int = 0,
**kwargs: Any,
) -> None:
"""Switchbot base class constructor for encrypted devices."""
if len(key_id) == 0:
raise ValueError("key_id is missing")
elif len(key_id) != 2:
raise ValueError("key_id is invalid")
if len(encryption_key) == 0:
raise ValueError("encryption_key is missing")
elif len(encryption_key) != 32:
raise ValueError("encryption_key is invalid")
self._key_id = key_id
self._encryption_key = bytearray.fromhex(encryption_key)
self._iv: bytes | None = None
self._cipher: bytes | None = None
self._model = model
super().__init__(device, None, interface, **kwargs)

# Old non-async method preserved for backwards compatibility
@classmethod
def retrieve_encryption_key(cls, device_mac: str, username: str, password: str):
async def async_fn():
async with aiohttp.ClientSession() as session:
return await cls.async_retrieve_encryption_key(
session, device_mac, username, password
)

return asyncio.run(async_fn())

@classmethod
async def async_retrieve_encryption_key(
cls,
session: aiohttp.ClientSession,
device_mac: str,
username: str,
password: str,
) -> dict:
"""Retrieve lock key from internal SwitchBot API."""
device_mac = device_mac.replace(":", "").replace("-", "").upper()

try:
auth_result = await cls.api_request(
session,
"account",
"account/api/v1/user/login",
{
"clientId": SWITCHBOT_APP_CLIENT_ID,
"username": username,
"password": password,
"grantType": "password",
"verifyCode": "",
},
)
auth_headers = {"authorization": auth_result["access_token"]}
except Exception as err:
raise SwitchbotAuthenticationError(f"Authentication failed: {err}") from err

try:
userinfo = await cls.api_request(
session, "account", "account/api/v1/user/userinfo", {}, auth_headers
)
if "botRegion" in userinfo and userinfo["botRegion"] != "":
region = userinfo["botRegion"]
else:
region = "us"
except Exception as err:
raise SwitchbotAccountConnectionError(
f"Failed to retrieve SwitchBot Account user details: {err}"
) from err

try:
device_info = await cls.api_request(
session,
f"wonderlabs.{region}",
"wonder/keys/v1/communicate",
{
"device_mac": device_mac,
"keyType": "user",
},
auth_headers,
)

return {
"key_id": device_info["communicationKey"]["keyId"],
"encryption_key": device_info["communicationKey"]["key"],
}
except Exception as err:
raise SwitchbotAccountConnectionError(
f"Failed to retrieve encryption key from SwitchBot Account: {err}"
) from err

@classmethod
async def verify_encryption_key(
cls,
device: BLEDevice,
key_id: str,
encryption_key: str,
model: SwitchbotModel,
**kwargs: Any,
) -> bool:
try:
device = cls(
device, key_id=key_id, encryption_key=encryption_key, model=model
)
except ValueError:
return False
try:
info = await device.get_basic_info()
except SwitchbotOperationError:
return False

return info is not None


class SwitchbotDeviceOverrideStateDuringConnection(SwitchbotBaseDevice):
"""Base Representation of a Switchbot Device.
Expand Down
145 changes: 9 additions & 136 deletions switchbot/devices/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,15 @@

from __future__ import annotations

import asyncio
import logging
import time
from typing import Any

import aiohttp
from bleak.backends.device import BLEDevice
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes

from ..api_config import SWITCHBOT_APP_API_BASE_URL, SWITCHBOT_APP_CLIENT_ID
from ..const import (
LockStatus,
SwitchbotAccountConnectionError,
SwitchbotApiError,
SwitchbotAuthenticationError,
SwitchbotModel,
)
from .device import SwitchbotDevice, SwitchbotOperationError
from ..const import LockStatus, SwitchbotModel
from .device import SwitchbotEncryptedDevice

COMMAND_HEADER = "57"
COMMAND_GET_CK_IV = f"{COMMAND_HEADER}0f2103"
Expand Down Expand Up @@ -54,7 +45,7 @@
# The return value of the command is 6 when the command is successful but the battery is low.


class SwitchbotLock(SwitchbotDevice):
class SwitchbotLock(SwitchbotEncryptedDevice):
"""Representation of a Switchbot Lock."""

def __init__(
Expand All @@ -66,141 +57,23 @@ def __init__(
model: SwitchbotModel = SwitchbotModel.LOCK,
**kwargs: Any,
) -> None:
if len(key_id) == 0:
raise ValueError("key_id is missing")
elif len(key_id) != 2:
raise ValueError("key_id is invalid")
if len(encryption_key) == 0:
raise ValueError("encryption_key is missing")
elif len(encryption_key) != 32:
raise ValueError("encryption_key is invalid")
if model not in (SwitchbotModel.LOCK, SwitchbotModel.LOCK_PRO):
raise ValueError("initializing SwitchbotLock with a non-lock model")
self._iv = None
self._cipher = None
self._key_id = key_id
self._encryption_key = bytearray.fromhex(encryption_key)
self._notifications_enabled: bool = False
self._model: SwitchbotModel = model
super().__init__(device, None, interface, **kwargs)
super().__init__(device, key_id, encryption_key, model, interface, **kwargs)

@staticmethod
@classmethod
async def verify_encryption_key(
cls,
device: BLEDevice,
key_id: str,
encryption_key: str,
model: SwitchbotModel = SwitchbotModel.LOCK,
**kwargs: Any,
) -> bool:
try:
lock = SwitchbotLock(
device, key_id=key_id, encryption_key=encryption_key, model=model
)
except ValueError:
return False
try:
lock_info = await lock.get_basic_info()
except SwitchbotOperationError:
return False

return lock_info is not None

@staticmethod
async def api_request(
session: aiohttp.ClientSession,
subdomain: str,
path: str,
data: dict = None,
headers: dict = None,
) -> dict:
url = f"https://{subdomain}.{SWITCHBOT_APP_API_BASE_URL}/{path}"
async with session.post(
url,
json=data,
headers=headers,
timeout=aiohttp.ClientTimeout(total=10),
) as result:
if result.status > 299:
raise SwitchbotApiError(
f"Unexpected status code returned by SwitchBot API: {result.status}"
)

response = await result.json()
if response["statusCode"] != 100:
raise SwitchbotApiError(
f"{response['message']}, status code: {response['statusCode']}"
)

return response["body"]

# Old non-async method preserved for backwards compatibility
@staticmethod
def retrieve_encryption_key(device_mac: str, username: str, password: str):
async def async_fn():
async with aiohttp.ClientSession() as session:
return await SwitchbotLock.async_retrieve_encryption_key(
session, device_mac, username, password
)

return asyncio.run(async_fn())

@staticmethod
async def async_retrieve_encryption_key(
session: aiohttp.ClientSession, device_mac: str, username: str, password: str
) -> dict:
"""Retrieve lock key from internal SwitchBot API."""
device_mac = device_mac.replace(":", "").replace("-", "").upper()

try:
auth_result = await SwitchbotLock.api_request(
session,
"account",
"account/api/v1/user/login",
{
"clientId": SWITCHBOT_APP_CLIENT_ID,
"username": username,
"password": password,
"grantType": "password",
"verifyCode": "",
},
)
auth_headers = {"authorization": auth_result["access_token"]}
except Exception as err:
raise SwitchbotAuthenticationError(f"Authentication failed: {err}") from err

try:
userinfo = await SwitchbotLock.api_request(
session, "account", "account/api/v1/user/userinfo", {}, auth_headers
)
if "botRegion" in userinfo and userinfo["botRegion"] != "":
region = userinfo["botRegion"]
else:
region = "us"
except Exception as err:
raise SwitchbotAccountConnectionError(
f"Failed to retrieve SwitchBot Account user details: {err}"
) from err

try:
device_info = await SwitchbotLock.api_request(
session,
f"wonderlabs.{region}",
"wonder/keys/v1/communicate",
{
"device_mac": device_mac,
"keyType": "user",
},
auth_headers,
)

return {
"key_id": device_info["communicationKey"]["keyId"],
"encryption_key": device_info["communicationKey"]["key"],
}
except Exception as err:
raise SwitchbotAccountConnectionError(
f"Failed to retrieve encryption key from SwitchBot Account: {err}"
) from err
return super().verify_encryption_key(
device, key_id, encryption_key, model, **kwargs
)

async def lock(self) -> bool:
"""Send lock command."""
Expand Down
Loading

0 comments on commit 9f939ce

Please sign in to comment.