Skip to content

Commit

Permalink
Raise exception and reload integration
Browse files Browse the repository at this point in the history
  • Loading branch information
arjenbos committed Dec 29, 2023
1 parent e0c8e60 commit bbc67eb
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
19 changes: 15 additions & 4 deletions custom_components/postnl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,23 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> True:
return True


async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload PostNL config entry."""
_LOGGER.debug('Reloading PostNL integration')
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id)

return unload_ok


class AsyncConfigEntryAuth:
"""Provide PostNL authentication tied to an OAuth2 based config entry."""

def __init__(
self,
oauth2_session: config_entry_oauth2_flow.OAuth2Session,
self,
oauth2_session: config_entry_oauth2_flow.OAuth2Session,
) -> None:
"""Initialize PostNL Auth."""
self.oauth_session = oauth2_session
Expand All @@ -74,8 +85,8 @@ async def check_and_refresh_token(self) -> str:

except (ClientResponseError, ClientError) as ex:
if (
self.oauth_session.config_entry.state
is ConfigEntryState.SETUP_IN_PROGRESS
self.oauth_session.config_entry.state
is ConfigEntryState.SETUP_IN_PROGRESS
):
if isinstance(ex, ClientResponseError) and 400 <= ex.status < 500:
raise ConfigEntryAuthFailed(
Expand Down
41 changes: 22 additions & 19 deletions custom_components/postnl/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import timedelta

from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed

from . import AsyncConfigEntryAuth, PostNLGraphql
from .const import DOMAIN
Expand All @@ -25,33 +25,36 @@ def __init__(self, hass: HomeAssistant) -> None:
update_interval=timedelta(seconds=120),
)

async def _async_update_data(self) -> dict[str, list[Package]]:
async def _async_update_data(self) -> dict[str, tuple[Package]]:
_LOGGER.debug('Get API data')
try:

auth: AsyncConfigEntryAuth = self.hass.data[DOMAIN][self.config_entry.entry_id]
await auth.check_and_refresh_token()
auth: AsyncConfigEntryAuth = self.hass.data[DOMAIN][self.config_entry.entry_id]
await auth.check_and_refresh_token()

self.graphq_api = PostNLGraphql(auth.access_token)
self.jouw_api = PostNLJouwAPI(auth.access_token)
self.graphq_api = PostNLGraphql(auth.access_token)
self.jouw_api = PostNLJouwAPI(auth.access_token)

data: dict[str, tuple[Package]] = {
'receiver': [],
'sender': []
}
data: dict[str, tuple[Package]] = {
'receiver': [],
'sender': []
}

shipments = await self.hass.async_add_executor_job(self.graphq_api.shipments)
shipments = await self.hass.async_add_executor_job(self.graphq_api.shipments)

receiver_shipments = [self.transform_shipment(shipment) for shipment in
shipments.get('trackedShipments', {}).get('receiverShipments', [])]
data['receiver'] = await asyncio.gather(*receiver_shipments)
receiver_shipments = [self.transform_shipment(shipment) for shipment in
shipments.get('trackedShipments', {}).get('receiverShipments', [])]
data['receiver'] = await asyncio.gather(*receiver_shipments)

sender_shipments = [self.transform_shipment(shipment) for shipment in
shipments.get('trackedShipments', {}).get('senderShipments', [])]
data['sender'] = await asyncio.gather(*sender_shipments)
sender_shipments = [self.transform_shipment(shipment) for shipment in
shipments.get('trackedShipments', {}).get('senderShipments', [])]
data['sender'] = await asyncio.gather(*sender_shipments)

_LOGGER.debug('Found %d packages', len(data['sender']) + len(data['receiver']))
_LOGGER.debug('Found %d packages', len(data['sender']) + len(data['receiver']))

return data
return data
except Exception as exception:
raise UpdateFailed(exception)

async def transform_shipment(self, shipment) -> Package:
_LOGGER.debug('Updating %s', shipment.get('key'))
Expand Down

0 comments on commit bbc67eb

Please sign in to comment.