Skip to content

Commit

Permalink
Improve exception handling for subscriber credential refresh (#637)
Browse files Browse the repository at this point in the history
* Improve exception handling for subscriber credential refresh

* Address spelling error
  • Loading branch information
allenporter authored Jun 10, 2023
1 parent 8786058 commit bfa3b08
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
34 changes: 23 additions & 11 deletions google_nest_sdm/google_nest_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from aiohttp.client_exceptions import ClientError
from google.api_core.exceptions import GoogleAPIError, NotFound, Unauthenticated
from google.auth.exceptions import RefreshError
from google.auth.exceptions import RefreshError, GoogleAuthError, TransportError
from google.auth.transport.requests import Request
from google.cloud import pubsub_v1
from google.oauth2.credentials import Credentials
Expand Down Expand Up @@ -119,6 +119,26 @@ def _validate_topic_name(topic_name: str) -> None:
)


def refresh_creds(creds: Credentials) -> Credentials:
"""Refresh credentials.
This is not part of the subscriber API, exposed only to facilitate testing.
"""
try:
creds.refresh(Request())
except RefreshError as err:
raise AuthException(f"Authentication refresh failure: {err}") from err
except TransportError as err:
raise SubscriberException(
f"Connectivity error during authentication refresh: {err}"
) from err
except GoogleAuthError as err:
raise SubscriberException(
f"Error during authentication refresh: {err}"
) from err
return creds


class AbstractSubscriberFactory(ABC):
"""Abstract class for creating a subscriber, to facilitate testing."""

Expand Down Expand Up @@ -232,7 +252,7 @@ def _delete_subscription(
subscription_name: str,
) -> None:
"""Deletes a subscription."""
creds = self._refresh_creds(creds)
creds = refresh_creds(creds)
subscriber = pubsub_v1.SubscriberClient(credentials=creds)
_LOGGER.debug(f"Deleting subscription '{subscription_name}'")
subscriber.delete_subscription(subscription=subscription_name)
Expand Down Expand Up @@ -265,22 +285,14 @@ def callback_wrapper(message: pubsub_v1.subscriber.message.Message) -> None:
callback_wrapper,
)

def _refresh_creds(self, creds: Credentials) -> Credentials:
"""Refresh credentials."""
try:
creds.refresh(Request())
except RefreshError as err:
raise AuthException(f"Access token failure: {err}") from err
return creds

def _new_subscriber(
self,
creds: Credentials,
subscription_name: str,
callback_wrapper: Callable[[pubsub_v1.subscriber.message.Message], None],
) -> pubsub_v1.subscriber.futures.StreamingPullFuture:
"""Issue a command to verify subscriber creds are correct."""
creds = self._refresh_creds(creds)
creds = refresh_creds(creds)
subscriber = pubsub_v1.SubscriberClient(credentials=creds)
subscription = subscriber.get_subscription(subscription=subscription_name)
if subscription.topic:
Expand Down
31 changes: 30 additions & 1 deletion tests/test_google_nest_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import asyncio
import json
from typing import Any, Awaitable, Callable, Dict, Optional
from unittest.mock import create_autospec
from unittest.mock import create_autospec, Mock

import aiohttp
import pytest
from google.auth.exceptions import RefreshError, GoogleAuthError, TransportError
from google.api_core.exceptions import ClientError, Unauthenticated
from google.cloud import pubsub_v1
from google.oauth2.credentials import Credentials
Expand All @@ -23,6 +24,7 @@
AbstractSubscriberFactory,
GoogleNestSubscriber,
get_api_env,
refresh_creds,
)

from .conftest import DeviceHandler, EventCallback, StructureHandler, assert_diagnostics
Expand Down Expand Up @@ -92,6 +94,33 @@ async def make_subscriber(
return make_subscriber


async def test_refresh_creds() -> None:
"""Test low level refresh errors."""
mock_refresh = Mock()
mock_creds = Mock()
mock_creds.refresh = mock_refresh
refresh_creds(mock_creds)
assert mock_refresh.call_count == 1


@pytest.mark.parametrize(
("raised", "expected"),
[
(RefreshError(), AuthException),
(TransportError(), SubscriberException),
(GoogleAuthError(), SubscriberException),
],
)
async def test_refresh_creds_error(raised: Exception, expected: Any) -> None:
"""Test low level refresh errors."""
mock_refresh = Mock()
mock_refresh.side_effect = raised
mock_creds = Mock()
mock_creds.refresh = mock_refresh
with pytest.raises(expected):
refresh_creds(mock_creds)


async def test_subscribe_no_events(
device_handler: DeviceHandler,
structure_handler: StructureHandler,
Expand Down

0 comments on commit bfa3b08

Please sign in to comment.