Skip to content

Commit

Permalink
fix(concurrent cdk): Properly call set_initial_state() on the cursor …
Browse files Browse the repository at this point in the history
…that is initialized on the ClientSideIncrementalRecordFilterDecorator (#310)
  • Loading branch information
brianjlai authored Feb 4, 2025
1 parent 126e233 commit ca68c5c
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
15 changes: 13 additions & 2 deletions airbyte_cdk/sources/declarative/concurrent_declarative_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,10 +475,21 @@ def _get_retriever(
# Also a temporary hack. In the legacy Stream implementation, as part of the read,
# set_initial_state() is called to instantiate incoming state on the cursor. Although we no
# longer rely on the legacy low-code cursor for concurrent checkpointing, low-code components
# like StopConditionPaginationStrategyDecorator and ClientSideIncrementalRecordFilterDecorator
# still rely on a DatetimeBasedCursor that is properly initialized with state.
# like StopConditionPaginationStrategyDecorator still rely on a DatetimeBasedCursor that is
# properly initialized with state.
if retriever.cursor:
retriever.cursor.set_initial_state(stream_state=stream_state)

# Similar to above, the ClientSideIncrementalRecordFilterDecorator cursor is a separate instance
# from the one initialized on the SimpleRetriever, so it also must also have state initialized
# for semi-incremental streams using is_client_side_incremental to filter properly
if isinstance(retriever.record_selector, RecordSelector) and isinstance(
retriever.record_selector.record_filter, ClientSideIncrementalRecordFilterDecorator
):
retriever.record_selector.record_filter._cursor.set_initial_state(
stream_state=stream_state
) # type: ignore # After non-concurrent cursors are deprecated we can remove these cursor workarounds

# We zero it out here, but since this is a cursor reference, the state is still properly
# instantiated for the other components that reference it
retriever.cursor = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
ConcurrentDeclarativeSource,
)
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
from airbyte_cdk.sources.declarative.extractors.record_filter import (
ClientSideIncrementalRecordFilterDecorator,
)
from airbyte_cdk.sources.declarative.partition_routers import AsyncJobPartitionRouter
from airbyte_cdk.sources.declarative.stream_slicers.declarative_partition_generator import (
StreamSlicerPartitionGenerator,
Expand Down Expand Up @@ -1647,6 +1650,44 @@ def test_async_incremental_stream_uses_concurrent_cursor_with_state():
assert async_job_partition_router.stream_slicer._concurrent_state == expected_state


def test_stream_using_is_client_side_incremental_has_cursor_state():
expected_cursor_value = "2024-07-01"
state = [
AirbyteStateMessage(
type=AirbyteStateType.STREAM,
stream=AirbyteStreamState(
stream_descriptor=StreamDescriptor(name="locations", namespace=None),
stream_state=AirbyteStateBlob(updated_at=expected_cursor_value),
),
)
]

manifest_with_stream_state_interpolation = copy.deepcopy(_MANIFEST)

# Enable semi-incremental on the locations stream
manifest_with_stream_state_interpolation["definitions"]["locations_stream"]["incremental_sync"][
"is_client_side_incremental"
] = True

source = ConcurrentDeclarativeSource(
source_config=manifest_with_stream_state_interpolation,
config=_CONFIG,
catalog=_CATALOG,
state=state,
)
concurrent_streams, synchronous_streams = source._group_streams(config=_CONFIG)

locations_stream = concurrent_streams[2]
assert isinstance(locations_stream, DefaultStream)

simple_retriever = locations_stream._stream_partition_generator._partition_factory._retriever
record_filter = simple_retriever.record_selector.record_filter
assert isinstance(record_filter, ClientSideIncrementalRecordFilterDecorator)
client_side_incremental_cursor_state = record_filter._cursor._cursor

assert client_side_incremental_cursor_state == expected_cursor_value


def create_wrapped_stream(stream: DeclarativeStream) -> Stream:
slice_to_records_mapping = get_mocked_read_records_output(stream_name=stream.name)

Expand Down

0 comments on commit ca68c5c

Please sign in to comment.