diff --git a/google_nest_sdm/admin_client.py b/google_nest_sdm/admin_client.py new file mode 100644 index 00000000..841398a1 --- /dev/null +++ b/google_nest_sdm/admin_client.py @@ -0,0 +1,233 @@ +"""Admin Client library for the Google Nest SDM API. + +This manages administrative tasks for setting up pubsub topics and subscriptions. + +This library exists to provide an asyncio interface given that the current pubsub +clients are synchronous. +""" + +import logging +import re +import asyncio +from typing import Any +from dataclasses import dataclass, field + +from .diagnostics import SUBSCRIBER_DIAGNOSTICS as DIAGNOSTICS +from .auth import AbstractAuth +from .exceptions import ( + ApiException, + NotFoundException, + ApiForbiddenException, + ConfigurationException, +) + +_LOGGER = logging.getLogger(__name__) + +__all__ = [ + "AdminClient", + "EligibleTopics", + "EligibleSubscriptions", + "validate_subscription_name", + "validate_topic_name", +] + +API_HOST_FORMAT = "https://pubsub.googleapis.com/v1/" +SDM_MANAGED_TOPIC_FORMAT = ( + "projects/sdm-prod/topics/enterprise-{device_access_project_id}" +) + +# Used to catch invalid subscriber id +EXPECTED_SUBSCRIBER_REGEXP = re.compile("^projects/[^/]+/subscriptions/[^/]+$") + +# Used to catch a topic misconfiguration +EXPECTED_TOPIC_REGEXP = re.compile("^projects/[^/]+/topics/[^/]+$") + +# Topic prefix for the project +EXPECTED_PROJECS_PREFIX = re.compile("^projects/[^/]+$") + + +@dataclass +class EligibleTopics: + """Eligible topics for the project.""" + + topic_names: list[str] = field(default_factory=list) + + +@dataclass +class EligibleSubscriptions: + """Eligible topics for the project.""" + + subscription_names: list[str] = field(default_factory=list) + + +def validate_subscription_name(subscription_name: str) -> None: + """Validates that a subscription name is correct. + + Raises ConfigurationException on failure. + """ + if not EXPECTED_SUBSCRIBER_REGEXP.match(subscription_name): + DIAGNOSTICS.increment("subscription_name_invalid") + _LOGGER.debug("Subscription name did not match pattern: %s", subscription_name) + raise ConfigurationException( + "Subscription misconfigured. Expected subscriber_id to " + f"match '{EXPECTED_SUBSCRIBER_REGEXP.pattern}' but was " + f"'{subscription_name}'" + ) + + +def validate_topic_name(topic_name: str) -> None: + """Validates that a topic name is correct. + + Raises ConfigurationException on failure. + """ + if not EXPECTED_TOPIC_REGEXP.match(topic_name): + DIAGNOSTICS.increment("topic_name_invalid") + _LOGGER.debug("Topic name did not match pattern: %s", topic_name) + raise ConfigurationException( + "Subscription misconfigured. Expected topic name to " + f"match '{EXPECTED_TOPIC_REGEXP.pattern}' but was " + f"'{topic_name}'." + ) + + +def validate_projects_prefix(project_path: str) -> None: + """Validates that a topic or subscription prefix is correct. + + Raises ConfigurationException on failure. + """ + if not EXPECTED_PROJECS_PREFIX.match(project_path): + DIAGNOSTICS.increment("topic_prefix_invalid") + _LOGGER.debug("Topic prefix did not match pattern: %s", project_path) + raise ConfigurationException( + "Subscription misconfigured. Expected topic name to " + f"match '{EXPECTED_PROJECS_PREFIX.pattern}' but was " + f"'{project_path}'." + ) + + +class AdminClient: + """Admin client for the Google Nest SDM API.""" + + def __init__( + self, auth: AbstractAuth, cloud_project_id: str, host: str | None = None + ) -> None: + """Initialize the admin client.""" + self._cloud_project_id = cloud_project_id + self._auth = auth.with_host(host if host is not None else API_HOST_FORMAT) + + async def create_topic(self, topic_name: str) -> None: + """Create a pubsub topic for the project.""" + validate_topic_name(topic_name) + await self._auth.request("put", topic_name) + + async def delete_topic(self, topic_name: str) -> None: + """Delete a pubsub topic for the project.""" + validate_topic_name(topic_name) + await self._auth.request("delete", topic_name) + + async def list_topics(self, projects_prefix: str) -> list[str]: + """List the pubsub topics for the project. + + The topic prefix should be in the format `projects/{console_project_id}`. + """ + validate_projects_prefix(projects_prefix) + response = await self._auth.get_json(projects_prefix) + return [topic["name"] for topic in response["topics"]] + + async def get_topic(self, topic_name: str) -> dict[str, Any]: + """Get a pubsub topic for the project.""" + validate_topic_name(topic_name) + return await self._auth.get_json(topic_name) + + async def create_subscription( + self, topic_name: str, subscription_name: str + ) -> None: + """Create a pubsub subscription for the project.""" + validate_topic_name(topic_name) + validate_subscription_name(subscription_name) + body = {"topic": topic_name} + await self._auth.request("put", subscription_name, json=body) + + async def delete_subscription(self, subscription_name: str) -> None: + """Delete a pubsub subscription for the project.""" + validate_subscription_name(subscription_name) + await self._auth.request("delete", subscription_name) + + async def list_subscriptions(self, projects_prefix: str) -> list[dict[str, Any]]: + """List the pubsub subscriptions for the project. + The projects_prefix should be in the format `projects/{console_project_id}`. + """ + validate_projects_prefix(projects_prefix) + response = await self._auth.get_json(projects_prefix) + return response["subscriptions"] # type: ignore[no-any-return] + + async def list_eligible_topics( + self, device_access_project_id: str + ) -> EligibleTopics: + """List the eligible topics for the project. + + This will try to find any topics already created for the project by either + the device access console or by the user. + """ + + sdm_topic_name = SDM_MANAGED_TOPIC_FORMAT.format( + device_access_project_id=device_access_project_id + ) + + async def get_sdm_topic() -> str | None: + try: + await self.get_topic(sdm_topic_name) + except ApiForbiddenException: + _LOGGER.debug( + "SDM topic exists but we do not have permission to access it (expected)" + ) + # The SDM topic exists. It is normal that we do not have permission + # to access it. + return sdm_topic_name + except NotFoundException: + _LOGGER.debug( + "SDM topic does not exist, proceeding to check cloud projects" + ) + return None + except ApiException as err: + _LOGGER.error( + "Unexpected error retrieving an SDM created topic: %s", err + ) + raise ApiException("Error retrieving SDM created topic") from err + _LOGGER.debug( + "SDM topic exists and we have permission to access it (unexpected)" + ) + return sdm_topic_name + + async def get_cloud_topics() -> list[str]: + try: + return await self.list_topics(f"projects/{self._cloud_project_id}") + except ApiException as err: + _LOGGER.error("Unexpected error listing topics: %s", err) + raise ApiException( + "Error while listing existing cloud console topics" + ) from err + + (sdm_topic_task, cloud_topics_task) = await asyncio.gather( + get_sdm_topic(), get_cloud_topics() + ) + topics = [] + if sdm_topic_task: + topics.append(sdm_topic_task) + topics.extend(cloud_topics_task) + return EligibleTopics(topic_names=topics) + + async def list_eligible_subscriptions( + self, expected_topic_name: str + ) -> EligibleSubscriptions: + """Return a set of eligible subscriptions for the project.""" + subscriptions = await self.list_subscriptions( + f"projects/{self._cloud_project_id}" + ) + return EligibleSubscriptions( + subscription_names=[ + sub["name"] + for sub in subscriptions + if sub["topic"] == expected_topic_name + ] + ) diff --git a/google_nest_sdm/auth.py b/google_nest_sdm/auth.py index 0212dbb0..80d24370 100644 --- a/google_nest_sdm/auth.py +++ b/google_nest_sdm/auth.py @@ -101,6 +101,10 @@ def __init__(self, websession: aiohttp.ClientSession, host: str): self._websession = websession self._host = host + def with_host(self, host: str) -> AbstractAuth: + """Return a new instance with a different host.""" + return self.__class__(self._websession, host) + @abstractmethod async def async_get_access_token(self) -> str: """Return a valid access token.""" diff --git a/setup.cfg b/setup.cfg index deb6f24a..d5e513ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = google_nest_sdm -version = 6.0.0 +version = 6.1.0 description = Library for the Google Nest SDM API long_description = file: README.md long_description_content_type = text/markdown diff --git a/tests/conftest.py b/tests/conftest.py index 2a05f6c8..582737a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ from __future__ import annotations import uuid +from http import HTTPStatus from abc import ABC, abstractmethod from typing import ( Any, @@ -27,6 +28,8 @@ FAKE_TOKEN = "some-token" PROJECT_ID = "project-id1" +_LOGGER = logging.getLogger(__name__) + def pytest_configure(config: pytest.Config) -> None: """Register marker for tests that log exceptions.""" @@ -51,6 +54,7 @@ def mock_server( ) -> Callable[[], Awaitable[TestServer]]: async def _make_server() -> TestServer: server = await aiohttp_server(app) + server.skip_url_asserts = True assert isinstance(server, TestServer) return server @@ -59,7 +63,6 @@ async def _make_server() -> TestServer: @pytest.fixture(name="client") def mock_client( - #event_loop: Any, server: Callable[[], Awaitable[TestServer]], aiohttp_client: Callable[[TestServer], Awaitable[TestClient]], ) -> Callable[[], Awaitable[TestClient]]: @@ -171,6 +174,7 @@ def get_response(self) -> dict[str, Any]: """Implemented by subclasses to return a response.""" async def handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response: + _LOGGER.debug("Request: %s", request) assert request.headers["Authorization"] == "Bearer %s" % self.token s = await request.text() self.recorder.request = await request.json() if s else {} @@ -274,13 +278,16 @@ def get_response(self) -> dict[str, Any]: def NewHandler( - r: Recorder, responses: list[dict[str, Any]], token: str = FAKE_TOKEN + r: Recorder, + responses: list[dict[str, Any]], + token: str = FAKE_TOKEN, + status: HTTPStatus = HTTPStatus.OK, ) -> Callable[[aiohttp.web.Request], Awaitable[aiohttp.web.Response]]: async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response: assert request.headers["Authorization"] == "Bearer %s" % token s = await request.text() r.request = await request.json() if s else {} - return aiohttp.web.json_response(responses.pop(0)) + return aiohttp.web.json_response(responses.pop(0), status=status) return handler diff --git a/tests/test_admin_client.py b/tests/test_admin_client.py new file mode 100644 index 00000000..5058e8f7 --- /dev/null +++ b/tests/test_admin_client.py @@ -0,0 +1,453 @@ +"""Tests for the admin client library.""" + +from typing import Awaitable, Callable +from http import HTTPStatus + +import aiohttp +import pytest + +from google_nest_sdm.admin_client import AdminClient +from google_nest_sdm.auth import AbstractAuth +from google_nest_sdm.exceptions import ( + ApiException, + ConfigurationException, +) + +from .conftest import Recorder, NewHandler + +GOOGLE_CLOUD_CONSOLE_PROJECT_ID = "google-cloud-console-project-id" +DEVICE_ACCESS_PROJECT_ID = "device-access-project-id" + + +@pytest.fixture(name="admin_client") +def mock_admin_client( + auth_client: Callable[[], Awaitable[AbstractAuth]], +) -> Callable[[], Awaitable[AdminClient]]: + + async def _make_admin_client() -> AdminClient: + mock_auth = await auth_client() + return AdminClient(mock_auth, GOOGLE_CLOUD_CONSOLE_PROJECT_ID, host="") + + return _make_admin_client + + +async def test_invalid_topic_format( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test creating a topic.""" + client = await admin_client() + with pytest.raises(ConfigurationException): + await client.create_topic("some-topic") + + +async def test_create_topic( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test creating a topic.""" + + handler = NewHandler( + recorder, + [{}], + ) + app.router.add_put( + f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name", handler + ) + + client = await admin_client() + await client.create_topic( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name" + ) + + assert recorder.request == {} + + +async def test_delete_topic( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test deleting a topic.""" + + handler = NewHandler( + recorder, + [{}], + ) + app.router.add_delete( + f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name", handler + ) + + client = await admin_client() + await client.delete_topic( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name" + ) + + assert recorder.request == {} + + +async def test_list_topics( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test listing topics.""" + + handler = NewHandler( + recorder, + [ + { + "topics": [ + {"name": "projects/project-id/topics/topic1"}, + {"name": "projects/project-id/topics/topic2"}, + ] + } + ], + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}", handler) + + client = await admin_client() + topics = await client.list_topics(f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}") + + assert topics == [ + "projects/project-id/topics/topic1", + "projects/project-id/topics/topic2", + ] + assert recorder.request == {} + + +async def test_list_topics_invalid_prefix( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test listing topics with an invalid prefix.""" + + handler = NewHandler( + recorder, + [ + { + "topics": [ + {"name": "projects/project-id/topics/topic1"}, + {"name": "projects/project-id/topics/topic2"}, + ] + } + ], + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics", handler) + + client = await admin_client() + with pytest.raises(ConfigurationException): + await client.list_topics("projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics") + + +async def test_get_topic( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test getting a topic.""" + + handler = NewHandler( + recorder, + [{"name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name"}], + ) + app.router.add_get( + f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name", handler + ) + + client = await admin_client() + response = await client.get_topic( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name" + ) + assert recorder.request == {} + assert response == { + "name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name" + } + + +async def test_create_subscription( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test creating a subscription.""" + + handler = NewHandler( + recorder, + [{}], + ) + app.router.add_put( + f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription-name", + handler, + ) + + client = await admin_client() + await client.create_subscription( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name", + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription-name", + ) + + assert recorder.request == { + "topic": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name" + } + + +async def test_delete_subscription( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test deleting a subscription.""" + + handler = NewHandler( + recorder, + [{}], + ) + app.router.add_delete( + f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription-name", + handler, + ) + + client = await admin_client() + await client.delete_subscription( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription-name" + ) + + assert recorder.request == {} + + +async def test_list_subscriptions( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test listing subscriptions.""" + + handler = NewHandler( + recorder, + [ + { + "subscriptions": [ + { + "name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription1" + }, + { + "name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription2" + }, + ] + } + ], + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}", handler) + + client = await admin_client() + subscriptions = await client.list_subscriptions( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}" + ) + + assert subscriptions == [ + { + "name": "projects/google-cloud-console-project-id/subscriptions/subscription1" + }, + { + "name": "projects/google-cloud-console-project-id/subscriptions/subscription2" + }, + ] + + +async def test_invalid_subscription_format( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test creating a subscription.""" + client = await admin_client() + with pytest.raises(ConfigurationException): + await client.create_subscription( + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/topic-name", + "some-subscription", + ) + with pytest.raises(ConfigurationException): + await client.create_subscription( + "some-topic", + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/subscription-name", + ) + + +async def test_list_eligible_topics( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test listing eligible topics.""" + + # SDM created pubsub topic exists (but is not visible, which is expected) exists + sdm_handler = NewHandler( + recorder, + [ + { + "error": { + "code": 403, + "message": "User not authorized to perform this action.", + "status": "PERMISSION_DENIED", + } + }, + ], + status=HTTPStatus.FORBIDDEN, + ) + app.router.add_get( + f"/projects/sdm-prod/topics/enterprise-{DEVICE_ACCESS_PROJECT_ID}", sdm_handler + ) + # Cloud topic also exists + cloud_handler = NewHandler( + recorder, + [ + { + "topics": [ + { + "name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/sdm-testing" + } + ] + } + ], + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}", cloud_handler) + + client = await admin_client() + eligible_topics = await client.list_eligible_topics(DEVICE_ACCESS_PROJECT_ID) + assert eligible_topics.topic_names == [ + "projects/sdm-prod/topics/enterprise-device-access-project-id", + "projects/google-cloud-console-project-id/topics/sdm-testing", + ] + + +async def test_list_eligible_topics_no_sdm_topic( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test listing eligible topics when no SDM topic exists.""" + + # SDM created pubsub topic does not exist + sdm_handler = NewHandler( + recorder, + [ + { + "error": { + "code": 404, + "message": f"Resource not found (resource=enterprise-{DEVICE_ACCESS_PROJECT_ID}).", + "status": "NOT_FOUND", + } + } + ], + status=HTTPStatus.NOT_FOUND, + ) + app.router.add_get( + f"/projects/sdm-prod/topics/enterprise-{DEVICE_ACCESS_PROJECT_ID}", sdm_handler + ) + + # Cloud topic exists + cloud_handler = NewHandler( + recorder, + [ + { + "topics": [ + { + "name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/sdm-testing" + } + ] + } + ], + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}", cloud_handler) + + client = await admin_client() + eligible_topics = await client.list_eligible_topics(DEVICE_ACCESS_PROJECT_ID) + assert eligible_topics.topic_names == [ + "projects/google-cloud-console-project-id/topics/sdm-testing" + ] + + +@pytest.mark.parametrize( + "sdm_status, cloud_status", + [ + (HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.OK), + (HTTPStatus.OK, HTTPStatus.INTERNAL_SERVER_ERROR), + ], +) +async def test_list_cloud_console_api_error( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, + sdm_status: HTTPStatus, + cloud_status: HTTPStatus, +) -> None: + """Test listing eligible topics when an error occurs listing the cloud console topics.""" + + # SDM created pubsub topic exists (but is not visible, which is expected) exists + sdm_handler = NewHandler( + recorder, + [ + { + "error": { + "code": 403, + "message": "User not authorized to perform this action.", + "status": "PERMISSION_DENIED", + } + }, + ], + status=sdm_status, + ) + app.router.add_get( + f"/projects/sdm-prod/topics/enterprise-{DEVICE_ACCESS_PROJECT_ID}", sdm_handler + ) + # Cloud topic also exists + cloud_handler = NewHandler( + recorder, + [{}], + status=cloud_status, + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}", cloud_handler) + + client = await admin_client() + with pytest.raises(ApiException): + await client.list_eligible_topics(DEVICE_ACCESS_PROJECT_ID) + + +async def test_list_eligible_subscriptions( + app: aiohttp.web.Application, + admin_client: Callable[[], Awaitable[AdminClient]], + recorder: Recorder, +) -> None: + """Test listing eligible subscriptions.""" + + cloud_handler = NewHandler( + recorder, + [ + { + "subscriptions": [ + { + "name": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/sdm-testing-sub", + "topic": f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/sdm-testing", + "pushConfig": {}, + "ackDeadlineSeconds": 10, + "messageRetentionDuration": "604800s", + "expirationPolicy": {"ttl": "2678400s"}, + "state": "ACTIVE", + } + ] + } + ], + ) + app.router.add_get(f"/projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}", cloud_handler) + + client = await admin_client() + eligible_subscriptions = await client.list_eligible_subscriptions( + expected_topic_name=f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/topics/sdm-testing" + ) + assert eligible_subscriptions.subscription_names == [ + f"projects/{GOOGLE_CLOUD_CONSOLE_PROJECT_ID}/subscriptions/sdm-testing-sub", + ]