Skip to content

Commit

Permalink
fix: STREAMP-4251: Add state on kafka event receiver. (#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurovojha authored Mar 11, 2024
1 parent 9b3929c commit 6ae5315
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 37 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

### Added
- Added `State` property to `KafkaEventReceiver`

## [2.2.2] 2023-02-07
### Bugfix
- Correctly cascade deletes from stream to schema when `entityView.exist.check.enabled` is `false`.
Expand Down
5 changes: 5 additions & 0 deletions state/kafka-receiver/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (C) 2018-2023 Expedia, Inc.
* Copyright (C) 2018-2024 Expedia, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
package com.expediagroup.streamplatform.streamregistry.state.kafka;

import static com.expediagroup.streamplatform.streamregistry.state.internal.EventCorrelator.CORRELATION_ID;
import static com.expediagroup.streamplatform.streamregistry.state.kafka.KafkaEventReceiver.State.*;
import static com.expediagroup.streamplatform.streamregistry.state.model.event.Event.LOAD_COMPLETE;
import static io.confluent.kafka.serializers.KafkaAvroDeserializerConfig.SCHEMA_REGISTRY_URL_CONFIG;
import static io.confluent.kafka.serializers.KafkaAvroDeserializerConfig.SPECIFIC_AVRO_READER_CONFIG;
Expand All @@ -35,8 +36,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.*;

import lombok.Builder;
import lombok.NonNull;
Expand Down Expand Up @@ -70,27 +70,33 @@ public class KafkaEventReceiver implements EventReceiver {
private static final int THREAD_POOL_SIZE = 2;


@NonNull private final Config config;
@NonNull
private final Config config;
private final EventCorrelator correlator;
@NonNull private final AvroConverter converter;
@NonNull private final KafkaConsumer<AvroKey, AvroValue> consumer;
@NonNull private final ScheduledExecutorService executorService;
@NonNull
private final AvroConverter converter;
@NonNull
private final KafkaConsumer<AvroKey, AvroValue> consumer;
@NonNull
private final ScheduledExecutorService executorService;
private final AtomicReference<State> state = new AtomicReference<>(CREATED);

private volatile boolean shuttingDown = false;
private final AtomicBoolean started = new AtomicBoolean(false);

public KafkaEventReceiver(Config config, EventCorrelator correlator, Configurator<KafkaConsumer<AvroKey, AvroValue>> consumerConfigurator) {
this(
config,
correlator,
new AvroConverter(),
getKafkaConsumer(config, consumerConfigurator),
newScheduledThreadPool(THREAD_POOL_SIZE)
config,
correlator,
new AvroConverter(),
getKafkaConsumer(config, consumerConfigurator),
newScheduledThreadPool(THREAD_POOL_SIZE)
);
}

public KafkaEventReceiver(Config config, EventCorrelator correlator) {
this(config, correlator, kafkaConsumer -> {});
this(config, correlator, kafkaConsumer -> {
});
}

public KafkaEventReceiver(Config config) {
Expand All @@ -105,14 +111,15 @@ private static KafkaConsumer<AvroKey, AvroValue> getKafkaConsumer(Config config,

@Override
public void receive(EventReceiverListener listener) {
if(started.getAndSet(true)) {
if (state.getAndSet(RUNNING) != CREATED) {
throw new IllegalStateException("Only a single EventReceiverListener is supported");
}
executorService.execute(() -> {
try {
consume(listener);
} catch (Exception e) {
log.error("Receiving failed", e);
state.set(ERROR);
throw e;
}
});
Expand All @@ -121,7 +128,7 @@ public void receive(EventReceiverListener listener) {
void consume(EventReceiverListener listener) {
val currentOffset = new AtomicLong(0L);
val progressLogger = executorService
.scheduleAtFixedRate(() -> log.info("Current offset {}", currentOffset.get()), 10, 10, SECONDS);
.scheduleAtFixedRate(() -> log.info("Current offset {}", currentOffset.get()), 10, 10, SECONDS);

val topicPartition = new TopicPartition(config.getTopic(), 0);
val topicPartitions = Collections.singletonList(topicPartition);
Expand All @@ -146,7 +153,7 @@ void consume(EventReceiverListener listener) {
loaded = true;
}

while (!shuttingDown) {
while (state.get() == RUNNING) {
for (ConsumerRecord<AvroKey, AvroValue> record : consumer.poll(Duration.ofMillis(100))) {
val event = converter.toModel(record.key(), record.value());
currentOffset.set(record.offset());
Expand Down Expand Up @@ -179,9 +186,14 @@ private void receiveCorrelationId(ConsumerRecord<?, ?> record) {

@Override
public void close() {
shuttingDown = true;
state.set(PENDING_SHUTDOWN);
executorService.shutdown();
consumer.close();
state.set(NOT_RUNNING);
}

public State getState() {
return state.get();
}

static Map<String, Object> consumerConfig(Config config) {
Expand Down Expand Up @@ -212,4 +224,12 @@ public static class Config {
@NonNull String groupId;
Map<String, Object> properties;
}

public enum State {
CREATED,
RUNNING,
ERROR,
PENDING_SHUTDOWN,
NOT_RUNNING
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.expediagroup.streamplatform.streamregistry.state.kafka;

import static com.expediagroup.streamplatform.streamregistry.state.internal.EventCorrelator.CORRELATION_ID;
import static com.expediagroup.streamplatform.streamregistry.state.kafka.KafkaEventReceiver.State.*;
import static com.expediagroup.streamplatform.streamregistry.state.model.event.Event.LOAD_COMPLETE;
import static io.confluent.kafka.serializers.AbstractKafkaSchemaSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG;
import static io.confluent.kafka.serializers.KafkaAvroDeserializerConfig.SPECIFIC_AVRO_READER_CONFIG;
Expand All @@ -29,19 +30,16 @@
import static org.apache.kafka.clients.consumer.ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG;
import static org.apache.kafka.clients.producer.ProducerConfig.BOOTSTRAP_SERVERS_CONFIG;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.*;

import lombok.val;

Expand All @@ -54,6 +52,8 @@
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.header.internals.RecordHeader;
import org.apache.kafka.common.header.internals.RecordHeaders;
import org.awaitility.*;
import org.awaitility.core.*;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -73,16 +73,26 @@

@RunWith(MockitoJUnitRunner.Silent.class)
public class KafkaEventReceiverTest {
@Mock private KafkaEventReceiver.Config config;
@Mock private EventCorrelator correlator;
@Mock private AvroConverter converter;
@Mock private KafkaConsumer<AvroKey, AvroValue> consumer;
@Mock private EventReceiverListener listener;
@Mock private PartitionInfo partitionInfo;
@Mock private ConsumerRecord<AvroKey, AvroValue> record;
@Mock private AvroKey avroKey;
@Mock private AvroValue avroValue;
@Mock private Event event;
@Mock
private KafkaEventReceiver.Config config;
@Mock
private EventCorrelator correlator;
@Mock
private AvroConverter converter;
@Mock
private KafkaConsumer<AvroKey, AvroValue> consumer;
@Mock
private EventReceiverListener listener;
@Mock
private PartitionInfo partitionInfo;
@Mock
private ConsumerRecord<AvroKey, AvroValue> record;
@Mock
private AvroKey avroKey;
@Mock
private AvroValue avroValue;
@Mock
private Event event;

private final ScheduledExecutorService executorService = newScheduledThreadPool(2);

Expand All @@ -91,6 +101,7 @@ public class KafkaEventReceiverTest {
private final String topic = "topic";
private final TopicPartition topicPartition = new TopicPartition(topic, 0);
private final List<TopicPartition> topicPartitions = Collections.singletonList(topicPartition);
private final ConditionFactory await = Awaitility.await().atMost(2, SECONDS);

@Before
public void before() {
Expand All @@ -109,14 +120,16 @@ public void typical() throws Exception {
when(converter.toModel(avroKey, avroValue)).thenReturn(event);
when(record.headers()).thenReturn(new RecordHeaders(Collections.singletonList(new RecordHeader(CORRELATION_ID, "foo".getBytes(UTF_8)))));
val latch = new CountDownLatch(1);
doAnswer((correlationId) ->{
doAnswer((correlationId) -> {
latch.countDown();
return null;
}).when(correlator).received(anyString());

underTest.receive(listener);
assertThat(underTest.getState(), is(RUNNING));
latch.await(1, SECONDS);
underTest.close();
assertThat(underTest.getState(), is(NOT_RUNNING));

val inOrder = Mockito.inOrder(consumer, listener, correlator);
inOrder.verify(consumer).assign(topicPartitions);
Expand All @@ -139,14 +152,16 @@ public void listenerThrowsException() throws Exception {
when(record.headers()).thenReturn(new RecordHeaders(Collections.singletonList(new RecordHeader(CORRELATION_ID, "foo".getBytes(UTF_8)))));
doThrow(new RuntimeException("listener error")).when(listener).onEvent(event);
val latch = new CountDownLatch(1);
doAnswer((correlationId) ->{
doAnswer((correlationId) -> {
latch.countDown();
return null;
}).when(correlator).received(anyString());

underTest.receive(listener);
latch.await(1, SECONDS);
assertThat(underTest.getState(), is(RUNNING));
underTest.close();
assertThat(underTest.getState(), is(NOT_RUNNING));

val inOrder = Mockito.inOrder(consumer, listener, correlator);
inOrder.verify(consumer).assign(topicPartitions);
Expand All @@ -156,6 +171,63 @@ public void listenerThrowsException() throws Exception {
inOrder.verify(correlator).received("foo");
}

@Test
public void errorWhenMoreThanOnePartition() {
when(config.getTopic()).thenReturn(topic);
val multiplePartitions = new ArrayList<PartitionInfo>() {{
add(partitionInfo);
add(partitionInfo);
}};
when(consumer.partitionsFor(topic)).thenReturn(multiplePartitions);

underTest.receive(listener);
verify(consumer, timeout(100)).partitionsFor(topic);
await.untilAsserted(() -> assertThat(underTest.getState(), is(ERROR)));

underTest.close();
assertThat(underTest.getState(), is(NOT_RUNNING));
}

@Test
public void errorWhenRunningButUnableToPoll() throws Exception {
val polls = new AtomicInteger(0);
when(config.getTopic()).thenReturn(topic);
when(consumer.partitionsFor(topic)).thenReturn(Collections.singletonList(partitionInfo));
when(consumer.beginningOffsets(topicPartitions)).thenReturn(Collections.singletonMap(topicPartition, 0L));
when(consumer.endOffsets(topicPartitions)).thenReturn(Collections.singletonMap(topicPartition, 0L));
when(consumer.poll(Duration.ofMillis(100))).thenAnswer(invocation -> {
if (polls.getAndIncrement() < 10) {
return new ConsumerRecords<>(Collections.singletonMap(topicPartition, Collections.singletonList(record)));
} else {
throw new RuntimeException("Some Kafka poll error here");
}
});
when(record.key()).thenReturn(avroKey);
when(record.value()).thenReturn(avroValue);
when(converter.toModel(avroKey, avroValue)).thenReturn(event);
when(record.headers()).thenReturn(new RecordHeaders(Collections.singletonList(new RecordHeader(CORRELATION_ID, "foo".getBytes(UTF_8)))));
val latch = new CountDownLatch(1);
doAnswer((correlationId) -> {
assertThat(underTest.getState(), is(RUNNING));
latch.countDown();
return null;
}).when(correlator).received(anyString());

underTest.receive(listener);
latch.await(1, SECONDS);

val inOrder = Mockito.inOrder(consumer, listener, correlator);
inOrder.verify(consumer).assign(topicPartitions);
inOrder.verify(consumer).seekToBeginning(topicPartitions);
inOrder.verify(listener).onEvent(LOAD_COMPLETE);
inOrder.verify(listener).onEvent(event);

await.untilAsserted(() -> assertThat(underTest.getState(), is(ERROR)));
verify(consumer, times(11)).poll(any());
underTest.close();
assertThat(underTest.getState(), is(NOT_RUNNING));
}

@Test(expected = IllegalStateException.class)
public void incorrectNumberOfPartitions() {
when(consumer.partitionsFor(topic)).thenReturn(Collections.emptyList());
Expand All @@ -167,7 +239,8 @@ public void incorrectNumberOfPartitions() {
public void onlySupportsOneReceiver() {
val doNothingListener = new EventReceiverListener() {
@Override
public <K extends Entity.Key<S>, S extends Specification> void onEvent(Event<K, S> event) { }
public <K extends Entity.Key<S>, S extends Specification> void onEvent(Event<K, S> event) {
}
};
underTest.receive(doNothingListener);
underTest.receive(doNothingListener);
Expand Down

0 comments on commit 6ae5315

Please sign in to comment.