From 7fa0b62d98673dc7b4cb1808ac0e4f5b89cb3ac6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1ty=C3=A1s=20Kuti?= Date: Wed, 6 Dec 2023 10:09:26 +0100 Subject: [PATCH 1/2] Implement async confluent-kafka producer THe async `AIOKafkaProducer` is implemented as a wrapper around `KafkaProducer` with an async `send` method and a poll-thread to continuously send messages in the background, making the result of `send` an awaitable `asyncio.Future`. The example of confluent-kafka has been followed: https://github.com/confluentinc/confluent-kafka-python/blob/master/examples/asyncio_example.py --- karapace/kafka/common.py | 11 +++- karapace/kafka/producer.py | 72 +++++++++++++++++++++- karapace/kafka_rest_apis/__init__.py | 43 ++++++------- karapace/kafka_rest_apis/authentication.py | 3 + tests/integration/kafka/test_producer.py | 71 ++++++++++++++++++++- tests/integration/test_rest.py | 4 +- tests/unit/test_authentication.py | 8 +++ 7 files changed, 186 insertions(+), 26 deletions(-) diff --git a/karapace/kafka/common.py b/karapace/kafka/common.py index 78dd1b78a..cf25d22dd 100644 --- a/karapace/kafka/common.py +++ b/karapace/kafka/common.py @@ -71,9 +71,12 @@ def token_with_expiry(self, config: str | None) -> tuple[str, int | None]: class KafkaClientParams(TypedDict, total=False): + acks: int | None client_id: str | None connections_max_idle_ms: int | None - max_block_ms: int | None + compression_type: str | None + linger_ms: int | None + message_max_bytes: int | None metadata_max_age_ms: int | None retries: int | None sasl_mechanism: str | None @@ -83,6 +86,7 @@ class KafkaClientParams(TypedDict, total=False): socket_timeout_ms: int | None ssl_cafile: str | None ssl_certfile: str | None + ssl_crlfile: str | None ssl_keyfile: str | None sasl_oauth_token_provider: TokenWithExpiryProvider # Consumer-only @@ -121,8 +125,12 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para config: dict[str, int | str | Callable | None] = { "bootstrap.servers": bootstrap_servers, + "acks": params.get("acks"), "client.id": params.get("client_id"), "connections.max.idle.ms": params.get("connections_max_idle_ms"), + "compression.type": params.get("compression_type"), + "linger.ms": params.get("linger_ms"), + "message.max.bytes": params.get("message_max_bytes"), "metadata.max.age.ms": params.get("metadata_max_age_ms"), "retries": params.get("retries"), "sasl.mechanism": params.get("sasl_mechanism"), @@ -132,6 +140,7 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para "socket.timeout.ms": params.get("socket_timeout_ms"), "ssl.ca.location": params.get("ssl_cafile"), "ssl.certificate.location": params.get("ssl_certfile"), + "ssl.crl.location": params.get("ssl_crlfile"), "ssl.key.location": params.get("ssl_keyfile"), "error_cb": self._error_callback, # Consumer-only diff --git a/karapace/kafka/producer.py b/karapace/kafka/producer.py index 2ee677c3b..68f728ed0 100644 --- a/karapace/kafka/producer.py +++ b/karapace/kafka/producer.py @@ -5,15 +5,18 @@ from __future__ import annotations +from collections.abc import Iterable from concurrent.futures import Future from confluent_kafka import Message, Producer from confluent_kafka.admin import PartitionMetadata from confluent_kafka.error import KafkaError, KafkaException from functools import partial -from karapace.kafka.common import _KafkaConfigMixin, raise_from_kafkaexception, translate_from_kafkaerror +from karapace.kafka.common import _KafkaConfigMixin, KafkaClientParams, raise_from_kafkaexception, translate_from_kafkaerror +from threading import Event, Thread from typing import cast, TypedDict from typing_extensions import Unpack +import asyncio import logging LOG = logging.getLogger(__name__) @@ -59,3 +62,70 @@ def partitions_for(self, topic: str) -> dict[int, PartitionMetadata]: return self.list_topics(topic).topics[topic].partitions except KafkaException as exc: raise_from_kafkaexception(exc) + + +class AsyncKafkaProducer: + """An async wrapper around `KafkaProducer` built on top of confluent-kafka. + + Calling `start` on an `AsyncKafkaProducer` instantiates a `KafkaProducer` + and starts a poll-thread. + + The poll-thread continuously polls the underlying producer so buffered messages + are sent and asyncio futures returned by the `send` method can be awaited. + """ + + def __init__( + self, + bootstrap_servers: Iterable[str] | str, + loop: asyncio.AbstractEventLoop | None = None, + **params: Unpack[KafkaClientParams], + ) -> None: + self.loop = loop or asyncio.get_running_loop() + + self.stopped = Event() + self.poll_thread = Thread(target=self.poll_loop) + + self.producer: KafkaProducer | None = None + self._bootstrap_servers = bootstrap_servers + self._producer_params = params + + def _start(self) -> None: + assert not self.stopped.is_set(), "The async producer cannot be restarted" + + self.producer = KafkaProducer(self._bootstrap_servers, **self._producer_params) + self.poll_thread.start() + + async def start(self) -> None: + # The `KafkaProducer` instantiation tries to establish a connection with + # retries, thus can block for a relatively long time. Running in the + # default executor and awaiting makes it async compatible. + await self.loop.run_in_executor(None, self._start) + + def _stop(self) -> None: + self.stopped.set() + if self.poll_thread.is_alive(): + self.poll_thread.join() + self.producer = None + + async def stop(self) -> None: + # Running all actions needed to stop in the default executor, since + # some can be blocking. + await self.loop.run_in_executor(None, self._stop) + + def poll_loop(self) -> None: + """Target of the poll-thread.""" + assert self.producer is not None, "The async producer must be started" + + while not self.stopped.is_set(): + # The call to `poll` is blocking, necessitating running this loop in its own thread. + # In case there is messages to be sent, this loop will do just that (equivalent to + # a `flush` call), otherwise it'll sleep for the given timeout (seconds). + self.producer.poll(timeout=0.1) + + async def send(self, topic: str, **params: Unpack[ProducerSendParams]) -> asyncio.Future[Message]: + assert self.producer is not None, "The async producer must be started" + + return asyncio.wrap_future( + self.producer.send(topic, **params), + loop=self.loop, + ) diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index d16fcd2ab..3c8bfc67c 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -1,5 +1,3 @@ -from aiokafka import AIOKafkaProducer -from aiokafka.errors import KafkaConnectionError from binascii import Error as B64DecodeError from collections import namedtuple from confluent_kafka.error import KafkaException @@ -13,9 +11,10 @@ TopicAuthorizationFailedError, UnknownTopicOrPartitionError, ) -from karapace.config import Config, create_client_ssl_context +from karapace.config import Config from karapace.errors import InvalidSchema from karapace.kafka.admin import KafkaAdminClient +from karapace.kafka.producer import AsyncKafkaProducer from karapace.kafka_rest_apis.authentication import ( get_auth_config_from_header, get_expiration_time_from_header, @@ -36,7 +35,7 @@ SchemaRetrievalError, ) from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType -from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient +from karapace.utils import convert_to_int, json_encode from typing import Callable, Dict, List, Optional, Tuple, Union import asyncio @@ -441,7 +440,7 @@ def __init__( self._auth_expiry = auth_expiry self._async_producer_lock = asyncio.Lock() - self._async_producer: Optional[AIOKafkaProducer] = None + self._async_producer: Optional[AsyncKafkaProducer] = None self.naming_strategy = NameStrategy(self.config["name_strategy"]) def __str__(self) -> str: @@ -461,12 +460,12 @@ def auth_expiry(self) -> datetime.datetime: def num_consumers(self) -> int: return len(self.consumer_manager.consumers) - async def _maybe_create_async_producer(self) -> AIOKafkaProducer: + async def _maybe_create_async_producer(self) -> AsyncKafkaProducer: if self._async_producer is not None: return self._async_producer if self.config["producer_acks"] == "all": - acks = "all" + acks = -1 else: acks = int(self.config["producer_acks"]) @@ -477,33 +476,34 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer: log.info("Creating async producer") - # Don't retry if creating the SSL context fails, likely a configuration issue with - # ciphers or certificate chains - ssl_context = create_client_ssl_context(self.config) - - # Don't retry if instantiating the producer fails, likely a configuration error. - producer = AIOKafkaProducer( + producer = AsyncKafkaProducer( acks=acks, bootstrap_servers=self.config["bootstrap_uri"], compression_type=self.config["producer_compression_type"], connections_max_idle_ms=self.config["connections_max_idle_ms"], linger_ms=self.config["producer_linger_ms"], - max_request_size=self.config["producer_max_request_size"], + message_max_bytes=self.config["producer_max_request_size"], metadata_max_age_ms=self.config["metadata_max_age_ms"], security_protocol=self.config["security_protocol"], - ssl_context=ssl_context, + ssl_cafile=self.config["ssl_cafile"], + ssl_certfile=self.config["ssl_certfile"], + ssl_keyfile=self.config["ssl_keyfile"], + ssl_crlfile=self.config["ssl_crlfile"], **get_kafka_client_auth_parameters_from_config(self.config), ) - try: await producer.start() - except KafkaConnectionError: + except (NoBrokersAvailable, AuthenticationFailedError): + await producer.stop() if retry: log.exception("Unable to connect to the bootstrap servers, retrying") else: log.exception("Giving up after trying to connect to the bootstrap servers") raise await asyncio.sleep(1) + except Exception: + await producer.stop() + raise else: self._async_producer = producer @@ -645,10 +645,8 @@ def init_admin_client(self): ssl_cafile=self.config["ssl_cafile"], ssl_certfile=self.config["ssl_certfile"], ssl_keyfile=self.config["ssl_keyfile"], - api_version=(1, 0, 0), metadata_max_age_ms=self.config["metadata_max_age_ms"], connections_max_idle_ms=self.config["connections_max_idle_ms"], - kafka_client=KarapaceKafkaClient, **get_kafka_client_auth_parameters_from_config(self.config, async_client=False), ) break @@ -1069,8 +1067,11 @@ async def produce_messages(self, *, topic: str, prepared_records: List) -> List: if not isinstance(result, Exception): produce_results.append( { - "offset": result.offset if result else -1, - "partition": result.topic_partition.partition if result else 0, + # In case the offset is not available, `confluent_kafka.Message.offset()` is + # `None`. To preserve backwards compatibility, we replace this with -1. + # -1 was the default `aiokafka` behaviour. + "offset": result.offset() if result and result.offset() is not None else -1, + "partition": result.partition() if result else 0, } ) diff --git a/karapace/kafka_rest_apis/authentication.py b/karapace/kafka_rest_apis/authentication.py index 19df87239..7550ceec1 100644 --- a/karapace/kafka_rest_apis/authentication.py +++ b/karapace/kafka_rest_apis/authentication.py @@ -142,6 +142,9 @@ class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync): async def token(self) -> str: return self._token + def token_with_expiry(self, _config: str | None = None) -> tuple[str, int | None]: + return (self._token, get_expiration_timestamp_from_jwt(self._token)) + class SASLOauthParams(TypedDict): sasl_mechanism: str diff --git a/tests/integration/kafka/test_producer.py b/tests/integration/kafka/test_producer.py index d82c6cce5..30f18eb7f 100644 --- a/tests/integration/kafka/test_producer.py +++ b/tests/integration/kafka/test_producer.py @@ -7,9 +7,12 @@ from confluent_kafka.admin import NewTopic from kafka.errors import MessageSizeTooLargeError, UnknownTopicOrPartitionError -from karapace.kafka.producer import KafkaProducer +from karapace.kafka.producer import AsyncKafkaProducer, KafkaProducer from karapace.kafka.types import Timestamp +from tests.integration.utils.kafka_server import KafkaServers +from typing import Iterator +import asyncio import pytest import time @@ -71,3 +74,69 @@ def test_partitions_for(self, producer: KafkaProducer, new_topic: NewTopic) -> N assert partitions[0].id == 0 assert partitions[0].replicas == [1] assert partitions[0].isrs == [1] + + +@pytest.fixture(scope="function", name="asyncproducer") +async def fixture_asyncproducer( + kafka_servers: KafkaServers, + loop: asyncio.AbstractEventLoop, +) -> Iterator[AsyncKafkaProducer]: + asyncproducer = AsyncKafkaProducer(bootstrap_servers=kafka_servers.bootstrap_servers, loop=loop) + await asyncproducer.start() + yield asyncproducer + await asyncproducer.stop() + + +class TestAsyncSend: + async def test_async_send(self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic) -> None: + key = b"key" + value = b"value" + partition = 0 + timestamp = int(time.time() * 1000) + headers = [("something", b"123"), (None, "foobar")] + + aiofut = await asyncproducer.send( + new_topic.topic, + key=key, + value=value, + partition=partition, + timestamp=timestamp, + headers=headers, + ) + message = await aiofut + + assert message.offset() == 0 + assert message.partition() == partition + assert message.topic() == new_topic.topic + assert message.key() == key + assert message.value() == value + assert message.timestamp()[0] == Timestamp.CREATE_TIME + assert message.timestamp()[1] == timestamp + + async def test_async_send_raises_for_unknown_topic(self, asyncproducer: AsyncKafkaProducer) -> None: + aiofut = await asyncproducer.send("nonexistent") + + with pytest.raises(UnknownTopicOrPartitionError): + _ = await aiofut + + async def test_async_send_raises_for_unknown_partition( + self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic + ) -> None: + aiofut = await asyncproducer.send(new_topic.topic, partition=99) + + with pytest.raises(UnknownTopicOrPartitionError): + _ = await aiofut + + async def test_async_send_raises_for_too_large_message( + self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic + ) -> None: + with pytest.raises(MessageSizeTooLargeError): + await asyncproducer.send(new_topic.topic, value=b"x" * 1000001) + + +async def test_stop_unstarted_async_producer() -> None: + asyncproducer = AsyncKafkaProducer(bootstrap_servers="irrelevant") + try: + await asyncproducer.stop() + except RuntimeError: + pytest.fail("Stopping raised RuntimeError, unstarted poll thread is not handled") diff --git a/tests/integration/test_rest.py b/tests/integration/test_rest.py index d2287f05f..97466836d 100644 --- a/tests/integration/test_rest.py +++ b/tests/integration/test_rest.py @@ -227,9 +227,9 @@ async def test_internal(rest_async: KafkaRest | None, admin_client: KafkaAdminCl assert len(results) == 1 for result in results: assert "error" in result, "Invalid result missing 'error' key" - assert result["error"] == "Unrecognized partition" + assert result["error"] == "This request is for a topic or partition that does not exist on this broker." assert "error_code" in result, "Invalid result missing 'error_code' key" - assert result["error_code"] == 1 + assert result["error_code"] == 2 assert rest_async_proxy.all_empty({"records": [{"key": {"foo": "bar"}}]}, "key") is False assert rest_async_proxy.all_empty({"records": [{"value": {"foo": "bar"}}]}, "value") is False diff --git a/tests/unit/test_authentication.py b/tests/unit/test_authentication.py index 0c6a874b5..43617b6c9 100644 --- a/tests/unit/test_authentication.py +++ b/tests/unit/test_authentication.py @@ -120,6 +120,14 @@ async def test_simple_oauth_token_provider_async_returns_configured_token() -> N assert await token_provider.token() == "TOKEN" +def test_simple_oauth_token_provider_async_returns_configured_token_and_expiry() -> None: + expiry_timestamp = 1697013997 + token = jwt.encode({"exp": expiry_timestamp}, "secret") + token_provider = SimpleOauthTokenProviderAsync(token) + + assert token_provider.token_with_expiry() == (token, expiry_timestamp) + + def test_get_client_auth_parameters_from_config_sasl_plain() -> None: config = set_config_defaults( {"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"} From a51027dbc60e09212efb0743c49978bd90ac87f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1ty=C3=A1s=20Kuti?= Date: Thu, 14 Dec 2023 09:59:27 +0100 Subject: [PATCH 2/2] Improve REST app shutdown REST webapp improvement: Using `on_cleanup` is not the correct hook to use here, as this would run _after_ the event loop has closed, making it unsuitable for cancelling background tasks for example. `on_shutdown` is triggered before the REST app shuts down, thus it's able to clean up eg. Kafka clients, background tasks, etc. properly. Before this change, the symptom of the bug is most prevalent in the Karapace REST proxy and its "idle proxy janitor" background task. Stopping the application when the janitor task is not running is straightforward, however when any `UserRestProxy` is present (ie. some requests have already been handled) and the task is running, stopping the REST proxy hangs or needs multiple signals to shut down. With the new `AIOKafkaProducer` implementation (which runs a poll-thread in the background) this results in an application that is unable to gracefully shutdown, only SIGKILL works. Using the `on_shutdown` hook fixes this issue, as we still have an event loop available to be able to cancel background tasks, etc. --- karapace/kafka_rest_apis/__init__.py | 1 + karapace/rapu.py | 2 +- tests/integration/kafka/test_producer.py | 8 -------- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/karapace/kafka_rest_apis/__init__.py b/karapace/kafka_rest_apis/__init__.py index 3c8bfc67c..54929bdf1 100644 --- a/karapace/kafka_rest_apis/__init__.py +++ b/karapace/kafka_rest_apis/__init__.py @@ -72,6 +72,7 @@ def __init__(self, config: Config) -> None: self._idle_proxy_janitor_task: Optional[asyncio.Task] = None async def close(self) -> None: + log.info("Closing REST proxy application") if self._idle_proxy_janitor_task is not None: self._idle_proxy_janitor_task.cancel() self._idle_proxy_janitor_task = None diff --git a/karapace/rapu.py b/karapace/rapu.py index 2b9decf12..b02a29fb1 100644 --- a/karapace/rapu.py +++ b/karapace/rapu.py @@ -167,7 +167,7 @@ def __init__( self.app = self._create_aiohttp_application(config=config) self.log = logging.getLogger(self.app_name) self.stats = StatsClient(config=config) - self.app.on_cleanup.append(self.close_by_app) + self.app.on_shutdown.append(self.close_by_app) self.not_ready_handler = not_ready_handler def _create_aiohttp_application(self, *, config: Config) -> aiohttp.web.Application: diff --git a/tests/integration/kafka/test_producer.py b/tests/integration/kafka/test_producer.py index 30f18eb7f..59c0da0ef 100644 --- a/tests/integration/kafka/test_producer.py +++ b/tests/integration/kafka/test_producer.py @@ -132,11 +132,3 @@ async def test_async_send_raises_for_too_large_message( ) -> None: with pytest.raises(MessageSizeTooLargeError): await asyncproducer.send(new_topic.topic, value=b"x" * 1000001) - - -async def test_stop_unstarted_async_producer() -> None: - asyncproducer = AsyncKafkaProducer(bootstrap_servers="irrelevant") - try: - await asyncproducer.stop() - except RuntimeError: - pytest.fail("Stopping raised RuntimeError, unstarted poll thread is not handled")