Skip to content

Commit

Permalink
Add an admin client for managing pubsub topics and subscriptions (#1123)
Browse files Browse the repository at this point in the history
* Add an admin client for managing pubsub topics and subscriptions

* List eligible topics in parallel

* Bump version to 6.1.0

* Remove unused import
  • Loading branch information
allenporter authored Oct 20, 2024
1 parent c6af86e commit d2e60ac
Show file tree
Hide file tree
Showing 5 changed files with 701 additions and 4 deletions.
233 changes: 233 additions & 0 deletions google_nest_sdm/admin_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
"""Admin Client library for the Google Nest SDM API.
This manages administrative tasks for setting up pubsub topics and subscriptions.
This library exists to provide an asyncio interface given that the current pubsub
clients are synchronous.
"""

import logging
import re
import asyncio
from typing import Any
from dataclasses import dataclass, field

from .diagnostics import SUBSCRIBER_DIAGNOSTICS as DIAGNOSTICS
from .auth import AbstractAuth
from .exceptions import (
ApiException,
NotFoundException,
ApiForbiddenException,
ConfigurationException,
)

_LOGGER = logging.getLogger(__name__)

__all__ = [
"AdminClient",
"EligibleTopics",
"EligibleSubscriptions",
"validate_subscription_name",
"validate_topic_name",
]

API_HOST_FORMAT = "https://pubsub.googleapis.com/v1/"
SDM_MANAGED_TOPIC_FORMAT = (
"projects/sdm-prod/topics/enterprise-{device_access_project_id}"
)

# Used to catch invalid subscriber id
EXPECTED_SUBSCRIBER_REGEXP = re.compile("^projects/[^/]+/subscriptions/[^/]+$")

# Used to catch a topic misconfiguration
EXPECTED_TOPIC_REGEXP = re.compile("^projects/[^/]+/topics/[^/]+$")

# Topic prefix for the project
EXPECTED_PROJECS_PREFIX = re.compile("^projects/[^/]+$")


@dataclass
class EligibleTopics:
"""Eligible topics for the project."""

topic_names: list[str] = field(default_factory=list)


@dataclass
class EligibleSubscriptions:
"""Eligible topics for the project."""

subscription_names: list[str] = field(default_factory=list)


def validate_subscription_name(subscription_name: str) -> None:
"""Validates that a subscription name is correct.
Raises ConfigurationException on failure.
"""
if not EXPECTED_SUBSCRIBER_REGEXP.match(subscription_name):
DIAGNOSTICS.increment("subscription_name_invalid")
_LOGGER.debug("Subscription name did not match pattern: %s", subscription_name)
raise ConfigurationException(
"Subscription misconfigured. Expected subscriber_id to "
f"match '{EXPECTED_SUBSCRIBER_REGEXP.pattern}' but was "
f"'{subscription_name}'"
)


def validate_topic_name(topic_name: str) -> None:
"""Validates that a topic name is correct.
Raises ConfigurationException on failure.
"""
if not EXPECTED_TOPIC_REGEXP.match(topic_name):
DIAGNOSTICS.increment("topic_name_invalid")
_LOGGER.debug("Topic name did not match pattern: %s", topic_name)
raise ConfigurationException(
"Subscription misconfigured. Expected topic name to "
f"match '{EXPECTED_TOPIC_REGEXP.pattern}' but was "
f"'{topic_name}'."
)


def validate_projects_prefix(project_path: str) -> None:
"""Validates that a topic or subscription prefix is correct.
Raises ConfigurationException on failure.
"""
if not EXPECTED_PROJECS_PREFIX.match(project_path):
DIAGNOSTICS.increment("topic_prefix_invalid")
_LOGGER.debug("Topic prefix did not match pattern: %s", project_path)
raise ConfigurationException(
"Subscription misconfigured. Expected topic name to "
f"match '{EXPECTED_PROJECS_PREFIX.pattern}' but was "
f"'{project_path}'."
)


class AdminClient:
"""Admin client for the Google Nest SDM API."""

def __init__(
self, auth: AbstractAuth, cloud_project_id: str, host: str | None = None
) -> None:
"""Initialize the admin client."""
self._cloud_project_id = cloud_project_id
self._auth = auth.with_host(host if host is not None else API_HOST_FORMAT)

async def create_topic(self, topic_name: str) -> None:
"""Create a pubsub topic for the project."""
validate_topic_name(topic_name)
await self._auth.request("put", topic_name)

async def delete_topic(self, topic_name: str) -> None:
"""Delete a pubsub topic for the project."""
validate_topic_name(topic_name)
await self._auth.request("delete", topic_name)

async def list_topics(self, projects_prefix: str) -> list[str]:
"""List the pubsub topics for the project.
The topic prefix should be in the format `projects/{console_project_id}`.
"""
validate_projects_prefix(projects_prefix)
response = await self._auth.get_json(projects_prefix)
return [topic["name"] for topic in response["topics"]]

async def get_topic(self, topic_name: str) -> dict[str, Any]:
"""Get a pubsub topic for the project."""
validate_topic_name(topic_name)
return await self._auth.get_json(topic_name)

async def create_subscription(
self, topic_name: str, subscription_name: str
) -> None:
"""Create a pubsub subscription for the project."""
validate_topic_name(topic_name)
validate_subscription_name(subscription_name)
body = {"topic": topic_name}
await self._auth.request("put", subscription_name, json=body)

async def delete_subscription(self, subscription_name: str) -> None:
"""Delete a pubsub subscription for the project."""
validate_subscription_name(subscription_name)
await self._auth.request("delete", subscription_name)

async def list_subscriptions(self, projects_prefix: str) -> list[dict[str, Any]]:
"""List the pubsub subscriptions for the project.
The projects_prefix should be in the format `projects/{console_project_id}`.
"""
validate_projects_prefix(projects_prefix)
response = await self._auth.get_json(projects_prefix)
return response["subscriptions"] # type: ignore[no-any-return]

async def list_eligible_topics(
self, device_access_project_id: str
) -> EligibleTopics:
"""List the eligible topics for the project.
This will try to find any topics already created for the project by either
the device access console or by the user.
"""

sdm_topic_name = SDM_MANAGED_TOPIC_FORMAT.format(
device_access_project_id=device_access_project_id
)

async def get_sdm_topic() -> str | None:
try:
await self.get_topic(sdm_topic_name)
except ApiForbiddenException:
_LOGGER.debug(
"SDM topic exists but we do not have permission to access it (expected)"
)
# The SDM topic exists. It is normal that we do not have permission
# to access it.
return sdm_topic_name
except NotFoundException:
_LOGGER.debug(
"SDM topic does not exist, proceeding to check cloud projects"
)
return None
except ApiException as err:
_LOGGER.error(
"Unexpected error retrieving an SDM created topic: %s", err
)
raise ApiException("Error retrieving SDM created topic") from err
_LOGGER.debug(
"SDM topic exists and we have permission to access it (unexpected)"
)
return sdm_topic_name

async def get_cloud_topics() -> list[str]:
try:
return await self.list_topics(f"projects/{self._cloud_project_id}")
except ApiException as err:
_LOGGER.error("Unexpected error listing topics: %s", err)
raise ApiException(
"Error while listing existing cloud console topics"
) from err

(sdm_topic_task, cloud_topics_task) = await asyncio.gather(
get_sdm_topic(), get_cloud_topics()
)
topics = []
if sdm_topic_task:
topics.append(sdm_topic_task)
topics.extend(cloud_topics_task)
return EligibleTopics(topic_names=topics)

async def list_eligible_subscriptions(
self, expected_topic_name: str
) -> EligibleSubscriptions:
"""Return a set of eligible subscriptions for the project."""
subscriptions = await self.list_subscriptions(
f"projects/{self._cloud_project_id}"
)
return EligibleSubscriptions(
subscription_names=[
sub["name"]
for sub in subscriptions
if sub["topic"] == expected_topic_name
]
)
4 changes: 4 additions & 0 deletions google_nest_sdm/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def __init__(self, websession: aiohttp.ClientSession, host: str):
self._websession = websession
self._host = host

def with_host(self, host: str) -> AbstractAuth:
"""Return a new instance with a different host."""
return self.__class__(self._websession, host)

@abstractmethod
async def async_get_access_token(self) -> str:
"""Return a valid access token."""
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = google_nest_sdm
version = 6.0.0
version = 6.1.0
description = Library for the Google Nest SDM API
long_description = file: README.md
long_description_content_type = text/markdown
Expand Down
13 changes: 10 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import uuid
from http import HTTPStatus
from abc import ABC, abstractmethod
from typing import (
Any,
Expand All @@ -27,6 +28,8 @@
FAKE_TOKEN = "some-token"
PROJECT_ID = "project-id1"

_LOGGER = logging.getLogger(__name__)


def pytest_configure(config: pytest.Config) -> None:
"""Register marker for tests that log exceptions."""
Expand All @@ -51,6 +54,7 @@ def mock_server(
) -> Callable[[], Awaitable[TestServer]]:
async def _make_server() -> TestServer:
server = await aiohttp_server(app)
server.skip_url_asserts = True
assert isinstance(server, TestServer)
return server

Expand All @@ -59,7 +63,6 @@ async def _make_server() -> TestServer:

@pytest.fixture(name="client")
def mock_client(
#event_loop: Any,
server: Callable[[], Awaitable[TestServer]],
aiohttp_client: Callable[[TestServer], Awaitable[TestClient]],
) -> Callable[[], Awaitable[TestClient]]:
Expand Down Expand Up @@ -171,6 +174,7 @@ def get_response(self) -> dict[str, Any]:
"""Implemented by subclasses to return a response."""

async def handler(self, request: aiohttp.web.Request) -> aiohttp.web.Response:
_LOGGER.debug("Request: %s", request)
assert request.headers["Authorization"] == "Bearer %s" % self.token
s = await request.text()
self.recorder.request = await request.json() if s else {}
Expand Down Expand Up @@ -274,13 +278,16 @@ def get_response(self) -> dict[str, Any]:


def NewHandler(
r: Recorder, responses: list[dict[str, Any]], token: str = FAKE_TOKEN
r: Recorder,
responses: list[dict[str, Any]],
token: str = FAKE_TOKEN,
status: HTTPStatus = HTTPStatus.OK,
) -> Callable[[aiohttp.web.Request], Awaitable[aiohttp.web.Response]]:
async def handler(request: aiohttp.web.Request) -> aiohttp.web.Response:
assert request.headers["Authorization"] == "Bearer %s" % token
s = await request.text()
r.request = await request.json() if s else {}
return aiohttp.web.json_response(responses.pop(0))
return aiohttp.web.json_response(responses.pop(0), status=status)

return handler

Expand Down
Loading

0 comments on commit d2e60ac

Please sign in to comment.