Skip to content

Commit

Permalink
Implement async confluent-kafka Kafka producer
Browse files Browse the repository at this point in the history
  • Loading branch information
Mátyás Kuti committed Dec 6, 2023
1 parent 0039127 commit 334035e
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 20 deletions.
11 changes: 10 additions & 1 deletion karapace/kafka/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@ def token_with_expiry(self, config: str | None) -> tuple[str, int | None]:


class KafkaClientParams(TypedDict, total=False):
acks: int | Literal["all"] | None
client_id: str | None
connections_max_idle_ms: int | None
max_block_ms: int | None
compression_type: str | None
linger_ms: int | None
message_max_bytes: int | None
metadata_max_age_ms: int | None
retries: int | None
sasl_mechanism: str | None
Expand All @@ -83,6 +86,7 @@ class KafkaClientParams(TypedDict, total=False):
socket_timeout_ms: int | None
ssl_cafile: str | None
ssl_certfile: str | None
ssl_crlfile: str | None
ssl_keyfile: str | None
sasl_oauth_token_provider: TokenWithExpiryProvider
# Consumer-only
Expand Down Expand Up @@ -121,8 +125,12 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para

config: dict[str, int | str | Callable | None] = {
"bootstrap.servers": bootstrap_servers,
"acks": params.get("acks"),
"client.id": params.get("client_id"),
"connections.max.idle.ms": params.get("connections_max_idle_ms"),
"compression.type": params.get("compression_type"),
"linger.ms": params.get("linger_ms"),
"message.max.bytes": params.get("message_max_bytes"),
"metadata.max.age.ms": params.get("metadata_max_age_ms"),
"retries": params.get("retries"),
"sasl.mechanism": params.get("sasl_mechanism"),
Expand All @@ -132,6 +140,7 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para
"socket.timeout.ms": params.get("socket_timeout_ms"),
"ssl.ca.location": params.get("ssl_cafile"),
"ssl.certificate.location": params.get("ssl_certfile"),
"ssl.crl.location": params.get("ssl_crlfile"),
"ssl.key.location": params.get("ssl_keyfile"),
"error_cb": self._error_callback,
# Consumer-only
Expand Down
55 changes: 54 additions & 1 deletion karapace/kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@

from __future__ import annotations

from collections.abc import Iterable
from concurrent.futures import Future
from confluent_kafka import Producer
from functools import partial
from karapace.kafka.common import _KafkaConfigMixin, raise_from_kafkaexception, translate_from_kafkaerror
from karapace.kafka.common import _KafkaConfigMixin, KafkaClientParams, raise_from_kafkaexception, translate_from_kafkaerror
from karapace.kafka.types import KafkaError, KafkaException, Message, PartitionMetadata
from threading import Thread
from typing import cast, TypedDict
from typing_extensions import Unpack

import asyncio
import logging

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,3 +61,53 @@ def partitions_for(self, topic: str) -> dict[int, PartitionMetadata]:
return self.list_topics(topic).topics[topic].partitions
except KafkaException as exc:
raise_from_kafkaexception(exc)


class AIOKafkaProducer:
"""An async wrapper around `KafkaProducer` built on top of confluent-kafka.
Starting `AIOKafkaProducer` instantiates a `KafkaProducer` and starts a poll-thread.
The poll-thread continuously polls the underlying producer so buffered messages
are sent and asyncio futures returned by the `send` method can be awaited.
"""

def __init__(
self,
bootstrap_servers: Iterable[str] | str,
loop: asyncio.AbstractEventLoop | None = None,
**params: Unpack[KafkaClientParams],
) -> None:
self.loop = loop or asyncio.get_event_loop()

self.producer: KafkaProducer | None = None
self._bootstrap_servers = bootstrap_servers
self._producer_params = params

self.stopped = False

self.poll_thread = Thread(target=self.poll_loop)

async def start(self) -> None:
self.producer = KafkaProducer(self._bootstrap_servers, **self._producer_params)
self.poll_thread.start()

async def stop(self) -> None:
self.stopped = True
self.poll_thread.join()
self.producer = None

def poll_loop(self) -> None:
"""Target of the poll-thread."""
assert self.producer is not None

while not self.stopped:
self.producer.poll(0.1)

async def send(self, topic: str, **params: Unpack[ProducerSendParams]) -> asyncio.Future[Message]:
assert self.producer is not None

return asyncio.wrap_future(
self.producer.send(topic, **params),
loop=self.loop,
)
26 changes: 11 additions & 15 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from aiokafka import AIOKafkaProducer
from aiokafka.errors import KafkaConnectionError
from binascii import Error as B64DecodeError
from collections import namedtuple
from contextlib import AsyncExitStack
Expand All @@ -12,9 +10,10 @@
TopicAuthorizationFailedError,
UnknownTopicOrPartitionError,
)
from karapace.config import Config, create_client_ssl_context
from karapace.config import Config
from karapace.errors import InvalidSchema
from karapace.kafka.admin import KafkaAdminClient
from karapace.kafka.producer import AIOKafkaProducer
from karapace.kafka.types import KafkaException
from karapace.kafka_rest_apis.authentication import (
get_auth_config_from_header,
Expand All @@ -36,7 +35,7 @@
SchemaRetrievalError,
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from karapace.utils import convert_to_int, json_encode
from typing import Callable, Dict, List, Optional, Tuple, Union

import asyncio
Expand Down Expand Up @@ -477,27 +476,26 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer:

log.info("Creating async producer")

# Don't retry if creating the SSL context fails, likely a configuration issue with
# ciphers or certificate chains
ssl_context = create_client_ssl_context(self.config)

# Don't retry if instantiating the producer fails, likely a configuration error.
producer = AIOKafkaProducer(
acks=acks,
bootstrap_servers=self.config["bootstrap_uri"],
compression_type=self.config["producer_compression_type"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
linger_ms=self.config["producer_linger_ms"],
max_request_size=self.config["producer_max_request_size"],
message_max_bytes=self.config["producer_max_request_size"],
metadata_max_age_ms=self.config["metadata_max_age_ms"],
security_protocol=self.config["security_protocol"],
ssl_context=ssl_context,
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
ssl_crlfile=self.config["ssl_crlfile"],
**get_kafka_client_auth_parameters_from_config(self.config),
)

try:
await producer.start()
except KafkaConnectionError:
except (NoBrokersAvailable, AuthenticationFailedError):
if retry:
log.exception("Unable to connect to the bootstrap servers, retrying")
else:
Expand Down Expand Up @@ -645,10 +643,8 @@ def init_admin_client(self):
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
api_version=(1, 0, 0),
metadata_max_age_ms=self.config["metadata_max_age_ms"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
kafka_client=KarapaceKafkaClient,
**get_kafka_client_auth_parameters_from_config(self.config, async_client=False),
)
break
Expand Down Expand Up @@ -1069,8 +1065,8 @@ async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
if not isinstance(result, Exception):
produce_results.append(
{
"offset": result.offset if result else -1,
"partition": result.topic_partition.partition if result else 0,
"offset": result.offset() if result else -1,
"partition": result.partition() if result else 0,
}
)

Expand Down
3 changes: 3 additions & 0 deletions karapace/kafka_rest_apis/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync):
async def token(self) -> str:
return self._token

def token_with_expiry(self, _config: str | None = None) -> tuple[str, int | None]:
return (self._token, get_expiration_timestamp_from_jwt(self._token))


class SASLOauthParams(TypedDict):
sasl_mechanism: str
Expand Down
67 changes: 66 additions & 1 deletion tests/integration/kafka/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from __future__ import annotations

from kafka.errors import MessageSizeTooLargeError, UnknownTopicOrPartitionError
from karapace.kafka.producer import KafkaProducer
from karapace.kafka.producer import AIOKafkaProducer, KafkaProducer
from karapace.kafka.types import NewTopic, Timestamp
from tests.integration.utils.kafka_server import KafkaServers
from typing import Iterator

import asyncio
import pytest
import time

Expand Down Expand Up @@ -70,3 +73,65 @@ def test_partitions_for(self, producer: KafkaProducer, new_topic: NewTopic) -> N
assert partitions[0].id == 0
assert partitions[0].replicas == [1]
assert partitions[0].isrs == [1]


@pytest.fixture(scope="function", name="aioproducer")
async def fixture_aioproducer(
kafka_servers: KafkaServers,
loop: asyncio.AbstractEventLoop,
) -> Iterator[AIOKafkaProducer]:
try:
aioproducer = AIOKafkaProducer(bootstrap_servers=kafka_servers.bootstrap_servers, loop=loop)
await aioproducer.start()
yield aioproducer
finally:
aioproducer.stop()


class TestAsyncSend:
async def test_async_send(self, aioproducer: AIOKafkaProducer, new_topic: NewTopic) -> None:
key = b"key"
value = b"value"
partition = 0
timestamp = int(time.time() * 1000)
headers = [("something", b"123")]

aiofut = await aioproducer.send(
new_topic.topic,
key=key,
value=value,
partition=partition,
timestamp=timestamp,
headers=headers,
)
message = await aiofut

assert message.offset() == 0
assert message.partition() == partition
assert message.topic() == new_topic.topic
assert message.key() == key
assert message.value() == value
assert message.timestamp()[0] == Timestamp.CREATE_TIME
assert message.timestamp()[1] == timestamp

async def test_async_send_raises_for_unknown_topic(self, aioproducer: AIOKafkaProducer) -> None:
aiofut = await aioproducer.send("nonexistent")

with pytest.raises(UnknownTopicOrPartitionError):
_ = await aiofut

async def test_async_send_raises_for_unknown_partition(self, aioproducer: AIOKafkaProducer, new_topic: NewTopic) -> None:
aiofut = await aioproducer.send(new_topic.topic, partition=99)

with pytest.raises(UnknownTopicOrPartitionError):
_ = await aiofut

async def test_async_send_raises_for_too_large_message(self, aioproducer: AIOKafkaProducer, new_topic: NewTopic) -> None:
with pytest.raises(MessageSizeTooLargeError):
await aioproducer.send(new_topic.topic, value=b"x" * 1000001)

async def test_send_cannot_be_called_on_unstarted_aioproducer(self) -> None:
aioproducer = AIOKafkaProducer(bootstrap_servers="irrelevant")

with pytest.raises(AssertionError):
await aioproducer.send("irrelevant")
4 changes: 2 additions & 2 deletions tests/integration/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ async def test_internal(rest_async: KafkaRest | None, admin_client: KafkaAdminCl
assert len(results) == 1
for result in results:
assert "error" in result, "Invalid result missing 'error' key"
assert result["error"] == "Unrecognized partition"
assert result["error"] == "This request is for a topic or partition that does not exist on this broker."
assert "error_code" in result, "Invalid result missing 'error_code' key"
assert result["error_code"] == 1
assert result["error_code"] == 2

assert rest_async_proxy.all_empty({"records": [{"key": {"foo": "bar"}}]}, "key") is False
assert rest_async_proxy.all_empty({"records": [{"value": {"foo": "bar"}}]}, "value") is False
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ async def test_simple_oauth_token_provider_async_returns_configured_token() -> N
assert await token_provider.token() == "TOKEN"


def test_simple_oauth_token_provider_async_returns_configured_token_and_expiry() -> None:
expiry_timestamp = 1697013997
token = jwt.encode({"exp": expiry_timestamp}, "secret")
token_provider = SimpleOauthTokenProviderAsync(token)

assert token_provider.token_with_expiry() == (token, expiry_timestamp)


def test_get_client_auth_parameters_from_config_sasl_plain() -> None:
config = set_config_defaults(
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}
Expand Down

0 comments on commit 334035e

Please sign in to comment.