Skip to content

Commit

Permalink
feat: implement row filtering mapper (#14333)
Browse files Browse the repository at this point in the history
  • Loading branch information
subodh1810 committed Oct 29, 2024
1 parent 7d88645 commit 08e9b56
Show file tree
Hide file tree
Showing 27 changed files with 948 additions and 17 deletions.
1 change: 1 addition & 0 deletions airbyte-api/server-api/src/main/openapi/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11665,6 +11665,7 @@ components:
enum:
- hashing
- field-renaming
- row-filtering
x-sdk-component: true
ConfiguredStreamMapper:
type: object
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class ReplicationWorkerHelper(
}

@VisibleForTesting
fun internalProcessMessageFromSource(sourceRawMessage: AirbyteMessage): AirbyteMessage {
fun internalProcessMessageFromSource(sourceRawMessage: AirbyteMessage): AirbyteMessage? {
val context = requireNotNull(ctx)

fieldSelector.filterSelectedFields(sourceRawMessage)
Expand All @@ -413,7 +413,12 @@ class ReplicationWorkerHelper(
}

if (sourceRawMessage.type == Type.RECORD) {
applyTransformationMappers(AirbyteJsonRecordAdapter(sourceRawMessage))
val airbyteJsonRecordAdapter = AirbyteJsonRecordAdapter(sourceRawMessage)
applyTransformationMappers(airbyteJsonRecordAdapter)
if (!airbyteJsonRecordAdapter.shouldInclude()) {
messageTracker.syncStatsTracker.updateFilteredOutRecordsStats(sourceRawMessage.record)
return null
}
}

return sourceRawMessage
Expand Down Expand Up @@ -458,8 +463,8 @@ class ReplicationWorkerHelper(
// source, so we only modify the state message after processing it, right before we send it to the
// destination
return internalProcessMessageFromSource(attachIdToStateMessageFromSource(sourceRawMessage))
.let { mapper.mapMessage(it) }
.let { Optional.ofNullable(it) }
?.let { mapper.mapMessage(it) }
?.let { Optional.ofNullable(it) } ?: Optional.empty()
}

fun getSourceDefinitionIdForSourceId(sourceId: UUID): UUID =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class StateCheckSumCountEventHandler(
checksumValidationEnabled: Boolean,
includeStreamInLogs: Boolean = true,
streamPlatformRecordCounts: Map<AirbyteStreamNameNamespacePair, Long> = emptyMap(),
filteredOutRecords: Double = 0.0,
) {
if (!isStateTypeSupported(stateMessage)) {
return
Expand All @@ -203,7 +204,12 @@ class StateCheckSumCountEventHandler(
val sourceStats: AirbyteStateStats? = stateMessage.sourceStats
if (sourceStats != null) {
sourceStats.recordCount?.let { sourceRecordCount ->
if (sourceRecordCount != stateRecordCount || platformRecordCount != stateRecordCount) {
if ((
sourceRecordCount.minus(
filteredOutRecords,
)
) != stateRecordCount || (platformRecordCount.minus(filteredOutRecords)) != stateRecordCount
) {
misMatchWhenAllThreeCountsArePresent(
origin,
sourceRecordCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class ParallelStreamStatsTracker(
@Volatile
private var checksumValidationEnabled = true

override fun updateFilteredOutRecordsStats(recordMessage: AirbyteRecordMessage) {
getOrCreateStreamStatsTracker(getNameNamespacePair(recordMessage)).updateFilteredOutRecordsStats(recordMessage)
}

override fun updateStats(recordMessage: AirbyteRecordMessage) {
getOrCreateStreamStatsTracker(getNameNamespacePair(recordMessage))
.trackRecord(recordMessage)
Expand Down Expand Up @@ -136,13 +140,15 @@ class ParallelStreamStatsTracker(
}
else -> {
val statsTracker = getOrCreateStreamStatsTracker(getNameNamespacePair(stateMessage))
val filteredOutRecords = statsTracker.getTrackedFilteredOutRecordsSinceLastStateMessage(stateMessage)
stateCheckSumEventHandler.validateStateChecksum(
stateMessage = stateMessage,
platformRecordCount = statsTracker.getTrackedCommittedRecordsSinceLastStateMessage(stateMessage).toDouble(),
platformRecordCount = statsTracker.getTrackedEmittedRecordsSinceLastStateMessage(stateMessage).toDouble(),
origin = AirbyteMessageOrigin.DESTINATION,
failOnInvalidChecksum = failOnInvalidChecksum,
checksumValidationEnabled = checksumValidationEnabled,
streamPlatformRecordCounts = getStreamToCommittedRecords(),
filteredOutRecords = filteredOutRecords.toDouble(),
)
statsTracker.trackStateFromDestination(stateMessage)
}
Expand Down Expand Up @@ -174,6 +180,12 @@ class ParallelStreamStatsTracker(
failOnInvalidChecksum: Boolean,
) {
val expectedRecordCount = streamTrackers.values.sumOf { getEmittedCount(origin, stateMessage, it).toDouble() }
val filteredOutRecords =
if (origin == AirbyteMessageOrigin.DESTINATION) {
streamTrackers.values.sumOf { it.getTrackedFilteredOutRecordsSinceLastStateMessage(stateMessage) }
} else {
0
}
stateCheckSumEventHandler.validateStateChecksum(
stateMessage = stateMessage,
platformRecordCount = expectedRecordCount,
Expand All @@ -182,6 +194,7 @@ class ParallelStreamStatsTracker(
checksumValidationEnabled = checksumValidationEnabled,
includeStreamInLogs = false,
streamPlatformRecordCounts = getStreamToEmittedRecords(),
filteredOutRecords = filteredOutRecords.toDouble(),
)
}

Expand All @@ -192,7 +205,7 @@ class ParallelStreamStatsTracker(
): Long {
return when (origin) {
AirbyteMessageOrigin.SOURCE -> tracker.getTrackedEmittedRecordsSinceLastStateMessage()
AirbyteMessageOrigin.DESTINATION -> tracker.getTrackedCommittedRecordsSinceLastStateMessage(stateMessage)
AirbyteMessageOrigin.DESTINATION -> tracker.getTrackedEmittedRecordsSinceLastStateMessage(stateMessage)
AirbyteMessageOrigin.INTERNAL -> 0
}
}
Expand All @@ -210,6 +223,8 @@ class ParallelStreamStatsTracker(
// [sumOf] methods handle null values as 0, which is a change that we don't want to make at this time.
val streamSyncStats = getAllStreamSyncStats(hasReplicationCompleted).takeIf { it.isNotEmpty() }
val bytesCommitted = streamSyncStats?.sumOf { it.stats.bytesCommitted }
val recordsFilteredOut = streamSyncStats?.sumOf { it.stats.recordsFilteredOut }
val bytesFilteredOut = streamSyncStats?.sumOf { it.stats.bytesFilteredOut }
val recordsCommitted = streamSyncStats?.sumOf { it.stats.recordsCommitted }
val bytesEmitted = streamSyncStats?.sumOf { it.stats.bytesEmitted }
val recordsEmitted = streamSyncStats?.sumOf { it.stats.recordsEmitted }
Expand Down Expand Up @@ -241,6 +256,8 @@ class ParallelStreamStatsTracker(
.withRecordsCommitted(recordsCommitted)
.withBytesEmitted(bytesEmitted)
.withRecordsEmitted(recordsEmitted)
.withRecordsFilteredOut(recordsFilteredOut)
.withBytesFilteredOut(bytesFilteredOut)
.withEstimatedBytes(estimatedBytes)
.withEstimatedRecords(estimatedRecords)
}
Expand Down Expand Up @@ -282,6 +299,16 @@ class ParallelStreamStatsTracker(
.filterValues { it.nameNamespacePair.name != null }
.mapValues { it.value.streamStats.emittedRecordsCount.get() }

override fun getStreamToFilteredOutRecords(): Map<AirbyteStreamNameNamespacePair, Long> =
streamTrackers
.filterValues { it.nameNamespacePair.name != null }
.mapValues { it.value.streamStats.filteredOutRecords.get() }

override fun getStreamToFilteredOutBytes(): Map<AirbyteStreamNameNamespacePair, Long> =
streamTrackers
.filterValues { it.nameNamespacePair.name != null }
.mapValues { it.value.streamStats.filteredOutBytesCount.get() }

override fun getStreamToEstimatedRecords(): Map<AirbyteStreamNameNamespacePair, Long> =
if (hasEstimatesErrors) {
mapOf()
Expand All @@ -302,6 +329,10 @@ class ParallelStreamStatsTracker(

override fun getTotalRecordsEmitted(): Long = getTotalStats().recordsEmitted ?: 0

override fun getTotalRecordsFilteredOut(): Long = getTotalStats().recordsFilteredOut ?: 0

override fun getTotalBytesFilteredOut(): Long = getTotalStats().bytesFilteredOut ?: 0

override fun getTotalRecordsEstimated(): Long = getTotalStats().estimatedRecords ?: 0

override fun getTotalBytesEmitted(): Long = getTotalStats().bytesEmitted ?: 0
Expand Down Expand Up @@ -370,10 +401,12 @@ class ParallelStreamStatsTracker(
SyncStats()
.withBytesEmitted(streamStats.emittedBytesCount.get())
.withRecordsEmitted(streamStats.emittedRecordsCount.get())
.withRecordsFilteredOut(streamStats.filteredOutRecords.get())
.withBytesFilteredOut(streamStats.filteredOutBytesCount.get())
.apply {
if (hasReplicationCompleted) {
withBytesCommitted(streamStats.emittedBytesCount.get())
withRecordsCommitted(streamStats.emittedRecordsCount.get())
withBytesCommitted(streamStats.emittedBytesCount.get().minus(bytesFilteredOut))
withRecordsCommitted(streamStats.emittedRecordsCount.get().minus(recordsFilteredOut))
} else {
withBytesCommitted(streamStats.committedBytesCount.get())
withRecordsCommitted(streamStats.committedRecordsCount.get())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ import java.util.concurrent.atomic.LongAccumulator
*/
data class StreamStatsCounters(
val emittedRecordsCount: AtomicLong = AtomicLong(),
val filteredOutRecords: AtomicLong = AtomicLong(),
val filteredOutBytesCount: AtomicLong = AtomicLong(),
val emittedBytesCount: AtomicLong = AtomicLong(),
val committedRecordsCount: AtomicLong = AtomicLong(),
val committedBytesCount: AtomicLong = AtomicLong(),
Expand All @@ -63,6 +65,8 @@ data class StreamStatsCounters(
data class EmittedStatsCounters(
val remittedRecordsCount: AtomicLong = AtomicLong(),
val emittedBytesCount: AtomicLong = AtomicLong(),
val filteredOutRecords: AtomicLong = AtomicLong(),
val filteredOutBytesCount: AtomicLong = AtomicLong(),
)

/**
Expand Down Expand Up @@ -104,6 +108,19 @@ class StreamStatsTracker(
private var previousEmittedStats = EmittedStatsCounters()
private var previousStateMessageReceivedAt: LocalDateTime? = null

fun updateFilteredOutRecordsStats(recordMessage: AirbyteRecordMessage) {
val emittedStatsToUpdate = emittedStats
val filteredOutByteSize = Jsons.getEstimatedByteSize(recordMessage.data).toLong()
with(emittedStatsToUpdate) {
filteredOutRecords.incrementAndGet()
filteredOutBytesCount.addAndGet(filteredOutByteSize)
}
with(streamStats) {
filteredOutRecords.incrementAndGet()
filteredOutBytesCount.addAndGet(filteredOutByteSize)
}
}

/**
* Bookkeeping for when a record message is read.
*
Expand Down Expand Up @@ -253,8 +270,12 @@ class StreamStatsTracker(
stateIds.remove(stagedStats.stateId)

// Increment committed stats as we are un-staging stats
streamStats.committedBytesCount.addAndGet(stagedStats.emittedStatsCounters.emittedBytesCount.get())
streamStats.committedRecordsCount.addAndGet(stagedStats.emittedStatsCounters.remittedRecordsCount.get())
streamStats.committedBytesCount.addAndGet(
stagedStats.emittedStatsCounters.emittedBytesCount.get().minus(stagedStats.emittedStatsCounters.filteredOutBytesCount.get()),
)
streamStats.committedRecordsCount.addAndGet(
stagedStats.emittedStatsCounters.remittedRecordsCount.get().minus(stagedStats.emittedStatsCounters.filteredOutBytesCount.get()),
)

if (stagedStats.stateId == stateId) {
break
Expand Down Expand Up @@ -287,7 +308,7 @@ class StreamStatsTracker(
return previousEmittedStats.remittedRecordsCount.get()
}

fun getTrackedCommittedRecordsSinceLastStateMessage(stateMessage: AirbyteStateMessage): Long {
fun getTrackedEmittedRecordsSinceLastStateMessage(stateMessage: AirbyteStateMessage): Long {
val stateId = stateMessage.getStateIdForStatsTracking()
val stagedStats: StagedStats? = stagedStatsList.find { it.stateId == stateId }
if (stagedStats == null) {
Expand All @@ -296,6 +317,15 @@ class StreamStatsTracker(
return stagedStats?.emittedStatsCounters?.remittedRecordsCount?.get() ?: 0
}

fun getTrackedFilteredOutRecordsSinceLastStateMessage(stateMessage: AirbyteStateMessage): Long {
val stateId = stateMessage.getStateIdForStatsTracking()
val stagedStats: StagedStats? = stagedStatsList.find { it.stateId == stateId }
if (stagedStats == null) {
logger.warn { "Could not find the state message with id $stateId in the stagedStatsList" }
}
return stagedStats?.emittedStatsCounters?.filteredOutRecords?.get() ?: 0
}

fun areStreamStatsReliable(): Boolean {
return !streamStats.unreliableStateOperations.get()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ fun SyncStatsTracker.getPerStreamStats(hasReplicationCompleted: Boolean): List<S
SyncStats()
.withBytesEmitted(tracker.getStreamToEmittedBytes()[stream])
.withRecordsEmitted(records)
.withRecordsFilteredOut(tracker.getStreamToFilteredOutRecords()[stream])
.withBytesFilteredOut(tracker.getStreamToFilteredOutBytes()[stream])
.withSourceStateMessagesEmitted(null)
.withDestinationStateMessagesEmitted(null)
.apply {
if (hasReplicationCompleted) {
bytesCommitted = tracker.getStreamToEmittedBytes()[stream]
recordsCommitted = tracker.getStreamToEmittedRecords()[stream]
bytesCommitted = tracker.getStreamToEmittedBytes()[stream]?.minus(bytesFilteredOut)
recordsCommitted = tracker.getStreamToEmittedRecords()[stream]?.minus(recordsFilteredOut)
} else {
bytesCommitted = tracker.getStreamToCommittedBytes()[stream]
recordsCommitted = tracker.getStreamToCommittedRecords()[stream]
Expand All @@ -44,7 +46,9 @@ fun SyncStatsTracker.getPerStreamStats(hasReplicationCompleted: Boolean): List<S
fun SyncStatsTracker.getTotalStats(hasReplicationCompleted: Boolean): SyncStats {
return SyncStats()
.withRecordsEmitted(getTotalRecordsEmitted())
.withRecordsFilteredOut(getTotalRecordsFilteredOut())
.withBytesEmitted(getTotalBytesEmitted())
.withBytesFilteredOut(getTotalBytesFilteredOut())
.withSourceStateMessagesEmitted(getTotalSourceStateMessagesEmitted())
.withDestinationStateMessagesEmitted(getTotalDestinationStateMessagesEmitted())
.withMaxSecondsBeforeSourceStateMessageEmitted(getMaxSecondsToReceiveSourceStateMessage())
Expand All @@ -53,8 +57,8 @@ fun SyncStatsTracker.getTotalStats(hasReplicationCompleted: Boolean): SyncStats
.withMeanSecondsBetweenStateMessageEmittedandCommitted(getMeanSecondsBetweenStateMessageEmittedAndCommitted())
.apply {
if (hasReplicationCompleted) {
bytesCommitted = bytesEmitted
recordsCommitted = recordsEmitted
bytesCommitted = bytesEmitted.minus(bytesFilteredOut)
recordsCommitted = recordsEmitted.minus(recordsFilteredOut)
} else {
bytesCommitted = getTotalBytesCommitted()
recordsCommitted = getTotalRecordsCommitted()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import io.airbyte.workers.context.ReplicationFeatureFlags
* Track stats during a sync.
*/
interface SyncStatsTracker {
fun updateFilteredOutRecordsStats(recordMessage: AirbyteRecordMessage)

/**
* Update the stats count with data from recordMessage.
*/
Expand Down Expand Up @@ -63,6 +65,10 @@ interface SyncStatsTracker {
*/
fun getStreamToEmittedRecords(): Map<AirbyteStreamNameNamespacePair, Long>

fun getStreamToFilteredOutRecords(): Map<AirbyteStreamNameNamespacePair, Long>

fun getStreamToFilteredOutBytes(): Map<AirbyteStreamNameNamespacePair, Long>

/**
* Get the per-stream estimated record count provided by
* [io.airbyte.protocol.models.AirbyteEstimateTraceMessage].
Expand Down Expand Up @@ -95,6 +101,10 @@ interface SyncStatsTracker {
*/
fun getTotalRecordsEmitted(): Long

fun getTotalRecordsFilteredOut(): Long

fun getTotalBytesFilteredOut(): Long

/**
* Get the overall estimated record count.
*
Expand Down
Loading

0 comments on commit 08e9b56

Please sign in to comment.