diff --git a/.gitignore b/.gitignore index 5a330a9..2ea6a99 100644 --- a/.gitignore +++ b/.gitignore @@ -125,3 +125,6 @@ dmypy.json .js .vscode .direnv + +# MacOS finder stuff +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 03c3fd6..0a43852 100644 --- a/README.md +++ b/README.md @@ -100,15 +100,15 @@ from tesla_powerwall import User # Login as customer without email # The default value for the email is "" -powerwall.login("") +await powerwall.login("") #=> # Login as customer with email -powerwall.login("", "") +await powerwall.login("", "") #=> # Login with different user -powerwall.login_as(User.INSTALLER, "", "") +await powerwall.login_as(User.INSTALLER, "", "") #=> # Check if we are logged in @@ -118,7 +118,7 @@ powerwall.is_authenticated() #=> True # Logout -powerwall.logout() +await powerwall.logout() powerwall.is_authenticated() #=> False ``` @@ -133,12 +133,12 @@ from tesla_powerwall import API # Manually create API object api = API('https:///') # Perform get on 'system_status/soe' -api.get_system_status_soe() +await api.get_system_status_soe() #=> {'percentage': 97.59281925744594} # From existing powerwall api = powerwall.get_api() -api.get_system_status_soe() +await api.get_system_status_soe() ``` The `Powerwall` objet provides a wrapper around the API and exposes common methods. @@ -148,14 +148,14 @@ The `Powerwall` objet provides a wrapper around the API and exposes common metho Get charge in percent: ```python -powerwall.get_charge() +await powerwall.get_charge() #=> 97.59281925744594 (%) ``` Get charge in watt: ```python -powerwall.get_energy() +await powerwall.get_energy() #=> 14807 (Wh) ``` @@ -164,7 +164,7 @@ powerwall.get_energy() Get the capacity of your powerwall in watt: ```python -powerwall.get_capacity() +await powerwall.get_capacity() #=> 28078 (Wh) ``` @@ -173,7 +173,7 @@ powerwall.get_capacity() Get information about the battery packs that are installed: ```python -batteries = powerwall.get_batteries() +batteries = await powerwall.get_batteries() #=> [, ] batteries[0].part_number #=> "XXX-G" @@ -194,7 +194,7 @@ batteries[0].wobble_detected ### Powerwall Status ```python -status = powerwall.get_status() +status = await powerwall.get_status() #=> status.version #=> '1.49.0' @@ -209,7 +209,7 @@ status.device_type ### Sitemaster ```python -sm = powerwall.sitemaster +sm = await powerwall.get_sitemaster() #=> sm.status #=> StatusUp @@ -224,7 +224,7 @@ The sitemaster can be started and stopped using `run()` and `stop()` ### Siteinfo ```python -info = powerwall.get_site_info() +info = await powerwall.get_site_info() #=> info.site_name #=> 'Tesla Home' @@ -243,7 +243,7 @@ info.timezone ```python from tesla_powerwall import MeterType -meters = powerwall.get_meters() +meters = await powerwall.get_meters() #=> # access meter, but may return None when meter is not available @@ -266,7 +266,7 @@ Available meters are: `solar`, `site`, `load`, `battery`, `generator`, and `busw `Meter` provides different methods for checking current power supply/draw: ```python -meters = powerwall.get_meters() +meters = await powerwall.get_meters() meters.solar.get_power() #=> 0.4 (kW) meters.solar.instant_power @@ -305,7 +305,7 @@ meters.battery.get_energy_imported() You can receive more detailed information about the meters `site` and `solar`: ```python -meter_details = powerwall.get_meter_site() # or get_meter_solar() for the solar meter +meter_details = await powerwall.get_meter_site() # or get_meter_solar() for the solar meter #=> readings = meter_details.readings #=> @@ -327,7 +327,7 @@ As `MeterDetailsReadings` inherits from `MeterResponse` (which is used in `Meter ### Device Type ```python -powerwall.get_device_type() +await powerwall.get_device_type() #=> ``` @@ -336,39 +336,39 @@ powerwall.get_device_type() Get current grid status. ```python -powerwall.get_grid_status() +await powerwall.get_grid_status() #=> -powerwall.is_grid_services_active() +await powerwall.is_grid_services_active() #=> False ``` ### Operation mode ```python -powerwall.get_operation_mode() +await powerwall.get_operation_mode() #=> -powerwall.get_backup_reserve_percentage() +await powerwall.get_backup_reserve_percentage() #=> 5.000019999999999 (%) ``` ### Powerwalls Serial Numbers ```python -serials = powerwall.get_serial_numbers() +await serials = powerwall.get_serial_numbers() #=> ["...", "...", ...] ``` ### Gateway DIN ```python -din = powerwall.get_gateway_din() +await din = powerwall.get_gateway_din() #=> 4159645-02-A--TGXXX ``` ### VIN ```python -vin = powerwall.get_vin() +await vin = powerwall.get_vin() ``` ### Off-grid status (Set Island mode) @@ -378,13 +378,13 @@ Take your powerwall on- and off-grid similar to the "Take off-grid" button in th #### Set powerwall to off-grid (Islanded) ```python -powerwall.set_island_mode(IslandMode.OFFGRID) +await powerwall.set_island_mode(IslandMode.OFFGRID) ``` #### Set powerwall to off-grid (Connected) ```python -powerwall.set_island_mode(IslandMode.ONGRID) +await powerwall.set_island_mode(IslandMode.ONGRID) ``` # Development diff --git a/pyproject.toml b/pyproject.toml index 32158ed..1dfb5f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,8 @@ classifiers = [ ] keywords = ["api", "tesla", "powerwall", "tesla_powerwall"] dependencies = [ - "requests>=2.22.0" + "aiohttp>=3.7.4", + "urllib3>=1.26.18", ] [project.urls] diff --git a/tesla_powerwall/api.py b/tesla_powerwall/api.py index 4fc8827..a862691 100644 --- a/tesla_powerwall/api.py +++ b/tesla_powerwall/api.py @@ -1,9 +1,10 @@ +import aiohttp from http.client import responses from json.decoder import JSONDecodeError -from typing import Any, List, Optional +from types import TracebackType +from typing import Any, List, Optional, Type from urllib.parse import urljoin -import requests from urllib3 import disable_warnings from urllib3.exceptions import InsecureRequestWarning @@ -15,7 +16,7 @@ def __init__( self, endpoint: str, timeout: int = 10, - http_session: Optional[requests.Session] = None, + http_session: Optional[aiohttp.ClientSession] = None, verify_ssl: bool = False, disable_insecure_warning: bool = True, ) -> None: @@ -23,9 +24,20 @@ def __init__( disable_warnings(InsecureRequestWarning) self._endpoint = self._parse_endpoint(endpoint) - self._timeout = timeout - self._http_session = http_session if http_session else requests.Session() - self._http_session.verify = verify_ssl + self._timeout = aiohttp.ClientTimeout(total=timeout) + self._owns_http_session = False if http_session else True + self._ssl = None if verify_ssl else False + + if http_session: + self._owns_http_session = False + self._http_session = http_session + else: + self._owns_http_session = True + + # Allow unsafe cookies so that folks can use IP addresses in their configs + # See: https://docs.aiohttp.org/en/v3.7.3/client_advanced.html#cookie-safety + jar = aiohttp.CookieJar(unsafe=True) + self._http_session = aiohttp.ClientSession(cookie_jar=jar) @staticmethod def _parse_endpoint(endpoint: str) -> str: @@ -47,50 +59,52 @@ def _parse_endpoint(endpoint: str) -> str: return endpoint @staticmethod - def _handle_error(response: requests.Response) -> None: - if response.status_code == 404: + async def _handle_error(response: aiohttp.ClientResponse) -> None: + if response.status == 404: raise ApiError( - "The url {} returned error 404".format(response.request.path_url) + "The url {} returned error 404".format(str(response.real_url)) ) - if response.status_code == 401 or response.status_code == 403: + if response.status == 401 or response.status == 403: response_json = None try: - response_json = response.json() + response_json = await response.json() except Exception: - raise AccessDeniedError(response.request.path_url) + raise AccessDeniedError(str(response.real_url)) else: raise AccessDeniedError( - response.request.path_url, + str(response.real_url), response_json.get("error"), response_json.get("message"), ) - if response.text is not None and len(response.text) > 0: + response_text = await response.text() + if response_text is not None and len(response_text) > 0: raise ApiError( "API returned status code '{}: {}' with body: {}".format( - response.status_code, - responses.get(response.status_code), - response.text, + response.status, + responses.get(response.status), + response_text, ) ) else: raise ApiError( "API returned status code '{}: {}' ".format( - response.status_code, responses.get(response.status_code) + response.status, responses.get(response.status) ) ) - def _process_response(self, response: requests.Response) -> dict: - if response.status_code >= 400: + async def _process_response(self, response: aiohttp.ClientResponse) -> dict: + if response.status >= 400: # API returned some sort of error that must be handled - self._handle_error(response) + await self._handle_error(response) - if len(response.content) == 0: + content = await response.read() + if len(content) == 0: return {} try: - response_json = response.json() + response_json = await response.json(content_type=None) except JSONDecodeError: raise ApiError( "Error while decoding json of response: {}".format(response.text) @@ -109,46 +123,45 @@ def _process_response(self, response: requests.Response) -> dict: def url(self, path: str): return urljoin(self._endpoint, path) - def get(self, path: str, headers: dict = {}) -> Any: + async def get(self, path: str, headers: dict = {}) -> Any: try: - response = self._http_session.get( + response = await self._http_session.get( url=self.url(path), timeout=self._timeout, headers=headers, + ssl=self._ssl, ) - except ( - requests.exceptions.ConnectionError, - requests.exceptions.ReadTimeout, - ) as e: + except aiohttp.ClientConnectionError as e: raise PowerwallUnreachableError(str(e)) - return self._process_response(response) + return await self._process_response(response) - def post( + async def post( self, path: str, payload: dict, headers: dict = {}, ) -> Any: try: - response = self._http_session.post( + response = await self._http_session.post( url=self.url(path), json=payload, timeout=self._timeout, headers=headers, + ssl=self._ssl, ) - except ( - requests.exceptions.ConnectionError, - requests.exceptions.ReadTimeout, - ) as e: + except aiohttp.ClientConnectionError as e: raise PowerwallUnreachableError(str(e)) - return self._process_response(response) + return await self._process_response(response) def is_authenticated(self) -> bool: - return "AuthCookie" in self._http_session.cookies.keys() + for cookie in self._http_session.cookie_jar: + if "AuthCookie" == cookie.key: + return True + return False - def login( + async def login( self, username: str, email: str, @@ -156,7 +169,7 @@ def login( force_sm_off: bool = False, ) -> dict: # force_sm_off is referred to as 'shouldForceLogin' in the web source code - return self.post( + return await self.post( "login/Basic", { "username": username, @@ -166,97 +179,106 @@ def login( }, ) - def logout(self) -> None: + async def logout(self) -> None: if not self.is_authenticated(): raise ApiError("Must be logged in to log out") # The api unsets the auth cookie and the token is invalidated - self.get("logout") + await self.get("logout") + + async def close(self) -> None: + if self._owns_http_session: + await self._http_session.close() - def close(self) -> None: - # Close the HTTP Session - # THis method is required for testing, - # so python doesn't complain about unclosed resources - self._http_session.close() + async def __aenter__(self) -> "API": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() # Endpoints are mapped to one method by _ so they can be easily accessed - def get_system_status(self) -> dict: - return self.get("system_status") + async def get_system_status(self) -> dict: + return await self.get("system_status") - def get_system_status_soe(self) -> dict: - return self.get("system_status/soe") + async def get_system_status_soe(self) -> dict: + return await self.get("system_status/soe") - def get_meters_aggregates(self) -> dict: - return self.get("meters/aggregates") + async def get_meters_aggregates(self) -> dict: + return await self.get("meters/aggregates") - def get_sitemaster_run(self): - return self.get("sitemaster/run") + async def get_sitemaster_run(self): + return await self.get("sitemaster/run") - def get_sitemaster_stop(self): - return self.get("sitemaster/stop") + async def get_sitemaster_stop(self): + return await self.get("sitemaster/stop") - def get_sitemaster(self) -> dict: - return self.get("sitemaster") + async def get_sitemaster(self) -> dict: + return await self.get("sitemaster") - def get_status(self) -> dict: - return self.get("status") + async def get_status(self) -> dict: + return await self.get("status") - def get_customer_registration(self) -> dict: - return self.get("customer/registration") + async def get_customer_registration(self) -> dict: + return await self.get("customer/registration") - def get_powerwalls(self): - return self.get("powerwalls") + async def get_powerwalls(self): + return await self.get("powerwalls") - def get_operation(self) -> dict: - return self.get("operation") + async def get_operation(self) -> dict: + return await self.get("operation") - def get_networks(self) -> list: - return self.get("networks") + async def get_networks(self) -> list: + return await self.get("networks") - def get_phase_usage(self): - return self.get("powerwalls/phase_usages") + async def get_phase_usage(self): + return await self.get("powerwalls/phase_usages") - def post_sitemaster_run_for_commissioning(self): - return self.post("sitemaster/run_for_commissioning", payload={}) + async def post_sitemaster_run_for_commissioning(self): + return await self.post("sitemaster/run_for_commissioning", payload={}) - def get_solars(self): - return self.get("solars") + async def get_solars(self): + return await self.get("solars") - def get_config(self): - return self.get("config") + async def get_config(self): + return await self.get("config") - def get_logs(self): - return self.get("getlogs") + async def get_logs(self): + return await self.get("getlogs") - def get_meters(self) -> list: - return self.get("meters") + async def get_meters(self) -> list: + return await self.get("meters") - def get_meters_site(self) -> list: - return self.get("meters/site") + async def get_meters_site(self) -> list: + return await self.get("meters/site") - def get_meters_solar(self) -> list: - return self.get("meters/solar") + async def get_meters_solar(self) -> list: + return await self.get("meters/solar") - def get_installer(self) -> dict: - return self.get("installer") + async def get_installer(self) -> dict: + return await self.get("installer") - def get_solar_brands(self) -> List[str]: - return self.get("solars/brands") + async def get_solar_brands(self) -> List[str]: + return await self.get("solars/brands") - def get_system_update_status(self) -> dict: - return self.get("system/update/status") + async def get_system_update_status(self) -> dict: + return await self.get("system/update/status") - def get_system_status_grid_status(self) -> dict: - return self.get("system_status/grid_status") + async def get_system_status_grid_status(self) -> dict: + return await self.get("system_status/grid_status") - def get_site_info(self) -> dict: - return self.get("site_info") + async def get_site_info(self) -> dict: + return await self.get("site_info") - def get_site_info_grid_codes(self) -> list: - return self.get("site_info/grid_codes") + async def get_site_info_grid_codes(self) -> list: + return await self.get("site_info/grid_codes") - def post_site_info_site_name(self, body: dict) -> dict: - return self.post("site_info/site_name", body) + async def post_site_info_site_name(self, body: dict) -> dict: + return await self.post("site_info/site_name", body) - def post_islanding_mode(self, body: dict) -> dict: - return self.post("v2/islanding/mode", body) + async def post_islanding_mode(self, body: dict) -> dict: + return await self.post("v2/islanding/mode", body) diff --git a/tesla_powerwall/powerwall.py b/tesla_powerwall/powerwall.py index 6173888..99a9a00 100644 --- a/tesla_powerwall/powerwall.py +++ b/tesla_powerwall/powerwall.py @@ -1,6 +1,7 @@ -from typing import List, Union +from types import TracebackType +from typing import List, Union, Optional, Type -import requests +import aiohttp from .api import API from .const import DeviceType, GridStatus, IslandMode, OperationMode, User @@ -23,19 +24,19 @@ def __init__( self, endpoint: str, timeout: int = 10, - http_session: Union[requests.Session, None] = None, + http_session: Union[aiohttp.ClientSession, None] = None, verify_ssl: bool = False, disable_insecure_warning: bool = True, ): self._api = API( - endpoint, - timeout, - http_session, - verify_ssl, - disable_insecure_warning, + endpoint=endpoint, + timeout=timeout, + http_session=http_session, + verify_ssl=verify_ssl, + disable_insecure_warning=disable_insecure_warning, ) - def login_as( + async def login_as( self, user: Union[User, str], password: str, @@ -45,143 +46,153 @@ def login_as( if isinstance(user, User): user = user.value - response = self._api.login(user, email, password, force_sm_off) + response = await self._api.login(user, email, password, force_sm_off) # The api returns an auth cookie which is automatically set # so there is no need to further process the response return LoginResponse.from_dict(response) - def login( + async def login( self, password: str, email: str = "", force_sm_off: bool = False ) -> LoginResponse: - return self.login_as(User.CUSTOMER, password, email, force_sm_off) + return await self.login_as(User.CUSTOMER, password, email, force_sm_off) - def logout(self) -> None: - self._api.logout() + async def logout(self) -> None: + await self._api.logout() def is_authenticated(self) -> bool: return self._api.is_authenticated() - def run(self) -> None: - self._api.get_sitemaster_run() + async def run(self) -> None: + await self._api.get_sitemaster_run() - def stop(self) -> None: - self._api.get_sitemaster_stop() + async def stop(self) -> None: + await self._api.get_sitemaster_stop() - def get_charge(self) -> Union[float, int]: - return assert_attribute(self._api.get_system_status_soe(), "percentage", "soe") + async def get_charge(self) -> Union[float, int]: + return assert_attribute( + await self._api.get_system_status_soe(), "percentage", "soe" + ) - def get_energy(self) -> int: + async def get_energy(self) -> int: return assert_attribute( - self._api.get_system_status(), + await self._api.get_system_status(), "nominal_energy_remaining", "system_status", ) - def get_sitemaster(self) -> SiteMasterResponse: - return SiteMasterResponse.from_dict(self._api.get_sitemaster()) + async def get_sitemaster(self) -> SiteMasterResponse: + return SiteMasterResponse.from_dict(await self._api.get_sitemaster()) - def get_meters(self) -> MetersAggregatesResponse: - return MetersAggregatesResponse.from_dict(self._api.get_meters_aggregates()) + async def get_meters(self) -> MetersAggregatesResponse: + return MetersAggregatesResponse.from_dict( + await self._api.get_meters_aggregates() + ) - def get_meter_site(self) -> MeterDetailsResponse: - meter_response = self._api.get_meters_site() + async def get_meter_site(self) -> MeterDetailsResponse: + meter_response = await self._api.get_meters_site() if meter_response is None or len(meter_response) == 0: raise ApiError("The powerwall returned no values for the site meter") return MeterDetailsResponse.from_dict(meter_response[0]) - def get_meter_solar(self) -> MeterDetailsResponse: - meter_response = self._api.get_meters_solar() + async def get_meter_solar(self) -> MeterDetailsResponse: + meter_response = await self._api.get_meters_solar() if meter_response is None or len(meter_response) == 0: raise ApiError("The powerwall returned no values for the solar meter") return MeterDetailsResponse.from_dict(meter_response[0]) - def get_grid_status(self) -> GridStatus: + async def get_grid_status(self) -> GridStatus: """Returns the current grid status.""" status = assert_attribute( - self._api.get_system_status_grid_status(), + await self._api.get_system_status_grid_status(), "grid_status", "grid_status", ) return GridStatus(status) - def get_capacity(self) -> float: + async def get_capacity(self) -> float: return assert_attribute( - self._api.get_system_status(), + await self._api.get_system_status(), "nominal_full_pack_energy", "system_status", ) - def get_batteries(self) -> List[BatteryResponse]: + async def get_batteries(self) -> List[BatteryResponse]: batteries = assert_attribute( - self._api.get_system_status(), "battery_blocks", "system_status" + await self._api.get_system_status(), "battery_blocks", "system_status" ) return [BatteryResponse.from_dict(battery) for battery in batteries] - def is_grid_services_active(self) -> bool: + async def is_grid_services_active(self) -> bool: return assert_attribute( - self._api.get_system_status_grid_status(), + await self._api.get_system_status_grid_status(), "grid_services_active", "grid_status", ) - def get_site_info(self) -> SiteInfoResponse: + async def get_site_info(self) -> SiteInfoResponse: """Returns information about the powerwall site""" - return SiteInfoResponse.from_dict(self._api.get_site_info()) + return SiteInfoResponse.from_dict(await self._api.get_site_info()) - def set_site_name(self, site_name: str) -> dict: - return self._api.post_site_info_site_name({"site_name": site_name}) + async def set_site_name(self, site_name: str) -> dict: + return await self._api.post_site_info_site_name({"site_name": site_name}) - def get_status(self) -> PowerwallStatusResponse: - return PowerwallStatusResponse.from_dict(self._api.get_status()) + async def get_status(self) -> PowerwallStatusResponse: + return PowerwallStatusResponse.from_dict(await self._api.get_status()) - def get_device_type(self) -> DeviceType: + async def get_device_type(self) -> DeviceType: """Returns the device type of the powerwall""" - return self.get_status().device_type + return (await self.get_status()).device_type - def get_serial_numbers(self) -> List[str]: + async def get_serial_numbers(self) -> List[str]: powerwalls = assert_attribute( - self._api.get_powerwalls(), "powerwalls", "powerwalls" + await self._api.get_powerwalls(), "powerwalls", "powerwalls" ) return [ assert_attribute(powerwall, "PackageSerialNumber") for powerwall in powerwalls ] - def get_gateway_din(self) -> str: + async def get_gateway_din(self) -> str: """Return the gateway din.""" - return assert_attribute(self._api.get_powerwalls(), "gateway_din", "powerwalls") + return assert_attribute( + await self._api.get_powerwalls(), "gateway_din", "powerwalls" + ) - def get_operation_mode(self) -> OperationMode: + async def get_operation_mode(self) -> OperationMode: operation_mode = assert_attribute( - self._api.get_operation(), "real_mode", "operation" + await self._api.get_operation(), "real_mode", "operation" ) return OperationMode(operation_mode) - def get_backup_reserve_percentage(self) -> float: + async def get_backup_reserve_percentage(self) -> float: return assert_attribute( - self._api.get_operation(), "backup_reserve_percent", "operation" + await self._api.get_operation(), "backup_reserve_percent", "operation" ) - def get_solars(self) -> List[SolarResponse]: - return [SolarResponse.from_dict(solar) for solar in self._api.get_solars()] + async def get_solars(self) -> List[SolarResponse]: + return [ + SolarResponse.from_dict(solar) for solar in await self._api.get_solars() + ] - def get_vin(self) -> str: - return assert_attribute(self._api.get_config(), "vin", "config") + async def get_vin(self) -> str: + return assert_attribute(await self._api.get_config(), "vin", "config") - def set_island_mode(self, mode: IslandMode) -> IslandMode: + async def set_island_mode(self, mode: IslandMode) -> IslandMode: return IslandMode( assert_attribute( - self._api.post_islanding_mode({"island_mode": mode.value}), + await self._api.post_islanding_mode({"island_mode": mode.value}), "island_mode", ) ) - def get_version(self) -> str: - version_str = assert_attribute(self._api.get_status(), "version", "status") + async def get_version(self) -> str: + version_str = assert_attribute( + await self._api.get_status(), "version", "status" + ) return version_str.split(" ")[ 0 ] # newer versions include a sha trailer '21.44.1 c58c2df3' @@ -189,5 +200,16 @@ def get_version(self) -> str: def get_api(self) -> API: return self._api - def close(self) -> None: - self._api.close() + async def close(self) -> None: + await self._api.close() + + async def __aenter__(self) -> "Powerwall": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() diff --git a/tests/integration/test_powerwall.py b/tests/integration/test_powerwall.py index 2fe46e3..e1dc035 100644 --- a/tests/integration/test_powerwall.py +++ b/tests/integration/test_powerwall.py @@ -1,5 +1,6 @@ +import aiohttp +import asyncio import unittest -from time import sleep from tesla_powerwall import GridStatus, IslandMode, MeterType, Powerwall from tesla_powerwall.responses import ( @@ -12,23 +13,25 @@ from tests.integration import POWERWALL_IP, POWERWALL_PASSWORD -class TestPowerwall(unittest.TestCase): - def setUp(self) -> None: +class TestPowerwall(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): self.powerwall = Powerwall(POWERWALL_IP) - self.powerwall.login(POWERWALL_PASSWORD) + await self.powerwall.login(POWERWALL_PASSWORD) + assert self.powerwall.is_authenticated() - def tearDown(self) -> None: - self.powerwall.close() + async def asyncTearDown(self): + await self.powerwall.close() + await self.http_session.close() - def test_get_charge(self) -> None: - charge = self.powerwall.get_charge() + async def test_get_charge(self) -> None: + charge = await self.powerwall.get_charge() if charge < 100: self.assertIsInstance(charge, float) else: self.assertEqual(charge, 100) - def test_get_meters(self) -> None: - meters = self.powerwall.get_meters() + async def test_get_meters(self) -> None: + meters = await self.powerwall.get_meters() self.assertIsInstance(meters, MetersAggregatesResponse) self.assertIsInstance(meters.get_meter(MeterType.BATTERY), MeterResponse) @@ -49,8 +52,8 @@ def test_get_meters(self) -> None: self.assertIsInstance(meter.is_drawing_from(), bool) self.assertIsInstance(meter.is_sending_to(), bool) - def test_sitemaster(self) -> None: - sitemaster = self.powerwall.get_sitemaster() + async def test_sitemaster(self) -> None: + sitemaster = await self.powerwall.get_sitemaster() self.assertIsInstance(sitemaster, SiteMasterResponse) @@ -59,8 +62,8 @@ def test_sitemaster(self) -> None: sitemaster.is_connected_to_tesla sitemaster.is_power_supply_mode - def test_site_info(self) -> None: - site_info = self.powerwall.get_site_info() + async def test_site_info(self) -> None: + site_info = await self.powerwall.get_site_info() self.assertIsInstance(site_info, SiteInfoResponse) @@ -68,14 +71,14 @@ def test_site_info(self) -> None: site_info.site_name site_info.timezone - def test_capacity(self) -> None: - self.assertIsInstance(self.powerwall.get_capacity(), int) + async def test_capacity(self) -> None: + self.assertIsInstance(await self.powerwall.get_capacity(), int) - def test_energy(self) -> None: - self.assertIsInstance(self.powerwall.get_energy(), int) + async def test_energy(self) -> None: + self.assertIsInstance(await self.powerwall.get_energy(), int) - def test_batteries(self) -> None: - batteries = self.powerwall.get_batteries() + async def test_batteries(self) -> None: + batteries = await self.powerwall.get_batteries() self.assertGreater(len(batteries), 0) for battery in batteries: battery.wobble_detected @@ -86,41 +89,41 @@ def test_batteries(self) -> None: battery.part_number battery.serial_number - def test_grid_status(self) -> None: - grid_status = self.powerwall.get_grid_status() + async def test_grid_status(self) -> None: + grid_status = await self.powerwall.get_grid_status() self.assertIsInstance(grid_status, GridStatus) - def test_status(self) -> None: - status = self.powerwall.get_status() + async def test_status(self) -> None: + status = await self.powerwall.get_status() self.assertIsInstance(status, PowerwallStatusResponse) status.up_time_seconds status.start_time status.version - def test_islanding(self) -> None: - initial_grid_status = self.powerwall.get_grid_status() + async def test_islanding(self) -> None: + initial_grid_status = await self.powerwall.get_grid_status() self.assertIsInstance(initial_grid_status, GridStatus) if initial_grid_status == GridStatus.CONNECTED: - self.go_offline() - self.go_online() + await self.go_offline() + await self.go_online() elif initial_grid_status == GridStatus.ISLANDED: - self.go_offline() - self.go_online() + await self.go_offline() + await self.go_online() - def go_offline(self) -> None: - observedIslandMode = self.powerwall.set_island_mode(IslandMode.OFFGRID) + async def go_offline(self) -> None: + observedIslandMode = await self.powerwall.set_island_mode(IslandMode.OFFGRID) self.assertEqual(observedIslandMode, IslandMode.OFFGRID) - self.wait_until_grid_status(GridStatus.ISLANDED) - self.assertEqual(self.powerwall.get_grid_status(), GridStatus.ISLANDED) + await self.wait_until_grid_status(GridStatus.ISLANDED) + self.assertEqual(await self.powerwall.get_grid_status(), GridStatus.ISLANDED) - def go_online(self) -> None: - observedIslandMode = self.powerwall.set_island_mode(IslandMode.ONGRID) + async def go_online(self) -> None: + observedIslandMode = await self.powerwall.set_island_mode(IslandMode.ONGRID) self.assertEqual(observedIslandMode, IslandMode.ONGRID) - self.wait_until_grid_status(GridStatus.CONNECTED) - self.assertEqual(self.powerwall.get_grid_status(), GridStatus.CONNECTED) + await self.wait_until_grid_status(GridStatus.CONNECTED) + self.assertEqual(await self.powerwall.get_grid_status(), GridStatus.CONNECTED) - def wait_until_grid_status( + async def wait_until_grid_status( self, expectedStatus: GridStatus, sleepTime: int = 1, @@ -130,10 +133,10 @@ def wait_until_grid_status( observedStatus: GridStatus while cycles < maxCycles: - observedStatus = self.powerwall.get_grid_status() + observedStatus = await self.powerwall.get_grid_status() if observedStatus == expectedStatus: break - sleep(sleepTime) + await asyncio.sleep(sleepTime) cycles = cycles + 1 self.assertEqual(observedStatus, expectedStatus) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index f121edb..668861a 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,7 +1,10 @@ import json from pathlib import Path -ENDPOINT = "https://1.1.1.1/api/" +ENDPOINT_SCHEME = "https://" +ENDPOINT_HOST = "1.1.1.1" +ENDPOINT_PATH = "/api/" +ENDPOINT = f"{ENDPOINT_SCHEME}{ENDPOINT_HOST}{ENDPOINT_PATH}" FIXTURE_BASE_PATH = Path("tests/unit/fixtures") diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index 3558cb9..4da39b7 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -1,17 +1,30 @@ -import json import unittest -import requests -import responses -from responses import GET, Response, add +import aiohttp +import aresponses +import json + +from tesla_powerwall import API, AccessDeniedError, ApiError, PowerwallUnreachableError +from tesla_powerwall.const import User +from tests.unit import ( + ENDPOINT_HOST, + ENDPOINT_PATH, + ENDPOINT, +) + -from tesla_powerwall import API, AccessDeniedError, ApiError -from tests.unit import ENDPOINT +class TestAPI(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.aresponses = aresponses.ResponsesMockServer() + await self.aresponses.__aenter__() + self.session = aiohttp.ClientSession() + self.api = API(ENDPOINT, http_session=self.session) -class TestAPI(unittest.TestCase): - def setUp(self): - self.api = API(ENDPOINT) + async def asyncTearDown(self): + await self.api.close() + await self.session.close() + await self.aresponses.__aexit__(None, None, None) def test_parse_endpoint(self): test_endpoints = [ @@ -24,76 +37,176 @@ def test_parse_endpoint(self): for test_endpoint in test_endpoints: self.assertEqual(self.api._parse_endpoint(test_endpoint), ENDPOINT) - @responses.activate - def test_process_response(self): - res = requests.Response() - res.request = requests.Request(method="GET", url=f"{ENDPOINT}test").prepare() - res.status_code = 401 - with self.assertRaises(AccessDeniedError): - self.api._process_response(res) - - res.status_code = 404 - with self.assertRaises(ApiError): - self.api._process_response(res) - - res.status_code = 502 - with self.assertRaises(ApiError): - self.api._process_response(res) - - res.status_code = 200 - res._content = b'{"error": "test_error"}' - with self.assertRaises(ApiError): - self.api._process_response(res) - - res._content = b'{invalid_json"' - with self.assertRaises(ApiError): - self.api._process_response(res) - - res._content = b"{}" - self.assertEqual(self.api._process_response(res), {}) - - res._content = b'{"response": "ok"}' - self.assertEqual(self.api._process_response(res), {"response": "ok"}) - - @responses.activate - def test_get(self): - add(Response(GET, url=f"{ENDPOINT}test_get", json={"test_get": True})) - - self.assertEqual(self.api.get("test_get"), {"test_get": True}) - - @responses.activate - def test_post(self): - def post_callback(request): - resp_body = {"test_post": True} - headers = {} - return (200, headers, json.dumps(resp_body)) - - responses.add_callback( - responses.POST, - url=f"{ENDPOINT}test_post", - callback=post_callback, - content_type="application/json", + async def test_process_response(self): + status = 0 + text = None + + def response_handler(request): + return self.aresponses.Response(status=status, text=text) + + self.aresponses.add( + ENDPOINT_HOST, + f"{ENDPOINT_PATH}test", + "GET", + response_handler, + repeat=self.aresponses.INFINITY, + ) + + status = 401 + async with self.session.get(f"{ENDPOINT}test") as response: + with self.assertRaises(AccessDeniedError): + await self.api._process_response(response) + + status = 404 + async with self.session.get(f"{ENDPOINT}test") as response: + with self.assertRaises(ApiError): + await self.api._process_response(response) + + status = 502 + async with self.session.get(f"{ENDPOINT}test") as response: + with self.assertRaises(ApiError): + await self.api._process_response(response) + + status = 200 + text = '{"error": "test_error"}' + async with self.session.get(f"{ENDPOINT}test") as response: + with self.assertRaises(ApiError): + await self.api._process_response(response) + + status = 200 + text = '{invalid_json"' + async with self.session.get(f"{ENDPOINT}test") as response: + with self.assertRaises(ApiError): + await self.api._process_response(response) + + status = 200 + text = "{}" + async with self.session.get(f"{ENDPOINT}test") as response: + self.assertEqual(await self.api._process_response(response), {}) + + status = 200 + text = '{"response": "ok"}' + async with self.session.get(f"{ENDPOINT}test") as response: + self.assertEqual( + await self.api._process_response(response), {"response": "ok"} + ) + + async def test_get(self): + self.aresponses.add( + ENDPOINT_HOST, + f"{ENDPOINT_PATH}test_get", + "GET", + self.aresponses.Response(text='{"test_get": true}'), ) - resp = self.api.post("test_post", {"test": True}) + self.assertEqual(await self.api.get("test_get"), {"test_get": True}) + self.aresponses.assert_plan_strictly_followed() + + async def test_post(self): + self.aresponses.add( + ENDPOINT_HOST, + f"{ENDPOINT_PATH}test_post", + "POST", + self.aresponses.Response( + text='{"test_post": true}', headers={"Content-Type": "application/json"} + ), + ) + + resp = await self.api.post("test_post", {"test": True}) self.assertIsInstance(resp, dict) self.assertEqual(resp, {"test_post": True}) - def test_is_authenticated(self): - api = API(ENDPOINT) - self.assertEqual(api.is_authenticated(), False) + self.aresponses.assert_plan_strictly_followed() + + async def test_is_authenticated(self): + self.assertEqual(self.api.is_authenticated(), False) + + self.session.cookie_jar.update_cookies(cookies={"AuthCookie": "foo"}) + self.assertEqual(self.api.is_authenticated(), True) def test_url(self): self.assertEqual(self.api.url("test"), ENDPOINT + "test") - @responses.activate - def test_logout(self): - add( - Response(GET, url=f"{ENDPOINT}logout"), - body="", - content_type="application/json", - ) - self.api._http_session.cookies.set("AuthCookie", "foo") + async def test_login(self): + jar = aiohttp.CookieJar(unsafe=True) + async with aiohttp.ClientSession(cookie_jar=jar) as http_session: + async with API(ENDPOINT, http_session=http_session) as api: + username = User.CUSTOMER.value + password = "password" + email = "email@email.com" + + async def response_handler(request) -> aresponses.Response: + request_json = await request.json() + + self.assertEqual(request_json["username"], username) + self.assertEqual(request_json["password"], password) + self.assertEqual(request_json["email"], email) + + login_response = self.aresponses.Response( + text=json.dumps( + { + "email": request_json["email"], + "firstname": "Tesla", + "lastname": "Energy", + "roles": ["Home_Owner"], + "token": "x4jbH...XMP8w==", + "provider": "Basic", + "loginTime": "2023-03-25T13:10:48.9029581+01:00", + } + ), + headers={"Content-Type": "application/json"}, + ) + login_response.set_cookie("AuthCookie", "foo") + return login_response + + self.aresponses.add( + ENDPOINT_HOST, + f"{ENDPOINT_PATH}login/Basic", + "POST", + response_handler, + ) + + await api.login(username=username, email=email, password=password) + + self.aresponses.add( + ENDPOINT_HOST, + f"{ENDPOINT_PATH}logout", + "GET", + self.aresponses.Response( + text="", headers={"Content-Type": "application/json"} + ), + ) + + await api.logout() + + self.aresponses.assert_plan_strictly_followed() + + async def test_close(self): + api_session = None + async with API(ENDPOINT) as api: + api_session = api._http_session + self.assertFalse(api_session.closed) + self.assertTrue(api_session.closed) + + async with aiohttp.ClientSession() as session: + async with API(ENDPOINT, http_session=session) as api: + api_session = api._http_session + self.assertFalse(api_session.closed) + + self.assertFalse(api_session.closed) + self.assertTrue(api_session.closed) - self.api.logout() + api = API(ENDPOINT) + api_session = api._http_session + self.assertFalse(api_session.closed) + await api.close() + self.assertTrue(api_session.closed) + + async with aiohttp.ClientSession() as session: + api_session = session + api = API(ENDPOINT, http_session=session) + self.assertFalse(api_session.closed) + await api.close() + self.assertFalse(api_session.closed) + self.assertTrue(api_session.closed) diff --git a/tests/unit/test_powerwall.py b/tests/unit/test_powerwall.py index 381fe2d..a8ce631 100644 --- a/tests/unit/test_powerwall.py +++ b/tests/unit/test_powerwall.py @@ -1,9 +1,10 @@ +import aiohttp +import aresponses import datetime +import json +from typing import Optional, Union import unittest -import responses -from responses import GET, Response, add - from tesla_powerwall import ( API, DeviceType, @@ -23,6 +24,8 @@ ) from tesla_powerwall.const import OperationMode from tests.unit import ( + ENDPOINT_HOST, + ENDPOINT_PATH, ENDPOINT, GRID_STATUS_RESPONSE, ISLANDING_MODE_OFFGRID_RESPONSE, @@ -39,52 +42,57 @@ ) -class TestPowerWall(unittest.TestCase): - def setUp(self): +class TestPowerWall(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + self.aresponses = aresponses.ResponsesMockServer() + await self.aresponses.__aenter__() + self.powerwall = Powerwall(ENDPOINT) + async def asyncTearDown(self): + await self.powerwall.close() + await self.aresponses.__aexit__(None, None, None) + def test_get_api(self): self.assertIsInstance(self.powerwall.get_api(), API) - @responses.activate - def test_get_charge(self): - add( - Response( - GET, - url=f"{ENDPOINT}system_status/soe", - json={"percentage": 53.123423}, - ) - ) - self.assertEqual(self.powerwall.get_charge(), 53.123423) - - @responses.activate - def test_get_sitemaster(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}sitemaster", - json=SITEMASTER_RESPONSE, - ) + def add_response( + self, + path: str, + method: str = "GET", + content_type: str = "application/json", + body: Optional[Union[str, dict]] = None, + ): + self.aresponses.add( + ENDPOINT_HOST, + f"{ENDPOINT_PATH}{path}", + method, + self.aresponses.Response( + headers={"Content-Type": content_type}, + text=json.dumps(body), + ), ) - sitemaster = self.powerwall.get_sitemaster() + async def test_get_charge(self): + self.add_response("system_status/soe", body={"percentage": 53.123423}) + self.assertEqual(await self.powerwall.get_charge(), 53.123423) + self.aresponses.assert_plan_strictly_followed() + + async def test_get_sitemaster(self): + self.add_response("sitemaster", body=SITEMASTER_RESPONSE) + + sitemaster = await self.powerwall.get_sitemaster() self.assertIsInstance(sitemaster, SiteMasterResponse) self.assertEqual(sitemaster.status, "StatusUp") self.assertEqual(sitemaster.is_running, True) self.assertEqual(sitemaster.is_connected_to_tesla, True) self.assertEqual(sitemaster.is_power_supply_mode, False) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_meters(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}meters/aggregates", - json=METERS_AGGREGATES_RESPONSE, - ) - ) - meters = self.powerwall.get_meters() + async def test_get_meters(self): + self.add_response("meters/aggregates", body=METERS_AGGREGATES_RESPONSE) + meters = await self.powerwall.get_meters() self.assertIsInstance(meters, MetersAggregatesResponse) self.assertListEqual( list(meters.meters.keys()), @@ -100,17 +108,11 @@ def test_get_meters(self): self.assertIsNone(meters.get_meter(MeterType.GENERATOR)) with self.assertRaises(MeterNotAvailableError): meters.generator + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_meter_site(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}meters/site", - json=METER_SITE_RESPONSE, - ) - ) - meter = self.powerwall.get_meter_site() + async def test_get_meter_site(self): + self.add_response("meters/site", body=METER_SITE_RESPONSE) + meter = await self.powerwall.get_meter_site() self.assertIsInstance(meter, MeterDetailsResponse) self.assertEqual(meter.location, MeterType.SITE) readings = meter.readings @@ -122,17 +124,11 @@ def test_get_meter_site(self): self.assertEqual(readings.instant_power, -18.00000076368451) self.assertEqual(readings.get_power(), -0.0) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_meter_solar(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}meters/solar", - json=METER_SOLAR_RESPONSE, - ) - ) - meter = self.powerwall.get_meter_solar() + async def test_get_meter_solar(self): + self.add_response("meters/solar", body=METER_SOLAR_RESPONSE) + meter = await self.powerwall.get_meter_solar() self.assertIsInstance(meter, MeterDetailsResponse) self.assertEqual(meter.location, MeterType.SOLAR) readings = meter.readings @@ -141,17 +137,11 @@ def test_get_meter_solar(self): self.assertIsInstance(readings.v_l1n, float) self.assertIsNone(readings.v_l2n) self.assertIsNone(readings.v_l3n) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_is_sending(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}meters/aggregates", - json=METERS_AGGREGATES_RESPONSE, - ) - ) - meters = self.powerwall.get_meters() + async def test_is_sending(self): + self.add_response("meters/aggregates", body=METERS_AGGREGATES_RESPONSE) + meters = await self.powerwall.get_meters() self.assertEqual(meters.get_meter(MeterType.SOLAR).is_sending_to(), False) self.assertEqual(meters.get_meter(MeterType.SOLAR).is_active(), True) self.assertEqual(meters.get_meter(MeterType.SOLAR).is_drawing_from(), True) @@ -159,48 +149,30 @@ def test_is_sending(self): self.assertEqual(meters.get_meter(MeterType.LOAD).is_sending_to(), True) self.assertEqual(meters.get_meter(MeterType.LOAD).is_drawing_from(), False) self.assertEqual(meters.get_meter(MeterType.LOAD).is_active(), True) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_grid_status(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}system_status/grid_status", - json=GRID_STATUS_RESPONSE, - ) - ) - grid_status = self.powerwall.get_grid_status() + async def test_get_grid_status(self): + self.add_response("system_status/grid_status", body=GRID_STATUS_RESPONSE) + grid_status = await self.powerwall.get_grid_status() self.assertEqual(grid_status, GridStatus.CONNECTED) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_is_grid_services_active(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}system_status/grid_status", - json=GRID_STATUS_RESPONSE, - ) - ) - self.assertEqual(self.powerwall.is_grid_services_active(), False) - - @responses.activate - def test_get_site_info(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}site_info", - json=SITE_INFO_RESPONSE, - ) - ) - site_info = self.powerwall.get_site_info() + async def test_is_grid_services_active(self): + self.add_response("system_status/grid_status", body=GRID_STATUS_RESPONSE) + self.assertEqual(await self.powerwall.is_grid_services_active(), False) + self.aresponses.assert_plan_strictly_followed() + + async def test_get_site_info(self): + self.add_response("site_info", body=SITE_INFO_RESPONSE) + site_info = await self.powerwall.get_site_info() self.assertEqual(site_info.nominal_system_energy, 27) self.assertEqual(site_info.site_name, "test") self.assertEqual(site_info.timezone, "Europe/Berlin") + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_status(self): - add(Response(responses.GET, url=f"{ENDPOINT}status", json=STATUS_RESPONSE)) - status = self.powerwall.get_status() + async def test_get_status(self): + self.add_response("status", body=STATUS_RESPONSE) + status = await self.powerwall.get_status() self.assertEqual( status.up_time_seconds, datetime.timedelta(seconds=61891, microseconds=214751), @@ -219,82 +191,55 @@ def test_get_status(self): ) self.assertEqual(status.device_type, DeviceType.GW1) self.assertEqual(status.version, "1.50.1 c58c2df3") + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_device_type(self): - add(Response(responses.GET, url=f"{ENDPOINT}status", json=STATUS_RESPONSE)) - device_type = self.powerwall.get_device_type() + async def test_get_device_type(self): + self.add_response("status", body=STATUS_RESPONSE) + device_type = await self.powerwall.get_device_type() self.assertIsInstance(device_type, DeviceType) self.assertEqual(device_type, DeviceType.GW1) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_serial_numbers(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}powerwalls", - json=POWERWALLS_RESPONSE, - ) - ) - serial_numbers = self.powerwall.get_serial_numbers() + async def test_get_serial_numbers(self): + self.add_response("powerwalls", body=POWERWALLS_RESPONSE) + serial_numbers = await self.powerwall.get_serial_numbers() self.assertEqual(serial_numbers, ["SerialNumber1", "SerialNumber2"]) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_gateway_din(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}powerwalls", - json=POWERWALLS_RESPONSE, - ) - ) - gateway_din = self.powerwall.get_gateway_din() + async def test_get_gateway_din(self): + self.add_response("powerwalls", body=POWERWALLS_RESPONSE) + gateway_din = await self.powerwall.get_gateway_din() self.assertEqual(gateway_din, "gateway_din") + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_backup_reserved_percentage(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}operation", - json=OPERATION_RESPONSE, - ) - ) + async def test_get_backup_reserved_percentage(self): + self.add_response("operation", body=OPERATION_RESPONSE) self.assertEqual( - self.powerwall.get_backup_reserve_percentage(), 5.000019999999999 + await self.powerwall.get_backup_reserve_percentage(), 5.000019999999999 ) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_operation_mode(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}operation", - json=OPERATION_RESPONSE, - ) - ) + async def test_get_operation_mode(self): + self.add_response("operation", body=OPERATION_RESPONSE) self.assertEqual( - self.powerwall.get_operation_mode(), OperationMode.SELF_CONSUMPTION + await self.powerwall.get_operation_mode(), OperationMode.SELF_CONSUMPTION ) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_get_version(self): - add(Response(responses.GET, url=f"{ENDPOINT}status", json=STATUS_RESPONSE)) - self.assertEqual(self.powerwall.get_version(), "1.50.1") - - @responses.activate - def test_system_status(self): - add( - Response( - responses.GET, - url=f"{ENDPOINT}system_status", - json=SYSTEM_STATUS_RESPONSE, - ) - ) - self.assertEqual(self.powerwall.get_capacity(), 28078) - self.assertEqual(self.powerwall.get_energy(), 14807) + async def test_get_version(self): + self.add_response("status", body=STATUS_RESPONSE) + self.assertEqual(await self.powerwall.get_version(), "1.50.1") + self.aresponses.assert_plan_strictly_followed() + + async def test_system_status(self): + self.add_response("system_status", body=SYSTEM_STATUS_RESPONSE) + self.assertEqual(await self.powerwall.get_capacity(), 28078) - batteries = self.powerwall.get_batteries() + self.add_response("system_status", body=SYSTEM_STATUS_RESPONSE) + self.assertEqual(await self.powerwall.get_energy(), 14807) + + self.add_response("system_status", body=SYSTEM_STATUS_RESPONSE) + batteries = await self.powerwall.get_batteries() self.assertEqual(len(batteries), 2) self.assertEqual(batteries[0].part_number, "XXX-G") self.assertEqual(batteries[0].serial_number, "TGXXX") @@ -303,32 +248,25 @@ def test_system_status(self): self.assertEqual(batteries[0].energy_charged, 5525740) self.assertEqual(batteries[0].energy_discharged, 4659550) self.assertEqual(batteries[0].wobble_detected, False) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_islanding_mode_offgrid(self): - add( - Response( - responses.POST, - url=f"{ENDPOINT}v2/islanding/mode", - json=ISLANDING_MODE_OFFGRID_RESPONSE, - ) + async def test_islanding_mode_offgrid(self): + self.add_response( + "v2/islanding/mode", method="POST", body=ISLANDING_MODE_OFFGRID_RESPONSE ) - mode = self.powerwall.set_island_mode(IslandMode.OFFGRID) + mode = await self.powerwall.set_island_mode(IslandMode.OFFGRID) self.assertEqual(mode, IslandMode.OFFGRID) + self.aresponses.assert_plan_strictly_followed() - @responses.activate - def test_islanding_mode_ongrid(self): - add( - Response( - responses.POST, - url=f"{ENDPOINT}v2/islanding/mode", - json=ISLANDING_MODE_ONGRID_RESPONSE, - ) + async def test_islanding_mode_ongrid(self): + self.add_response( + "v2/islanding/mode", method="POST", body=ISLANDING_MODE_ONGRID_RESPONSE ) - mode = self.powerwall.set_island_mode(IslandMode.ONGRID) + mode = await self.powerwall.set_island_mode(IslandMode.ONGRID) self.assertEqual(mode, IslandMode.ONGRID) + self.aresponses.assert_plan_strictly_followed() def test_helpers(self): resp = {"a": 1} @@ -339,3 +277,31 @@ def test_helpers(self): assert_attribute(resp, "test", "test") self.assertEqual(convert_to_kw(2500, -1), 2.5) + + async def test_close(self): + api_session = None + async with Powerwall(ENDPOINT) as powerwall: + api_session = powerwall._api._http_session + self.assertFalse(api_session.closed) + self.assertTrue(api_session.closed) + + async with aiohttp.ClientSession() as session: + api_session = session + async with Powerwall(ENDPOINT, http_session=session) as powerwall: + self.assertFalse(api_session.closed) + self.assertFalse(api_session.closed) + self.assertTrue(api_session.closed) + + powerwall = Powerwall(ENDPOINT) + api_session = powerwall._api._http_session + self.assertFalse(api_session.closed) + await powerwall.close() + self.assertTrue(api_session.closed) + + async with aiohttp.ClientSession() as session: + api_session = session + powerwall = Powerwall(ENDPOINT, http_session=session) + self.assertFalse(api_session.closed) + await powerwall.close() + self.assertFalse(api_session.closed) + self.assertTrue(api_session.closed) diff --git a/tox.ini b/tox.ini index f1ddd94..fd89342 100644 --- a/tox.ini +++ b/tox.ini @@ -2,12 +2,12 @@ envlist = testenv [testenv] -deps = responses +deps = aresponses commands = python -m unittest discover {posargs:tests/unit} [testenv:unit] commands = python -m unittest discover tests/unit [testenv:integration] -passenv = POWERWALL_IP POWERWALL_PASSWORD -commands = python -m unittest discover tests/integration +passenv = POWERWALL_IP,POWERWALL_PASSWORD +commands = python -m unittest discover tests/integration \ No newline at end of file