Skip to content

Commit

Permalink
destination-async-framework: move the state emission logic into Globa…
Browse files Browse the repository at this point in the history
…lAsyncStateManager (#35240)
  • Loading branch information
subodh1810 authored and xiaohansong committed Feb 27, 2024
1 parent ee38f14 commit baee6fd
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 157 deletions.
1 change: 1 addition & 0 deletions airbyte-cdk/java/airbyte-cdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ MavenLocal debugging steps:

| Version | Date | Pull Request | Subject |
|:--------|:-----------|:-----------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------|
| 0.20.9 | 2024-02-15 | [\#35240](https://github.com/airbytehq/airbyte/pull/35240) | Make state emission to platform inside state manager itself. |
| 0.20.8 | 2024-02-15 | [\#35285](https://github.com/airbytehq/airbyte/pull/35285) | Improve blobstore module structure. |
| 0.20.7 | 2024-02-13 | [\#35236](https://github.com/airbytehq/airbyte/pull/35236) | output logs to files in addition to stdout when running tests |
| 0.20.6 | 2024-02-12 | [\#35036](https://github.com/airbytehq/airbyte/pull/35036) | Add trace utility to emit analytics messages. |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@
import io.airbyte.cdk.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.cdk.integrations.destination_async.state.FlushFailure;
import io.airbyte.cdk.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.cdk.integrations.destination_async.state.PartialStateWithDestinationStats;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.UUID;
Expand Down Expand Up @@ -67,8 +64,6 @@ public class FlushWorkers implements AutoCloseable {
private final AtomicBoolean isClosing;
private final GlobalAsyncStateManager stateManager;

private final Object LOCK = new Object();

public FlushWorkers(final BufferDequeue bufferDequeue,
final DestinationFlushFunction flushFunction,
final Consumer<AirbyteMessage> outputRecordCollector,
Expand Down Expand Up @@ -172,7 +167,7 @@ private void flush(final StreamDescriptor desc, final UUID flushWorkerId) {
AirbyteFileUtils.byteCountToDisplaySize(batch.getSizeInBytes()));

flusher.flush(desc, batch.getData().stream().map(MessageWithMeta::message));
emitStateMessages(batch.flushStates(stateIdToCount));
batch.flushStates(stateIdToCount, outputRecordCollector);
}

log.info("Flush Worker ({}) -- Worker finished flushing. Current queue size: {}",
Expand Down Expand Up @@ -222,7 +217,7 @@ public void close() throws Exception {
log.info("Closing flush workers -- all buffers flushed");

// before shutting down the supervisor, flush all state.
emitStateMessages(stateManager.flushStates());
stateManager.flushStates(outputRecordCollector);
supervisorThread.shutdown();
while (!supervisorThread.awaitTermination(5L, TimeUnit.MINUTES)) {
log.info("Waiting for flush worker supervisor to shut down");
Expand All @@ -239,17 +234,6 @@ public void close() throws Exception {
debugLoop.shutdownNow();
}

private void emitStateMessages(final List<PartialStateWithDestinationStats> partials) {
synchronized (LOCK) {
for (final PartialStateWithDestinationStats partial : partials) {
final AirbyteMessage message = Jsons.deserialize(partial.stateMessage().getSerialized(), AirbyteMessage.class);
message.getState().setDestinationStats(partial.stats());
log.info("State with arrival number {} emitted from thread {}", partial.stateArrivalNumber(), Thread.currentThread().getName());
outputRecordCollector.accept(message);
}
}
}

private static String humanReadableFlushWorkerId(final UUID flushWorkerId) {
return flushWorkerId.toString().substring(0, 5);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import io.airbyte.cdk.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.cdk.integrations.destination_async.buffers.StreamAwareQueue.MessageWithMeta;
import io.airbyte.cdk.integrations.destination_async.state.GlobalAsyncStateManager;
import io.airbyte.cdk.integrations.destination_async.state.PartialStateWithDestinationStats;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -57,16 +58,13 @@ public void close() throws Exception {
}

/**
* For the batch, marks all the states that have now been flushed. Also returns states that can be
* flushed. This method is descriptrive, it assumes that whatever consumes the state messages emits
* them, internally it purges the states it returns. message that it can.
* For the batch, marks all the states that have now been flushed. Also writes the states that can
* be flushed back to platform via stateManager.
* <p>
*
* @return list of states that can be flushed
*/
public List<PartialStateWithDestinationStats> flushStates(final Map<Long, Long> stateIdToCount) {
public void flushStates(final Map<Long, Long> stateIdToCount, final Consumer<AirbyteMessage> outputRecordCollector) {
stateIdToCount.forEach(stateManager::decrement);
return stateManager.flushStates();
stateManager.flushStates(outputRecordCollector);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import com.google.common.base.Strings;
import io.airbyte.cdk.integrations.destination_async.GlobalMemoryManager;
import io.airbyte.cdk.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteStateMessage;
import io.airbyte.protocol.models.v0.AirbyteStateStats;
import io.airbyte.protocol.models.v0.StreamDescriptor;
import java.util.ArrayList;
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
Expand All @@ -25,6 +25,7 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand Down Expand Up @@ -153,16 +154,12 @@ public void decrement(final long stateId, final long count) {
}

/**
* Returns state messages with no more inflight records i.e. counter = 0 across all streams.
* Flushes state messages with no more inflight records i.e. counter = 0 across all streams.
* Intended to be called by {@link io.airbyte.cdk.integrations.destination_async.FlushWorkers} after
* a worker has finished flushing its record batch.
* <p>
* The return list of states should be emitted back to the platform.
*
* @return list of state messages with no more inflight records.
*/
public List<PartialStateWithDestinationStats> flushStates() {
final List<PartialStateWithDestinationStats> output = new ArrayList<>();
public void flushStates(final Consumer<AirbyteMessage> outputRecordCollector) {
Long bytesFlushed = 0L;
synchronized (LOCK) {
for (final Map.Entry<StreamDescriptor, LinkedBlockingDeque<Long>> entry : descToStateIdQ.entrySet()) {
Expand Down Expand Up @@ -195,8 +192,13 @@ public List<PartialStateWithDestinationStats> flushStates() {
if (allRecordsCommitted) {
final StateMessageWithArrivalNumber stateMessage = oldestState.getLeft();
final double flushedRecordsAssociatedWithState = stateIdToCounterForPopulatingDestinationStats.get(oldestStateId).doubleValue();
output.add(new PartialStateWithDestinationStats(stateMessage.partialAirbyteStateMessage(),
new AirbyteStateStats().withRecordCount(flushedRecordsAssociatedWithState), stateMessage.arrivalNumber()));

log.info("State with arrival number {} emitted from thread {} at {}", stateMessage.arrivalNumber(), Thread.currentThread().getName(),
Instant.now().toString());
final AirbyteMessage message = Jsons.deserialize(stateMessage.partialAirbyteStateMessage.getSerialized(), AirbyteMessage.class);
message.getState().setDestinationStats(new AirbyteStateStats().withRecordCount(flushedRecordsAssociatedWithState));
outputRecordCollector.accept(message);

bytesFlushed += oldestState.getRight();

// cleanup
Expand All @@ -212,7 +214,6 @@ public List<PartialStateWithDestinationStats> flushStates() {
}

freeBytes(bytesFlushed);
return output;
}

private Long getStateIdAndIncrement(final StreamDescriptor streamDescriptor, final long increment) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version=0.20.8
version=0.20.9
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public class BufferDequeueTest {

private static final int RECORD_SIZE_20_BYTES = 20;
private static final String DEFAULT_NAMESPACE = "foo_namespace";
public static final String RECORD_20_BYTES = "abc";
private static final String STREAM_NAME = "stream1";
private static final StreamDescriptor STREAM_DESC = new StreamDescriptor().withName(STREAM_NAME);
private static final PartialAirbyteMessage RECORD_MSG_20_BYTES = new PartialAirbyteMessage()
Expand Down
Loading

0 comments on commit baee6fd

Please sign in to comment.