From 9cb2de62c5ff231e940ac8fe7a2749fde1afe70f Mon Sep 17 00:00:00 2001 From: Ali Tavallaie Date: Sun, 4 Aug 2024 09:23:00 +0330 Subject: [PATCH] topic creation within kafka protocol --- connectiva/protocols/kafka_protocol.py | 54 ++++++++++++++++++++++++-- tests/test_kafka_protocol.py | 2 + 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/connectiva/protocols/kafka_protocol.py b/connectiva/protocols/kafka_protocol.py index 803e72a..38a1b6d 100644 --- a/connectiva/protocols/kafka_protocol.py +++ b/connectiva/protocols/kafka_protocol.py @@ -1,5 +1,7 @@ -from kafka import KafkaProducer, KafkaConsumer -from kafka.errors import KafkaError +from kafka import KafkaProducer, KafkaConsumer, KafkaAdminClient +from kafka.admin import NewTopic +from kafka.errors import KafkaError, TopicAlreadyExistsError +from kafka.structs import TopicPartition from typing import Dict, Any from connectiva import CommunicationMethod, Message import json @@ -19,8 +21,11 @@ def __init__(self, **kwargs): self.endpoint = kwargs.get("endpoint") self.topic = kwargs.get("topic") self.group_id = kwargs.get("group_id") + self.partitions = kwargs.get("partitions", 1) + self.replication_factor = kwargs.get("replication_factor", 1) self.producer = None self.consumer = None + self.admin_client = None # Parse the endpoint to get broker list self.broker_list = self._parse_endpoint(self.endpoint) @@ -40,9 +45,39 @@ def _parse_endpoint(self, endpoint: str) -> list: self.logger.debug(f"Parsed brokers: {brokers}") return brokers + def create_topic(self): + """ + Create Kafka topic if it does not exist. + """ + try: + self.admin_client = KafkaAdminClient(bootstrap_servers=self.broker_list) + topic_list = self.admin_client.list_topics() + if self.topic not in topic_list: + self.logger.info(f"Creating topic {self.topic}...") + new_topic = NewTopic( + name=self.topic, + num_partitions=self.partitions, + replication_factor=self.replication_factor + ) + self.admin_client.create_topics([new_topic]) + self.logger.info(f"Topic {self.topic} created successfully!") + else: + self.logger.info(f"Topic {self.topic} already exists.") + except TopicAlreadyExistsError: + self.logger.info(f"Topic {self.topic} already exists.") + except KafkaError as e: + self.logger.error(f"Failed to create topic: {e}") + raise + finally: + if self.admin_client: + self.admin_client.close() + def connect(self): self.logger.info(f"Connecting to Kafka brokers at {self.broker_list}...") try: + # Create the topic if it doesn't exist + self.create_topic() + # Initialize Kafka producer self.producer = KafkaProducer( bootstrap_servers=self.broker_list, @@ -61,6 +96,7 @@ def connect(self): value_deserializer=lambda x: json.loads(x.decode('utf-8')) ) self.logger.info("Kafka consumer connected.") + self.consumer.subscribe([self.topic]) # Subscribe to the topic else: self.logger.info("No consumer group ID provided; skipping consumer initialization.") @@ -95,8 +131,18 @@ def receive(self) -> Message: def seek_to_end(self): """Move the consumer to the end of the log for the current topic.""" if self.consumer: - self.consumer.seek_to_end() - self.logger.info("Moved consumer to the end of the log.") + try: + partitions = self.consumer.partitions_for_topic(self.topic) + if not partitions: + self.logger.error(f"No partitions found for topic {self.topic}.") + return + + topic_partitions = [TopicPartition(self.topic, p) for p in partitions] + self.consumer.assign(topic_partitions) # Ensure partitions are assigned + self.consumer.seek_to_end() + self.logger.info("Moved consumer to the end of the log.") + except Exception as e: + self.logger.error(f"Failed to seek to end: {e}") def disconnect(self): self.logger.info("Disconnecting from Kafka...") diff --git a/tests/test_kafka_protocol.py b/tests/test_kafka_protocol.py index 83f6559..0b503ba 100644 --- a/tests/test_kafka_protocol.py +++ b/tests/test_kafka_protocol.py @@ -17,6 +17,8 @@ def setUpClass(cls): endpoint='kafka://localhost:9092', # Correct Kafka endpoint format topic='test_topic', group_id='test_group', + partitions=1, # Define the number of partitions + replication_factor=1, # Define the replication factor log=True, # Enable logging to stdout log_file=cls.log_file, # Enable logging to file log_level="DEBUG"