Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement async confluent-kafka producer #775

Merged
merged 2 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion karapace/kafka/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,12 @@ def token_with_expiry(self, config: str | None) -> tuple[str, int | None]:


class KafkaClientParams(TypedDict, total=False):
acks: int | None
client_id: str | None
connections_max_idle_ms: int | None
max_block_ms: int | None
compression_type: str | None
linger_ms: int | None
message_max_bytes: int | None
metadata_max_age_ms: int | None
retries: int | None
sasl_mechanism: str | None
Expand All @@ -83,6 +86,7 @@ class KafkaClientParams(TypedDict, total=False):
socket_timeout_ms: int | None
ssl_cafile: str | None
ssl_certfile: str | None
ssl_crlfile: str | None
ssl_keyfile: str | None
sasl_oauth_token_provider: TokenWithExpiryProvider
# Consumer-only
Expand Down Expand Up @@ -121,8 +125,12 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para

config: dict[str, int | str | Callable | None] = {
"bootstrap.servers": bootstrap_servers,
"acks": params.get("acks"),
"client.id": params.get("client_id"),
"connections.max.idle.ms": params.get("connections_max_idle_ms"),
"compression.type": params.get("compression_type"),
"linger.ms": params.get("linger_ms"),
"message.max.bytes": params.get("message_max_bytes"),
"metadata.max.age.ms": params.get("metadata_max_age_ms"),
"retries": params.get("retries"),
"sasl.mechanism": params.get("sasl_mechanism"),
Expand All @@ -132,6 +140,7 @@ def _get_config_from_params(self, bootstrap_servers: Iterable[str] | str, **para
"socket.timeout.ms": params.get("socket_timeout_ms"),
"ssl.ca.location": params.get("ssl_cafile"),
"ssl.certificate.location": params.get("ssl_certfile"),
"ssl.crl.location": params.get("ssl_crlfile"),
"ssl.key.location": params.get("ssl_keyfile"),
"error_cb": self._error_callback,
# Consumer-only
Expand Down
72 changes: 71 additions & 1 deletion karapace/kafka/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@

from __future__ import annotations

from collections.abc import Iterable
from concurrent.futures import Future
from confluent_kafka import Message, Producer
from confluent_kafka.admin import PartitionMetadata
from confluent_kafka.error import KafkaError, KafkaException
from functools import partial
from karapace.kafka.common import _KafkaConfigMixin, raise_from_kafkaexception, translate_from_kafkaerror
from karapace.kafka.common import _KafkaConfigMixin, KafkaClientParams, raise_from_kafkaexception, translate_from_kafkaerror
from threading import Event, Thread
from typing import cast, TypedDict
from typing_extensions import Unpack

import asyncio
import logging

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,3 +62,70 @@ def partitions_for(self, topic: str) -> dict[int, PartitionMetadata]:
return self.list_topics(topic).topics[topic].partitions
except KafkaException as exc:
raise_from_kafkaexception(exc)


class AsyncKafkaProducer:
"""An async wrapper around `KafkaProducer` built on top of confluent-kafka.

Calling `start` on an `AsyncKafkaProducer` instantiates a `KafkaProducer`
and starts a poll-thread.

The poll-thread continuously polls the underlying producer so buffered messages
are sent and asyncio futures returned by the `send` method can be awaited.
"""

def __init__(
self,
bootstrap_servers: Iterable[str] | str,
aiven-anton marked this conversation as resolved.
Show resolved Hide resolved
loop: asyncio.AbstractEventLoop | None = None,
**params: Unpack[KafkaClientParams],
) -> None:
self.loop = loop or asyncio.get_running_loop()

self.stopped = Event()
self.poll_thread = Thread(target=self.poll_loop)

self.producer: KafkaProducer | None = None
self._bootstrap_servers = bootstrap_servers
self._producer_params = params

def _start(self) -> None:
assert not self.stopped.is_set(), "The async producer cannot be restarted"

self.producer = KafkaProducer(self._bootstrap_servers, **self._producer_params)
self.poll_thread.start()

async def start(self) -> None:
# The `KafkaProducer` instantiation tries to establish a connection with
# retries, thus can block for a relatively long time. Running in the
# default executor and awaiting makes it async compatible.
await self.loop.run_in_executor(None, self._start)

def _stop(self) -> None:
self.stopped.set()
if self.poll_thread.is_alive():
self.poll_thread.join()
self.producer = None

async def stop(self) -> None:
# Running all actions needed to stop in the default executor, since
# some can be blocking.
await self.loop.run_in_executor(None, self._stop)

def poll_loop(self) -> None:
matyaskuti marked this conversation as resolved.
Show resolved Hide resolved
"""Target of the poll-thread."""
assert self.producer is not None, "The async producer must be started"

while not self.stopped.is_set():
# The call to `poll` is blocking, necessitating running this loop in its own thread.
# In case there is messages to be sent, this loop will do just that (equivalent to
# a `flush` call), otherwise it'll sleep for the given timeout (seconds).
self.producer.poll(timeout=0.1)

async def send(self, topic: str, **params: Unpack[ProducerSendParams]) -> asyncio.Future[Message]:
assert self.producer is not None, "The async producer must be started"

return asyncio.wrap_future(
self.producer.send(topic, **params),
loop=self.loop,
)
44 changes: 23 additions & 21 deletions karapace/kafka_rest_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from aiokafka import AIOKafkaProducer
from aiokafka.errors import KafkaConnectionError
from binascii import Error as B64DecodeError
from collections import namedtuple
from confluent_kafka.error import KafkaException
Expand All @@ -13,9 +11,10 @@
TopicAuthorizationFailedError,
UnknownTopicOrPartitionError,
)
from karapace.config import Config, create_client_ssl_context
from karapace.config import Config
from karapace.errors import InvalidSchema
from karapace.kafka.admin import KafkaAdminClient
from karapace.kafka.producer import AsyncKafkaProducer
from karapace.kafka_rest_apis.authentication import (
get_auth_config_from_header,
get_expiration_time_from_header,
Expand All @@ -36,7 +35,7 @@
SchemaRetrievalError,
)
from karapace.typing import NameStrategy, SchemaId, Subject, SubjectType
from karapace.utils import convert_to_int, json_encode, KarapaceKafkaClient
from karapace.utils import convert_to_int, json_encode
from typing import Callable, Dict, List, Optional, Tuple, Union

import asyncio
Expand Down Expand Up @@ -73,6 +72,7 @@ def __init__(self, config: Config) -> None:
self._idle_proxy_janitor_task: Optional[asyncio.Task] = None

async def close(self) -> None:
log.info("Closing REST proxy application")
if self._idle_proxy_janitor_task is not None:
self._idle_proxy_janitor_task.cancel()
self._idle_proxy_janitor_task = None
Expand Down Expand Up @@ -441,7 +441,7 @@ def __init__(
self._auth_expiry = auth_expiry

self._async_producer_lock = asyncio.Lock()
self._async_producer: Optional[AIOKafkaProducer] = None
self._async_producer: Optional[AsyncKafkaProducer] = None
self.naming_strategy = NameStrategy(self.config["name_strategy"])

def __str__(self) -> str:
Expand All @@ -461,12 +461,12 @@ def auth_expiry(self) -> datetime.datetime:
def num_consumers(self) -> int:
return len(self.consumer_manager.consumers)

async def _maybe_create_async_producer(self) -> AIOKafkaProducer:
async def _maybe_create_async_producer(self) -> AsyncKafkaProducer:
if self._async_producer is not None:
return self._async_producer

if self.config["producer_acks"] == "all":
acks = "all"
acks = -1
else:
acks = int(self.config["producer_acks"])

Expand All @@ -477,33 +477,34 @@ async def _maybe_create_async_producer(self) -> AIOKafkaProducer:

log.info("Creating async producer")

# Don't retry if creating the SSL context fails, likely a configuration issue with
# ciphers or certificate chains
ssl_context = create_client_ssl_context(self.config)

# Don't retry if instantiating the producer fails, likely a configuration error.
producer = AIOKafkaProducer(
producer = AsyncKafkaProducer(
acks=acks,
bootstrap_servers=self.config["bootstrap_uri"],
compression_type=self.config["producer_compression_type"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
linger_ms=self.config["producer_linger_ms"],
max_request_size=self.config["producer_max_request_size"],
message_max_bytes=self.config["producer_max_request_size"],
metadata_max_age_ms=self.config["metadata_max_age_ms"],
security_protocol=self.config["security_protocol"],
ssl_context=ssl_context,
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
ssl_crlfile=self.config["ssl_crlfile"],
**get_kafka_client_auth_parameters_from_config(self.config),
)

try:
await producer.start()
except KafkaConnectionError:
except (NoBrokersAvailable, AuthenticationFailedError):
await producer.stop()
if retry:
log.exception("Unable to connect to the bootstrap servers, retrying")
else:
log.exception("Giving up after trying to connect to the bootstrap servers")
raise
await asyncio.sleep(1)
except Exception:
await producer.stop()
raise
else:
self._async_producer = producer

Expand Down Expand Up @@ -645,10 +646,8 @@ def init_admin_client(self):
ssl_cafile=self.config["ssl_cafile"],
ssl_certfile=self.config["ssl_certfile"],
ssl_keyfile=self.config["ssl_keyfile"],
api_version=(1, 0, 0),
metadata_max_age_ms=self.config["metadata_max_age_ms"],
connections_max_idle_ms=self.config["connections_max_idle_ms"],
kafka_client=KarapaceKafkaClient,
**get_kafka_client_auth_parameters_from_config(self.config, async_client=False),
)
break
Expand Down Expand Up @@ -1069,8 +1068,11 @@ async def produce_messages(self, *, topic: str, prepared_records: List) -> List:
if not isinstance(result, Exception):
produce_results.append(
{
"offset": result.offset if result else -1,
"partition": result.topic_partition.partition if result else 0,
# In case the offset is not available, `confluent_kafka.Message.offset()` is
# `None`. To preserve backwards compatibility, we replace this with -1.
# -1 was the default `aiokafka` behaviour.
"offset": result.offset() if result and result.offset() is not None else -1,
"partition": result.partition() if result else 0,
}
)

Expand Down
3 changes: 3 additions & 0 deletions karapace/kafka_rest_apis/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ class SimpleOauthTokenProviderAsync(AbstractTokenProviderAsync):
async def token(self) -> str:
return self._token

def token_with_expiry(self, _config: str | None = None) -> tuple[str, int | None]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary for now, as we still have the consumers from aiokafka present. Once that's replaced, we can simplify these token providers, to have a single one with this method, that doesn't really depend on the abstract classes from any library. (confluent-kafka only needs a callable as the OAuth callback)

return (self._token, get_expiration_timestamp_from_jwt(self._token))


class SASLOauthParams(TypedDict):
sasl_mechanism: str
Expand Down
2 changes: 1 addition & 1 deletion karapace/rapu.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(
self.app = self._create_aiohttp_application(config=config)
self.log = logging.getLogger(self.app_name)
self.stats = StatsClient(config=config)
self.app.on_cleanup.append(self.close_by_app)
self.app.on_shutdown.append(self.close_by_app)
self.not_ready_handler = not_ready_handler

def _create_aiohttp_application(self, *, config: Config) -> aiohttp.web.Application:
Expand Down
63 changes: 62 additions & 1 deletion tests/integration/kafka/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

from confluent_kafka.admin import NewTopic
from kafka.errors import MessageSizeTooLargeError, UnknownTopicOrPartitionError
from karapace.kafka.producer import KafkaProducer
from karapace.kafka.producer import AsyncKafkaProducer, KafkaProducer
from karapace.kafka.types import Timestamp
from tests.integration.utils.kafka_server import KafkaServers
from typing import Iterator

import asyncio
import pytest
import time

Expand Down Expand Up @@ -71,3 +74,61 @@ def test_partitions_for(self, producer: KafkaProducer, new_topic: NewTopic) -> N
assert partitions[0].id == 0
assert partitions[0].replicas == [1]
assert partitions[0].isrs == [1]


@pytest.fixture(scope="function", name="asyncproducer")
async def fixture_asyncproducer(
kafka_servers: KafkaServers,
loop: asyncio.AbstractEventLoop,
) -> Iterator[AsyncKafkaProducer]:
asyncproducer = AsyncKafkaProducer(bootstrap_servers=kafka_servers.bootstrap_servers, loop=loop)
await asyncproducer.start()
yield asyncproducer
await asyncproducer.stop()


class TestAsyncSend:
async def test_async_send(self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic) -> None:
key = b"key"
value = b"value"
partition = 0
timestamp = int(time.time() * 1000)
headers = [("something", b"123"), (None, "foobar")]

aiofut = await asyncproducer.send(
new_topic.topic,
key=key,
value=value,
partition=partition,
timestamp=timestamp,
headers=headers,
)
message = await aiofut

assert message.offset() == 0
assert message.partition() == partition
assert message.topic() == new_topic.topic
assert message.key() == key
assert message.value() == value
assert message.timestamp()[0] == Timestamp.CREATE_TIME
assert message.timestamp()[1] == timestamp
aiven-anton marked this conversation as resolved.
Show resolved Hide resolved

async def test_async_send_raises_for_unknown_topic(self, asyncproducer: AsyncKafkaProducer) -> None:
aiofut = await asyncproducer.send("nonexistent")

with pytest.raises(UnknownTopicOrPartitionError):
_ = await aiofut

async def test_async_send_raises_for_unknown_partition(
self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic
) -> None:
aiofut = await asyncproducer.send(new_topic.topic, partition=99)

with pytest.raises(UnknownTopicOrPartitionError):
_ = await aiofut

async def test_async_send_raises_for_too_large_message(
self, asyncproducer: AsyncKafkaProducer, new_topic: NewTopic
) -> None:
with pytest.raises(MessageSizeTooLargeError):
await asyncproducer.send(new_topic.topic, value=b"x" * 1000001)
4 changes: 2 additions & 2 deletions tests/integration/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,9 @@ async def test_internal(rest_async: KafkaRest | None, admin_client: KafkaAdminCl
assert len(results) == 1
for result in results:
assert "error" in result, "Invalid result missing 'error' key"
assert result["error"] == "Unrecognized partition"
assert result["error"] == "This request is for a topic or partition that does not exist on this broker."
assert "error_code" in result, "Invalid result missing 'error_code' key"
assert result["error_code"] == 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aiokafka raises an AssertionError in this case, which is considered non-retriable, however from confluent-kafka an UnknownTopicOrPartitionError equivalent is raised, which is considered retriable - hence the difference in the error codes.

I'm not entirely sure which behaviour we'd like to go forward with, the "correct" one with code 2, or preserve the previous. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would go with what the confluent-kafka implements.

assert result["error_code"] == 2

assert rest_async_proxy.all_empty({"records": [{"key": {"foo": "bar"}}]}, "key") is False
assert rest_async_proxy.all_empty({"records": [{"value": {"foo": "bar"}}]}, "value") is False
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ async def test_simple_oauth_token_provider_async_returns_configured_token() -> N
assert await token_provider.token() == "TOKEN"


def test_simple_oauth_token_provider_async_returns_configured_token_and_expiry() -> None:
expiry_timestamp = 1697013997
token = jwt.encode({"exp": expiry_timestamp}, "secret")
token_provider = SimpleOauthTokenProviderAsync(token)

assert token_provider.token_with_expiry() == (token, expiry_timestamp)


def test_get_client_auth_parameters_from_config_sasl_plain() -> None:
config = set_config_defaults(
{"sasl_mechanism": "PLAIN", "sasl_plain_username": "username", "sasl_plain_password": "password"}
Expand Down