From bfa3b08d4183f88456adc127f175b0a4c6e00f6b Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Sat, 10 Jun 2023 08:07:52 -0700 Subject: [PATCH] Improve exception handling for subscriber credential refresh (#637) * Improve exception handling for subscriber credential refresh * Address spelling error --- google_nest_sdm/google_nest_subscriber.py | 34 +++++++++++++++-------- tests/test_google_nest_subscriber.py | 31 ++++++++++++++++++++- 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/google_nest_sdm/google_nest_subscriber.py b/google_nest_sdm/google_nest_subscriber.py index f9d6a0dd..a853292c 100644 --- a/google_nest_sdm/google_nest_subscriber.py +++ b/google_nest_sdm/google_nest_subscriber.py @@ -13,7 +13,7 @@ from aiohttp.client_exceptions import ClientError from google.api_core.exceptions import GoogleAPIError, NotFound, Unauthenticated -from google.auth.exceptions import RefreshError +from google.auth.exceptions import RefreshError, GoogleAuthError, TransportError from google.auth.transport.requests import Request from google.cloud import pubsub_v1 from google.oauth2.credentials import Credentials @@ -119,6 +119,26 @@ def _validate_topic_name(topic_name: str) -> None: ) +def refresh_creds(creds: Credentials) -> Credentials: + """Refresh credentials. + + This is not part of the subscriber API, exposed only to facilitate testing. + """ + try: + creds.refresh(Request()) + except RefreshError as err: + raise AuthException(f"Authentication refresh failure: {err}") from err + except TransportError as err: + raise SubscriberException( + f"Connectivity error during authentication refresh: {err}" + ) from err + except GoogleAuthError as err: + raise SubscriberException( + f"Error during authentication refresh: {err}" + ) from err + return creds + + class AbstractSubscriberFactory(ABC): """Abstract class for creating a subscriber, to facilitate testing.""" @@ -232,7 +252,7 @@ def _delete_subscription( subscription_name: str, ) -> None: """Deletes a subscription.""" - creds = self._refresh_creds(creds) + creds = refresh_creds(creds) subscriber = pubsub_v1.SubscriberClient(credentials=creds) _LOGGER.debug(f"Deleting subscription '{subscription_name}'") subscriber.delete_subscription(subscription=subscription_name) @@ -265,14 +285,6 @@ def callback_wrapper(message: pubsub_v1.subscriber.message.Message) -> None: callback_wrapper, ) - def _refresh_creds(self, creds: Credentials) -> Credentials: - """Refresh credentials.""" - try: - creds.refresh(Request()) - except RefreshError as err: - raise AuthException(f"Access token failure: {err}") from err - return creds - def _new_subscriber( self, creds: Credentials, @@ -280,7 +292,7 @@ def _new_subscriber( callback_wrapper: Callable[[pubsub_v1.subscriber.message.Message], None], ) -> pubsub_v1.subscriber.futures.StreamingPullFuture: """Issue a command to verify subscriber creds are correct.""" - creds = self._refresh_creds(creds) + creds = refresh_creds(creds) subscriber = pubsub_v1.SubscriberClient(credentials=creds) subscription = subscriber.get_subscription(subscription=subscription_name) if subscription.topic: diff --git a/tests/test_google_nest_subscriber.py b/tests/test_google_nest_subscriber.py index 59d95a59..108f6311 100644 --- a/tests/test_google_nest_subscriber.py +++ b/tests/test_google_nest_subscriber.py @@ -3,10 +3,11 @@ import asyncio import json from typing import Any, Awaitable, Callable, Dict, Optional -from unittest.mock import create_autospec +from unittest.mock import create_autospec, Mock import aiohttp import pytest +from google.auth.exceptions import RefreshError, GoogleAuthError, TransportError from google.api_core.exceptions import ClientError, Unauthenticated from google.cloud import pubsub_v1 from google.oauth2.credentials import Credentials @@ -23,6 +24,7 @@ AbstractSubscriberFactory, GoogleNestSubscriber, get_api_env, + refresh_creds, ) from .conftest import DeviceHandler, EventCallback, StructureHandler, assert_diagnostics @@ -92,6 +94,33 @@ async def make_subscriber( return make_subscriber +async def test_refresh_creds() -> None: + """Test low level refresh errors.""" + mock_refresh = Mock() + mock_creds = Mock() + mock_creds.refresh = mock_refresh + refresh_creds(mock_creds) + assert mock_refresh.call_count == 1 + + +@pytest.mark.parametrize( + ("raised", "expected"), + [ + (RefreshError(), AuthException), + (TransportError(), SubscriberException), + (GoogleAuthError(), SubscriberException), + ], +) +async def test_refresh_creds_error(raised: Exception, expected: Any) -> None: + """Test low level refresh errors.""" + mock_refresh = Mock() + mock_refresh.side_effect = raised + mock_creds = Mock() + mock_creds.refresh = mock_refresh + with pytest.raises(expected): + refresh_creds(mock_creds) + + async def test_subscribe_no_events( device_handler: DeviceHandler, structure_handler: StructureHandler,