diff --git a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py index 3b3636236..e67f1f9e1 100644 --- a/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/global_substream_cursor.py @@ -112,7 +112,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: * Yield the last slice. At that point, once there are as many slices yielded as closes, the global slice will be closed too """ slice_generator = ( - StreamSlice(partition=partition, cursor_slice=cursor_slice) + StreamSlice( + partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields + ) for partition in self._partition_router.stream_slices() for cursor_slice in self._stream_cursor.stream_slices() ) @@ -128,7 +130,9 @@ def stream_slices(self) -> Iterable[StreamSlice]: def generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[StreamSlice]: slice_generator = ( - StreamSlice(partition=partition, cursor_slice=cursor_slice) + StreamSlice( + partition=partition, cursor_slice=cursor_slice, extra_fields=partition.extra_fields + ) for cursor_slice in self._stream_cursor.stream_slices() ) diff --git a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py index 8073b2b12..f689dcf05 100644 --- a/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py +++ b/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py @@ -8,6 +8,9 @@ import pytest from airbyte_cdk.sources.declarative.incremental.declarative_cursor import DeclarativeCursor +from airbyte_cdk.sources.declarative.incremental.global_substream_cursor import ( + GlobalSubstreamCursor, +) from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( PerPartitionCursor, PerPartitionKeySerializer, @@ -715,3 +718,63 @@ def test_per_partition_state_when_set_initial_global_state( }, ] assert cursor.get_stream_state()["states"] == expected_state + + +def test_per_partition_cursor_partition_router_extra_fields( + mocked_cursor_factory, mocked_partition_router +): + first_partition = {"first_partition_key": "first_partition_value"} + mocked_partition_router.stream_slices.return_value = [ + StreamSlice( + partition=first_partition, cursor_slice={}, extra_fields={"extra_field": "extra_value"} + ), + ] + cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) + + mocked_cursor_factory.create.return_value = cursor + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + + cursor.set_initial_state({"states": [{"partition": first_partition, "cursor": CURSOR_STATE}]}) + slices = list(cursor.stream_slices()) + + assert slices[0].extra_fields == {"extra_field": "extra_value"} + assert slices == [ + StreamSlice( + partition={"first_partition_key": "first_partition_value"}, + cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"}, + extra_fields={"extra_field": "extra_value"}, + ) + ] + + +def test_global_cursor_partition_router_extra_fields( + mocked_cursor_factory, mocked_partition_router +): + first_partition = {"first_partition_key": "first_partition_value"} + mocked_partition_router.stream_slices.return_value = [ + StreamSlice( + partition=first_partition, cursor_slice={}, extra_fields={"extra_field": "extra_value"} + ), + ] + cursor = ( + MockedCursorBuilder() + .with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]) + .build() + ) + + global_cursor = GlobalSubstreamCursor(cursor, mocked_partition_router) + + slices = list(global_cursor.stream_slices()) + + assert slices[0].extra_fields == {"extra_field": "extra_value"} + assert slices == [ + StreamSlice( + partition=first_partition, + cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"}, + extra_fields={"extra_field": "extra_value"}, + ) + ]