diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py index 2e76f49a396a..2694da2762ca 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/datetime/min_max_datetime.py @@ -4,7 +4,7 @@ import datetime as dt from dataclasses import InitVar, dataclass, field -from typing import Any, Mapping, Union +from typing import Any, Mapping, Optional, Union from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString @@ -37,13 +37,13 @@ class MinMaxDatetime: min_datetime: Union[InterpolatedString, str] = "" max_datetime: Union[InterpolatedString, str] = "" - def __post_init__(self, parameters: Mapping[str, Any]): + def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.datetime = InterpolatedString.create(self.datetime, parameters=parameters or {}) self._parser = DatetimeParser() - self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None - self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None + self.min_datetime = InterpolatedString.create(self.min_datetime, parameters=parameters) if self.min_datetime else None # type: ignore + self.max_datetime = InterpolatedString.create(self.max_datetime, parameters=parameters) if self.max_datetime else None # type: ignore - def get_datetime(self, config, **additional_parameters) -> dt.datetime: + def get_datetime(self, config: Mapping[str, Any], **additional_parameters: Mapping[str, Any]) -> dt.datetime: """ Evaluates and returns the datetime :param config: The user-provided configuration as specified by the source's spec @@ -55,29 +55,44 @@ def get_datetime(self, config, **additional_parameters) -> dt.datetime: if not datetime_format: datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z" - time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format) + time = self._parser.parse(str(self.datetime.eval(config, **additional_parameters)), datetime_format) # type: ignore # datetime is always cast to an interpolated string if self.min_datetime: - min_time = str(self.min_datetime.eval(config, **additional_parameters)) + min_time = str(self.min_datetime.eval(config, **additional_parameters)) # type: ignore # min_datetime is always cast to an interpolated string if min_time: - min_time = self._parser.parse(min_time, datetime_format) - time = max(time, min_time) + min_datetime = self._parser.parse(min_time, datetime_format) # type: ignore # min_datetime is always cast to an interpolated string + time = max(time, min_datetime) if self.max_datetime: - max_time = str(self.max_datetime.eval(config, **additional_parameters)) + max_time = str(self.max_datetime.eval(config, **additional_parameters)) # type: ignore # max_datetime is always cast to an interpolated string if max_time: - max_time = self._parser.parse(max_time, datetime_format) - time = min(time, max_time) + max_datetime = self._parser.parse(max_time, datetime_format) + time = min(time, max_datetime) return time - @property + @property # type: ignore # properties don't play well with dataclasses... def datetime_format(self) -> str: """The format of the string representing the datetime""" return self._datetime_format @datetime_format.setter - def datetime_format(self, value: str): + def datetime_format(self, value: str) -> None: """Setter for the datetime format""" # Covers the case where datetime_format is not provided in the constructor, which causes the property object # to be set which we need to avoid doing if not isinstance(value, property): self._datetime_format = value + + @classmethod + def create( + cls, + interpolated_string_or_min_max_datetime: Union[InterpolatedString, str, "MinMaxDatetime"], + parameters: Optional[Mapping[str, Any]] = None, + ) -> "MinMaxDatetime": + if parameters is None: + parameters = {} + if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance( + interpolated_string_or_min_max_datetime, str + ): + return MinMaxDatetime(datetime=interpolated_string_or_min_max_datetime, parameters=parameters) + else: + return interpolated_string_or_min_max_datetime diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py index f74ed377c4ab..d56e7c99a545 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py @@ -10,7 +10,7 @@ from airbyte_cdk.sources.declarative.retrievers.retriever import Retriever from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader from airbyte_cdk.sources.declarative.schema.schema_loader import SchemaLoader -from airbyte_cdk.sources.declarative.types import Config +from airbyte_cdk.sources.declarative.types import Config, StreamSlice from airbyte_cdk.sources.streams.core import Stream @@ -101,6 +101,8 @@ def read_records( """ :param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state. """ + if not isinstance(stream_slice, StreamSlice): + raise ValueError(f"DeclarativeStream does not support stream_slices that are not StreamSlice. Got {stream_slice}") yield from self.retriever.read_records(self.get_json_schema(), stream_slice) def get_json_schema(self) -> Mapping[str, Any]: # type: ignore @@ -114,7 +116,7 @@ def get_json_schema(self) -> Mapping[str, Any]: # type: ignore def stream_slices( self, *, sync_mode: SyncMode, cursor_field: Optional[List[str]] = None, stream_state: Optional[Mapping[str, Any]] = None - ) -> Iterable[Optional[Mapping[str, Any]]]: + ) -> Iterable[Optional[StreamSlice]]: """ Override to define the slices for this stream. See the stream slicing section of the docs for more information. diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py index e2a5f27d1ef3..0124b93e7553 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/datetime_based_cursor.py @@ -4,7 +4,7 @@ import datetime from dataclasses import InitVar, dataclass, field -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Union from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, Level, Type from airbyte_cdk.sources.declarative.datetime.datetime_parser import DatetimeParser @@ -70,10 +70,8 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: f"If step is defined, cursor_granularity should be as well and vice-versa. " f"Right now, step is `{self.step}` and cursor_granularity is `{self.cursor_granularity}`" ) - if not isinstance(self.start_datetime, MinMaxDatetime): - self.start_datetime = MinMaxDatetime(self.start_datetime, parameters) - if self.end_datetime and not isinstance(self.end_datetime, MinMaxDatetime): - self.end_datetime = MinMaxDatetime(self.end_datetime, parameters) + self._start_datetime = MinMaxDatetime.create(self.start_datetime, parameters) + self._end_datetime = None if not self.end_datetime else MinMaxDatetime.create(self.end_datetime, parameters) self._timezone = datetime.timezone.utc self._interpolation = JinjaInterpolation() @@ -84,23 +82,23 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None: else datetime.timedelta.max ) self._cursor_granularity = self._parse_timedelta(self.cursor_granularity) - self.cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters) - self.lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters) - self.partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters) - self.partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters) + self._cursor_field = InterpolatedString.create(self.cursor_field, parameters=parameters) + self._lookback_window = InterpolatedString.create(self.lookback_window, parameters=parameters) if self.lookback_window else None + self._partition_field_start = InterpolatedString.create(self.partition_field_start or "start_time", parameters=parameters) + self._partition_field_end = InterpolatedString.create(self.partition_field_end or "end_time", parameters=parameters) self._parser = DatetimeParser() # If datetime format is not specified then start/end datetime should inherit it from the stream slicer - if not self.start_datetime.datetime_format: - self.start_datetime.datetime_format = self.datetime_format - if self.end_datetime and not self.end_datetime.datetime_format: - self.end_datetime.datetime_format = self.datetime_format + if not self._start_datetime.datetime_format: + self._start_datetime.datetime_format = self.datetime_format + if self._end_datetime and not self._end_datetime.datetime_format: + self._end_datetime.datetime_format = self.datetime_format if not self.cursor_datetime_formats: self.cursor_datetime_formats = [self.datetime_format] def get_stream_state(self) -> StreamState: - return {self.cursor_field.eval(self.config): self._cursor} if self._cursor else {} + return {self._cursor_field.eval(self.config): self._cursor} if self._cursor else {} def set_initial_state(self, stream_state: StreamState) -> None: """ @@ -109,17 +107,22 @@ def set_initial_state(self, stream_state: StreamState) -> None: :param stream_state: The state of the stream as returned by get_stream_state """ - self._cursor = stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None + self._cursor = stream_state.get(self._cursor_field.eval(self.config)) if stream_state else None def close_slice(self, stream_slice: StreamSlice, most_recent_record: Optional[Record]) -> None: - last_record_cursor_value = most_recent_record.get(self.cursor_field.eval(self.config)) if most_recent_record else None - stream_slice_value_end = stream_slice.get(self.partition_field_end.eval(self.config)) + if stream_slice.partition: + raise ValueError(f"Stream slice {stream_slice} should not have a partition. Got {stream_slice.partition}.") + last_record_cursor_value = most_recent_record.get(self._cursor_field.eval(self.config)) if most_recent_record else None + stream_slice_value_end = stream_slice.get(self._partition_field_end.eval(self.config)) + potential_cursor_values = [ + cursor_value for cursor_value in [self._cursor, last_record_cursor_value, stream_slice_value_end] if cursor_value + ] cursor_value_str_by_cursor_value_datetime = dict( map( # we need to ensure the cursor value is preserved as is in the state else the CATs might complain of something like # 2023-01-04T17:30:19.000Z' <= '2023-01-04T17:30:19.000000Z' lambda datetime_str: (self.parse_date(datetime_str), datetime_str), - filter(lambda item: item, [self._cursor, last_record_cursor_value, stream_slice_value_end]), + potential_cursor_values, ) ) self._cursor = ( @@ -142,37 +145,43 @@ def stream_slices(self) -> Iterable[StreamSlice]: return self._partition_daterange(start_datetime, end_datetime, self._step) def _calculate_earliest_possible_value(self, end_datetime: datetime.datetime) -> datetime.datetime: - lookback_delta = self._parse_timedelta(self.lookback_window.eval(self.config) if self.lookback_window else "P0D") - earliest_possible_start_datetime = min(self.start_datetime.get_datetime(self.config), end_datetime) + lookback_delta = self._parse_timedelta(self._lookback_window.eval(self.config) if self.lookback_window else "P0D") + earliest_possible_start_datetime = min(self._start_datetime.get_datetime(self.config), end_datetime) cursor_datetime = self._calculate_cursor_datetime_from_state(self.get_stream_state()) return max(earliest_possible_start_datetime, cursor_datetime) - lookback_delta def _select_best_end_datetime(self) -> datetime.datetime: now = datetime.datetime.now(tz=self._timezone) - if not self.end_datetime: + if not self._end_datetime: return now - return min(self.end_datetime.get_datetime(self.config), now) + return min(self._end_datetime.get_datetime(self.config), now) def _calculate_cursor_datetime_from_state(self, stream_state: Mapping[str, Any]) -> datetime.datetime: - if self.cursor_field.eval(self.config, stream_state=stream_state) in stream_state: - return self.parse_date(stream_state[self.cursor_field.eval(self.config)]) + if self._cursor_field.eval(self.config, stream_state=stream_state) in stream_state: + return self.parse_date(stream_state[self._cursor_field.eval(self.config)]) return datetime.datetime.min.replace(tzinfo=datetime.timezone.utc) def _format_datetime(self, dt: datetime.datetime) -> str: return self._parser.format(dt, self.datetime_format) - def _partition_daterange(self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration]): - start_field = self.partition_field_start.eval(self.config) - end_field = self.partition_field_end.eval(self.config) + def _partition_daterange( + self, start: datetime.datetime, end: datetime.datetime, step: Union[datetime.timedelta, Duration] + ) -> List[StreamSlice]: + start_field = self._partition_field_start.eval(self.config) + end_field = self._partition_field_end.eval(self.config) dates = [] while start <= end: next_start = self._evaluate_next_start_date_safely(start, step) end_date = self._get_date(next_start - self._cursor_granularity, end, min) - dates.append({start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)}) + dates.append( + StreamSlice( + partition={}, cursor_slice={start_field: self._format_datetime(start), end_field: self._format_datetime(end_date)} + ) + ) start = next_start return dates - def _evaluate_next_start_date_safely(self, start, step): + def _evaluate_next_start_date_safely(self, start: datetime.datetime, step: datetime.timedelta) -> datetime.datetime: """ Given that we set the default step at datetime.timedelta.max, we will generate an OverflowError when evaluating the next start_date This method assumes that users would never enter a step that would generate an overflow. Given that would be the case, the code @@ -183,7 +192,12 @@ def _evaluate_next_start_date_safely(self, start, step): except OverflowError: return datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) - def _get_date(self, cursor_value, default_date: datetime.datetime, comparator) -> datetime.datetime: + def _get_date( + self, + cursor_value: datetime.datetime, + default_date: datetime.datetime, + comparator: Callable[[datetime.datetime, datetime.datetime], datetime.datetime], + ) -> datetime.datetime: cursor_date = cursor_value or default_date return comparator(cursor_date, default_date) @@ -196,7 +210,7 @@ def parse_date(self, date: str) -> datetime.datetime: raise ValueError(f"No format in {self.cursor_datetime_formats} matching {date}") @classmethod - def _parse_timedelta(cls, time_str) -> Union[datetime.timedelta, Duration]: + def _parse_timedelta(cls, time_str: Optional[str]) -> Union[datetime.timedelta, Duration]: """ :return Parses an ISO 8601 durations into datetime.timedelta or Duration objects. """ @@ -244,18 +258,20 @@ def request_kwargs(self) -> Mapping[str, Any]: # Never update kwargs return {} - def _get_request_options(self, option_type: RequestOptionType, stream_slice: StreamSlice): - options = {} + def _get_request_options(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: + options: MutableMapping[str, Any] = {} + if not stream_slice: + return options if self.start_time_option and self.start_time_option.inject_into == option_type: - options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( - self.partition_field_start.eval(self.config) + options[self.start_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore # field_name is always casted to an interpolated string + self._partition_field_start.eval(self.config) ) if self.end_time_option and self.end_time_option.inject_into == option_type: - options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self.partition_field_end.eval(self.config)) + options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(self._partition_field_end.eval(self.config)) # type: ignore # field_name is always casted to an interpolated string return options def should_be_synced(self, record: Record) -> bool: - cursor_field = self.cursor_field.eval(self.config) + cursor_field = self._cursor_field.eval(self.config) record_cursor_value = record.get(cursor_field) if not record_cursor_value: self._send_log( @@ -278,7 +294,7 @@ def _send_log(self, level: Level, message: str) -> None: ) def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: - cursor_field = self.cursor_field.eval(self.config) + cursor_field = self._cursor_field.eval(self.config) first_cursor_value = first.get(cursor_field) second_cursor_value = second.get(cursor_field) if first_cursor_value and second_cursor_value: diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py index 75af991970d3..39dfa8f1fe1f 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/incremental/per_partition_cursor.py @@ -3,7 +3,7 @@ # import json -from typing import Any, Callable, Iterable, Mapping, Optional +from typing import Any, Callable, Iterable, Mapping, MutableMapping, Optional, Union from airbyte_cdk.sources.declarative.incremental.cursor import Cursor from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer @@ -24,73 +24,15 @@ def to_partition_key(to_serialize: Any) -> str: return json.dumps(to_serialize, indent=None, separators=(",", ":"), sort_keys=True) @staticmethod - def to_partition(to_deserialize: Any): - return json.loads(to_deserialize) - - -class PerPartitionStreamSlice(StreamSlice): - def __init__(self, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any]): - self._partition = partition - self._cursor_slice = cursor_slice - if partition.keys() & cursor_slice.keys(): - raise ValueError("Keys for partition and incremental sync cursor should not overlap") - self._stream_slice = dict(partition) | dict(cursor_slice) - - @property - def partition(self): - return self._partition - - @property - def cursor_slice(self): - return self._cursor_slice - - def __repr__(self): - return repr(self._stream_slice) - - def __setitem__(self, key: str, value: Any): - raise ValueError("PerPartitionStreamSlice is immutable") - - def __getitem__(self, key: str): - return self._stream_slice[key] - - def __len__(self): - return len(self._stream_slice) - - def __iter__(self): - return iter(self._stream_slice) - - def __contains__(self, item: str): - return item in self._stream_slice - - def keys(self): - return self._stream_slice.keys() - - def items(self): - return self._stream_slice.items() - - def values(self): - return self._stream_slice.values() - - def get(self, key: str, default: Any) -> Any: - return self._stream_slice.get(key, default) - - def __eq__(self, other): - if isinstance(other, dict): - return self._stream_slice == other - if isinstance(other, PerPartitionStreamSlice): - # noinspection PyProtectedMember - return self._partition == other._partition and self._cursor_slice == other._cursor_slice - return False - - def __ne__(self, other): - return not self.__eq__(other) + def to_partition(to_deserialize: Any) -> Mapping[str, Any]: + return json.loads(to_deserialize) # type: ignore # The partition is known to be a dict, but the type hint is Any class CursorFactory: - def __init__(self, create_function: Callable[[], StreamSlicer]): + def __init__(self, create_function: Callable[[], Cursor]): self._create_function = create_function - def create(self) -> StreamSlicer: + def create(self) -> Cursor: return self._create_function() @@ -115,27 +57,27 @@ class PerPartitionCursor(Cursor): Therefore, we need to manage state per partition. """ - _NO_STATE = {} - _NO_CURSOR_STATE = {} + _NO_STATE: Mapping[str, Any] = {} + _NO_CURSOR_STATE: Mapping[str, Any] = {} _KEY = 0 _VALUE = 1 def __init__(self, cursor_factory: CursorFactory, partition_router: StreamSlicer): self._cursor_factory = cursor_factory self._partition_router = partition_router - self._cursor_per_partition = {} + self._cursor_per_partition: MutableMapping[str, Cursor] = {} self._partition_serializer = PerPartitionKeySerializer() - def stream_slices(self) -> Iterable[PerPartitionStreamSlice]: + def stream_slices(self) -> Iterable[StreamSlice]: slices = self._partition_router.stream_slices() for partition in slices: - cursor = self._cursor_per_partition.get(self._to_partition_key(partition)) + cursor = self._cursor_per_partition.get(self._to_partition_key(partition.partition)) if not cursor: cursor = self._create_cursor(self._NO_CURSOR_STATE) - self._cursor_per_partition[self._to_partition_key(partition)] = cursor + self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor for cursor_slice in cursor.stream_slices(): - yield PerPartitionStreamSlice(partition, cursor_slice) + yield StreamSlice(partition=partition, cursor_slice=cursor_slice) def set_initial_state(self, stream_state: StreamState) -> None: if not stream_state: @@ -147,10 +89,12 @@ def set_initial_state(self, stream_state: StreamState) -> None: def close_slice(self, stream_slice: StreamSlice, most_recent_record: Optional[Record]) -> None: try: cursor_most_recent_record = ( - Record(most_recent_record.data, stream_slice.cursor_slice) if most_recent_record else most_recent_record + Record(most_recent_record.data, StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice)) + if most_recent_record + else most_recent_record ) self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].close_slice( - stream_slice.cursor_slice, cursor_most_recent_record + StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), cursor_most_recent_record ) except KeyError as exception: raise ValueError( @@ -179,16 +123,16 @@ def _get_state_for_partition(self, partition: Mapping[str, Any]) -> Optional[Str return None @staticmethod - def _is_new_state(stream_state): + def _is_new_state(stream_state: Mapping[str, Any]) -> bool: return not bool(stream_state) - def _to_partition_key(self, partition) -> tuple: + def _to_partition_key(self, partition: Mapping[str, Any]) -> str: return self._partition_serializer.to_partition_key(partition) - def _to_dict(self, partition_key: tuple) -> StreamSlice: + def _to_dict(self, partition_key: str) -> Mapping[str, Any]: return self._partition_serializer.to_partition(partition_key) - def select_state(self, stream_slice: Optional[PerPartitionStreamSlice] = None) -> Optional[StreamState]: + def select_state(self, stream_slice: Optional[StreamSlice] = None) -> Optional[StreamState]: if not stream_slice: raise ValueError("A partition needs to be provided in order to extract a state") @@ -197,7 +141,7 @@ def select_state(self, stream_slice: Optional[PerPartitionStreamSlice] = None) - return self._get_state_for_partition(stream_slice.partition) - def _create_cursor(self, cursor_state: Any) -> StreamSlicer: + def _create_cursor(self, cursor_state: Any) -> Cursor: cursor = self._cursor_factory.create() cursor.set_initial_state(cursor_state) return cursor @@ -209,11 +153,18 @@ def get_request_params( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._partition_router.get_request_params( - stream_state=stream_state, stream_slice=stream_slice.partition, next_page_token=next_page_token - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_params( - stream_state=stream_state, stream_slice=stream_slice.cursor_slice, next_page_token=next_page_token - ) + if stream_slice: + return self._partition_router.get_request_params( # type: ignore # this always returns a mapping + stream_state=stream_state, + stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), + next_page_token=next_page_token, + ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_params( + stream_state=stream_state, + stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), + next_page_token=next_page_token, + ) + else: + raise ValueError("A partition needs to be provided in order to get request params") def get_request_headers( self, @@ -222,11 +173,18 @@ def get_request_headers( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._partition_router.get_request_headers( - stream_state=stream_state, stream_slice=stream_slice.partition, next_page_token=next_page_token - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_headers( - stream_state=stream_state, stream_slice=stream_slice.cursor_slice, next_page_token=next_page_token - ) + if stream_slice: + return self._partition_router.get_request_headers( # type: ignore # this always returns a mapping + stream_state=stream_state, + stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), + next_page_token=next_page_token, + ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_headers( + stream_state=stream_state, + stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), + next_page_token=next_page_token, + ) + else: + raise ValueError("A partition needs to be provided in order to get request headers") def get_request_body_data( self, @@ -234,12 +192,19 @@ def get_request_body_data( stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Mapping[str, Any]: - return self._partition_router.get_request_body_data( - stream_state=stream_state, stream_slice=stream_slice.partition, next_page_token=next_page_token - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_body_data( - stream_state=stream_state, stream_slice=stream_slice.cursor_slice, next_page_token=next_page_token - ) + ) -> Union[Mapping[str, Any], str]: + if stream_slice: + return self._partition_router.get_request_body_data( # type: ignore # this always returns a mapping + stream_state=stream_state, + stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), + next_page_token=next_page_token, + ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_body_data( + stream_state=stream_state, + stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), + next_page_token=next_page_token, + ) + else: + raise ValueError("A partition needs to be provided in order to get request body data") def get_request_body_json( self, @@ -248,16 +213,25 @@ def get_request_body_json( stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, ) -> Mapping[str, Any]: - return self._partition_router.get_request_body_json( - stream_state=stream_state, stream_slice=stream_slice.partition, next_page_token=next_page_token - ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_body_json( - stream_state=stream_state, stream_slice=stream_slice.cursor_slice, next_page_token=next_page_token - ) + if stream_slice: + return self._partition_router.get_request_body_json( # type: ignore # this always returns a mapping + stream_state=stream_state, + stream_slice=StreamSlice(partition=stream_slice.partition, cursor_slice={}), + next_page_token=next_page_token, + ) | self._cursor_per_partition[self._to_partition_key(stream_slice.partition)].get_request_body_json( + stream_state=stream_state, + stream_slice=StreamSlice(partition={}, cursor_slice=stream_slice.cursor_slice), + next_page_token=next_page_token, + ) + else: + raise ValueError("A partition needs to be provided in order to get request body json") def should_be_synced(self, record: Record) -> bool: return self._get_cursor(record).should_be_synced(self._convert_record_to_cursor_record(record)) def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: + if not first.associated_slice or not second.associated_slice: + raise ValueError(f"Both records should have an associated slice but got {first.associated_slice} and {second.associated_slice}") if first.associated_slice.partition != second.associated_slice.partition: raise ValueError( f"To compare records, partition should be the same but got {first.associated_slice.partition} and {second.associated_slice.partition}" @@ -268,10 +242,15 @@ def is_greater_than_or_equal(self, first: Record, second: Record) -> bool: ) @staticmethod - def _convert_record_to_cursor_record(record: Record): - return Record(record.data, record.associated_slice.cursor_slice) + def _convert_record_to_cursor_record(record: Record) -> Record: + return Record( + record.data, + StreamSlice(partition={}, cursor_slice=record.associated_slice.cursor_slice) if record.associated_slice else None, + ) def _get_cursor(self, record: Record) -> Cursor: + if not record.associated_slice: + raise ValueError("Invalid state as stream slices that are emitted should refer to an existing cursor") partition_key = self._to_partition_key(record.associated_slice.partition) if partition_key not in self._cursor_per_partition: raise ValueError("Invalid state as stream slices that are emitted should refer to an existing cursor") diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/interpolation/jinja.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/interpolation/jinja.py index 91d52c7579f4..d2ef7a9d0464 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/interpolation/jinja.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/interpolation/jinja.py @@ -3,7 +3,7 @@ # import ast -from typing import Any, Optional, Tuple, Type +from typing import Any, Mapping, Optional, Tuple, Type from airbyte_cdk.sources.declarative.interpolation.filters import filters from airbyte_cdk.sources.declarative.interpolation.interpolation import Interpolation @@ -48,7 +48,7 @@ class JinjaInterpolation(Interpolation): # Please add a unit test to test_jinja.py when adding a restriction. RESTRICTED_BUILTIN_FUNCTIONS = ["range"] # The range function can cause very expensive computations - def __init__(self): + def __init__(self) -> None: self._environment = Environment() self._environment.filters.update(**filters) self._environment.globals.update(**macros) @@ -64,8 +64,8 @@ def eval( config: Config, default: Optional[str] = None, valid_types: Optional[Tuple[Type[Any]]] = None, - **additional_parameters, - ): + **additional_parameters: Any, + ) -> Any: context = {"config": config, **additional_parameters} for alias, equivalent in self.ALIASES.items(): @@ -90,23 +90,23 @@ def eval( # If result is empty or resulted in an undefined error, evaluate and return the default string return self._literal_eval(self._eval(default, context), valid_types) - def _literal_eval(self, result, valid_types: Optional[Tuple[Type[Any]]]): + def _literal_eval(self, result: Optional[str], valid_types: Optional[Tuple[Type[Any]]]) -> Any: try: - evaluated = ast.literal_eval(result) + evaluated = ast.literal_eval(result) # type: ignore # literal_eval is able to handle None except (ValueError, SyntaxError): return result if not valid_types or (valid_types and isinstance(evaluated, valid_types)): return evaluated return result - def _eval(self, s: str, context): + def _eval(self, s: Optional[str], context: Mapping[str, Any]) -> Optional[str]: try: - ast = self._environment.parse(s) + ast = self._environment.parse(s) # type: ignore # parse is able to handle None undeclared = meta.find_undeclared_variables(ast) undeclared_not_in_context = {var for var in undeclared if var not in context} if undeclared_not_in_context: raise ValueError(f"Jinja macro has undeclared variables: {undeclared_not_in_context}. Context: {context}") - return self._environment.from_string(s).render(context) + return self._environment.from_string(s).render(context) # type: ignore # from_string is able to handle None except TypeError: # The string is a static value, not a jinja template # It can be returned as is diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py index 5413709d9615..3490c02f7a0c 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/list_partition_router.py @@ -30,11 +30,12 @@ class ListPartitionRouter(StreamSlicer): parameters: InitVar[Mapping[str, Any]] request_option: Optional[RequestOption] = None - def __post_init__(self, parameters: Mapping[str, Any]): + def __post_init__(self, parameters: Mapping[str, Any]) -> None: if isinstance(self.values, str): self.values = InterpolatedString.create(self.values, parameters=parameters).eval(self.config) - if isinstance(self.cursor_field, str): - self.cursor_field = InterpolatedString(string=self.cursor_field, parameters=parameters) + self._cursor_field = ( + InterpolatedString(string=self.cursor_field, parameters=parameters) if isinstance(self.cursor_field, str) else self.cursor_field + ) self._cursor = None @@ -75,13 +76,13 @@ def get_request_body_json( return self._get_request_option(RequestOptionType.body_json, stream_slice) def stream_slices(self) -> Iterable[StreamSlice]: - return [{self.cursor_field.eval(self.config): slice_value} for slice_value in self.values] + return [StreamSlice(partition={self._cursor_field.eval(self.config): slice_value}, cursor_slice={}) for slice_value in self.values] - def _get_request_option(self, request_option_type: RequestOptionType, stream_slice: StreamSlice): + def _get_request_option(self, request_option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: if self.request_option and self.request_option.inject_into == request_option_type and stream_slice: - slice_value = stream_slice.get(self.cursor_field.eval(self.config)) + slice_value = stream_slice.get(self._cursor_field.eval(self.config)) if slice_value: - return {self.request_option.field_name.eval(self.config): slice_value} + return {self.request_option.field_name.eval(self.config): slice_value} # type: ignore # field_name is always casted to InterpolatedString else: return {} else: diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py index 4697d114eb1a..d1e7bab68e40 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/single_partition_router.py @@ -48,4 +48,4 @@ def get_request_body_json( return {} def stream_slices(self) -> Iterable[StreamSlice]: - yield dict() + yield StreamSlice(partition={}, cursor_slice={}) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py index 3e915168c059..7d93bacb084f 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/partition_routers/substream_partition_router.py @@ -3,7 +3,7 @@ # from dataclasses import InitVar, dataclass -from typing import Any, Iterable, List, Mapping, Optional, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Mapping, Optional, Union import dpath.util from airbyte_cdk.models import AirbyteMessage, SyncMode, Type @@ -11,7 +11,9 @@ from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer from airbyte_cdk.sources.declarative.types import Config, Record, StreamSlice, StreamState -from airbyte_cdk.sources.streams.core import Stream + +if TYPE_CHECKING: + from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream @dataclass @@ -25,14 +27,14 @@ class ParentStreamConfig: request_option: How to inject the slice value on an outgoing HTTP request """ - stream: Stream + stream: "DeclarativeStream" # Parent streams must be DeclarativeStream because we can't know which part of the stream slice is a partition for regular Stream parent_key: Union[InterpolatedString, str] partition_field: Union[InterpolatedString, str] config: Config parameters: InitVar[Mapping[str, Any]] request_option: Optional[RequestOption] = None - def __post_init__(self, parameters: Mapping[str, Any]): + def __post_init__(self, parameters: Mapping[str, Any]) -> None: self.parent_key = InterpolatedString.create(self.parent_key, parameters=parameters) self.partition_field = InterpolatedString.create(self.partition_field, parameters=parameters) @@ -51,7 +53,7 @@ class SubstreamPartitionRouter(StreamSlicer): config: Config parameters: InitVar[Mapping[str, Any]] - def __post_init__(self, parameters: Mapping[str, Any]): + def __post_init__(self, parameters: Mapping[str, Any]) -> None: if not self.parent_stream_configs: raise ValueError("SubstreamPartitionRouter needs at least 1 parent stream") self._parameters = parameters @@ -88,19 +90,19 @@ def get_request_body_json( stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping]: + ) -> Mapping[str, Any]: # Pass the stream_slice from the argument, not the cursor because the cursor is updated after processing the response return self._get_request_option(RequestOptionType.body_json, stream_slice) - def _get_request_option(self, option_type: RequestOptionType, stream_slice: StreamSlice): + def _get_request_option(self, option_type: RequestOptionType, stream_slice: Optional[StreamSlice]) -> Mapping[str, Any]: params = {} if stream_slice: for parent_config in self.parent_stream_configs: if parent_config.request_option and parent_config.request_option.inject_into == option_type: - key = parent_config.partition_field.eval(self.config) + key = parent_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string value = stream_slice.get(key) if value: - params.update({parent_config.request_option.field_name.eval(config=self.config): value}) + params.update({parent_config.request_option.field_name.eval(config=self.config): value}) # type: ignore # field_name is always casted to an interpolated string return params def stream_slices(self) -> Iterable[StreamSlice]: @@ -123,13 +125,13 @@ def stream_slices(self) -> Iterable[StreamSlice]: else: for parent_stream_config in self.parent_stream_configs: parent_stream = parent_stream_config.stream - parent_field = parent_stream_config.parent_key.eval(self.config) - stream_state_field = parent_stream_config.partition_field.eval(self.config) + parent_field = parent_stream_config.parent_key.eval(self.config) # type: ignore # parent_key is always casted to an interpolated string + partition_field = parent_stream_config.partition_field.eval(self.config) # type: ignore # partition_field is always casted to an interpolated string for parent_stream_slice in parent_stream.stream_slices( sync_mode=SyncMode.full_refresh, cursor_field=None, stream_state=None ): empty_parent_slice = True - parent_slice = parent_stream_slice + parent_partition = parent_stream_slice.partition if parent_stream_slice else {} for parent_record in parent_stream.read_records( sync_mode=SyncMode.full_refresh, cursor_field=None, stream_slice=parent_stream_slice, stream_state=None @@ -143,12 +145,14 @@ def stream_slices(self) -> Iterable[StreamSlice]: elif isinstance(parent_record, Record): parent_record = parent_record.data try: - stream_state_value = dpath.util.get(parent_record, parent_field) + partition_value = dpath.util.get(parent_record, parent_field) except KeyError: pass else: empty_parent_slice = False - yield {stream_state_field: stream_state_value, "parent_slice": parent_slice} + yield StreamSlice( + partition={partition_field: partition_value, "parent_slice": parent_partition}, cursor_slice={} + ) # If the parent slice contains no records, if empty_parent_slice: yield from [] diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/http_requester.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/http_requester.py index 20c18ec9ba6c..98e12eef908a 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/http_requester.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/http_requester.py @@ -456,6 +456,9 @@ def send_request( json=self._request_body_json(stream_state, stream_slice, next_page_token, request_body_json), data=self._request_body_data(stream_state, stream_slice, next_page_token, request_body_data), ) + import time + + time.sleep(1) response = self._send_with_retry(request, log_formatter=log_formatter) return self._validate_response(response) diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py index b07ffb3f6f08..c03a232e368e 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/request_options/request_options_provider.py @@ -4,7 +4,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Mapping, MutableMapping, Optional, Union +from typing import Any, Mapping, Optional, Union from airbyte_cdk.sources.declarative.types import StreamSlice, StreamState @@ -28,7 +28,7 @@ def get_request_params( stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> MutableMapping[str, Any]: + ) -> Mapping[str, Any]: """ Specifies the query parameters that should be set on an outgoing HTTP request given the inputs. @@ -53,7 +53,7 @@ def get_request_body_data( stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Union[Mapping[str, Any], str]]: + ) -> Union[Mapping[str, Any], str]: """ Specifies how to populate the body of the request with a non-JSON payload. @@ -71,7 +71,7 @@ def get_request_body_json( stream_state: Optional[StreamState] = None, stream_slice: Optional[StreamSlice] = None, next_page_token: Optional[Mapping[str, Any]] = None, - ) -> Optional[Mapping[str, Any]]: + ) -> Mapping[str, Any]: """ Specifies how to populate the body of the request with a JSON payload. diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/retriever.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/retriever.py index bf4247a4f441..cd6310a0948f 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/retriever.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/retriever.py @@ -6,7 +6,8 @@ from dataclasses import dataclass from typing import Any, Iterable, Mapping, Optional -from airbyte_cdk.sources.declarative.types import StreamSlice, StreamState +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice +from airbyte_cdk.sources.declarative.types import StreamState from airbyte_cdk.sources.streams.core import StreamData diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py index a9c946044922..7028850bcaaa 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py @@ -258,7 +258,7 @@ def _next_page_token(self, response: requests.Response) -> Optional[Mapping[str, return self._paginator.next_page_token(response, self._records_from_last_response) def _fetch_next_page( - self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any], next_page_token: Optional[Mapping[str, Any]] = None + self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, next_page_token: Optional[Mapping[str, Any]] = None ) -> Optional[requests.Response]: return self.requester.send_request( path=self._paginator_path(), @@ -280,7 +280,7 @@ def _read_pages( self, records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]], stream_state: Mapping[str, Any], - stream_slice: Mapping[str, Any], + stream_slice: StreamSlice, ) -> Iterable[StreamData]: pagination_complete = False next_page_token = None @@ -310,7 +310,7 @@ def read_records( :param stream_slice: The stream slice to read data for :return: The records read from the API source """ - stream_slice = stream_slice or {} # None-check + _slice = stream_slice or StreamSlice(partition={}, cursor_slice={}) # None-check # Fixing paginator types has a long tail of dependencies self._paginator.reset() @@ -318,15 +318,15 @@ def read_records( record_generator = partial( self._parse_records, stream_state=self.state or {}, - stream_slice=stream_slice, + stream_slice=_slice, records_schema=records_schema, ) - for stream_data in self._read_pages(record_generator, self.state, stream_slice): - most_recent_record_from_slice = self._get_most_recent_record(most_recent_record_from_slice, stream_data, stream_slice) + for stream_data in self._read_pages(record_generator, self.state, _slice): + most_recent_record_from_slice = self._get_most_recent_record(most_recent_record_from_slice, stream_data, _slice) yield stream_data if self.cursor: - self.cursor.close_slice(stream_slice, most_recent_record_from_slice) + self.cursor.close_slice(_slice, most_recent_record_from_slice) return def _get_most_recent_record( @@ -356,7 +356,7 @@ def _extract_record(stream_data: StreamData, stream_slice: StreamSlice) -> Optio return None # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: # type: ignore + def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore """ Specifies the slices for this stream. See the stream slicing section of the docs for more information. @@ -382,7 +382,7 @@ def _parse_records( response: Optional[requests.Response], stream_state: Mapping[str, Any], records_schema: Mapping[str, Any], - stream_slice: Optional[Mapping[str, Any]], + stream_slice: Optional[StreamSlice], ) -> Iterable[StreamData]: yield from self._parse_response( response, @@ -412,11 +412,11 @@ def __post_init__(self, options: Mapping[str, Any]) -> None: ) # stream_slices is defined with arguments on http stream and fixing this has a long tail of dependencies. Will be resolved by the decoupling of http stream and simple retriever - def stream_slices(self) -> Iterable[Optional[Mapping[str, Any]]]: # type: ignore + def stream_slices(self) -> Iterable[Optional[StreamSlice]]: # type: ignore return islice(super().stream_slices(), self.maximum_number_of_slices) def _fetch_next_page( - self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any], next_page_token: Optional[Mapping[str, Any]] = None + self, stream_state: Mapping[str, Any], stream_slice: StreamSlice, next_page_token: Optional[Mapping[str, Any]] = None ) -> Optional[requests.Response]: return self.requester.send_request( path=self._paginator_path(), diff --git a/airbyte-cdk/python/airbyte_cdk/sources/declarative/types.py b/airbyte-cdk/python/airbyte_cdk/sources/declarative/types.py index fd0eba51676c..734ff29fffb7 100644 --- a/airbyte-cdk/python/airbyte_cdk/sources/declarative/types.py +++ b/airbyte-cdk/python/airbyte_cdk/sources/declarative/types.py @@ -4,14 +4,13 @@ from __future__ import annotations -from typing import Any, List, Mapping, Optional +from typing import Any, ItemsView, Iterator, KeysView, List, Mapping, Optional, ValuesView # A FieldPointer designates a path to a field inside a mapping. For example, retrieving ["k1", "k1.2"] in the object {"k1" :{"k1.2": # "hello"}] returns "hello" FieldPointer = List[str] Config = Mapping[str, Any] ConnectionDefinition = Mapping[str, Any] -StreamSlice = Mapping[str, Any] StreamState = Mapping[str, Any] @@ -51,3 +50,67 @@ def __eq__(self, other: object) -> bool: def __ne__(self, other: object) -> bool: return not self.__eq__(other) + + +class StreamSlice(Mapping[str, Any]): + def __init__(self, *, partition: Mapping[str, Any], cursor_slice: Mapping[str, Any]) -> None: + self._partition = partition + self._cursor_slice = cursor_slice + if partition.keys() & cursor_slice.keys(): + raise ValueError("Keys for partition and incremental sync cursor should not overlap") + self._stream_slice = dict(partition) | dict(cursor_slice) + + @property + def partition(self) -> Mapping[str, Any]: + p = self._partition + while isinstance(p, StreamSlice): + p = p.partition + return p + + @property + def cursor_slice(self) -> Mapping[str, Any]: + c = self._cursor_slice + while isinstance(c, StreamSlice): + c = c.cursor_slice + return c + + def __repr__(self) -> str: + return repr(self._stream_slice) + + def __setitem__(self, key: str, value: Any) -> None: + raise ValueError("StreamSlice is immutable") + + def __getitem__(self, key: str) -> Any: + return self._stream_slice[key] + + def __len__(self) -> int: + return len(self._stream_slice) + + def __iter__(self) -> Iterator[str]: + return iter(self._stream_slice) + + def __contains__(self, item: Any) -> bool: + return item in self._stream_slice + + def keys(self) -> KeysView[str]: + return self._stream_slice.keys() + + def items(self) -> ItemsView[str, Any]: + return self._stream_slice.items() + + def values(self) -> ValuesView[Any]: + return self._stream_slice.values() + + def get(self, key: str, default: Any = None) -> Optional[Any]: + return self._stream_slice.get(key, default) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, dict): + return self._stream_slice == other + if isinstance(other, StreamSlice): + # noinspection PyProtectedMember + return self._partition == other._partition and self._cursor_slice == other._cursor_slice + return False + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/datetime/test_min_max_datetime.py b/airbyte-cdk/python/unit_tests/sources/declarative/datetime/test_min_max_datetime.py index b23f6e2fffe9..84a63969cec6 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/datetime/test_min_max_datetime.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/datetime/test_min_max_datetime.py @@ -6,6 +6,7 @@ import pytest from airbyte_cdk.sources.declarative.datetime.min_max_datetime import MinMaxDatetime +from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString date_format = "%Y-%m-%dT%H:%M:%S.%f%z" @@ -110,3 +111,17 @@ def test_min_max_datetime_lazy_eval(): assert datetime.datetime(2021, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) == MinMaxDatetime( **kwargs, parameters={"max_datetime": "2021-01-01T00:00:00"} ).get_datetime({}) + + +@pytest.mark.parametrize( + "input_datetime", [ + pytest.param("2022-01-01T00:00:00", id="test_create_min_max_datetime_from_string"), + pytest.param(InterpolatedString.create("2022-01-01T00:00:00", parameters={}), id="test_create_min_max_datetime_from_string"), + pytest.param(MinMaxDatetime("2022-01-01T00:00:00", parameters={}), id="test_create_min_max_datetime_from_minmaxdatetime") + ] +) +def test_create_min_max_datetime(input_datetime): + minMaxDatetime = MinMaxDatetime.create(input_datetime, parameters={}) + expected_value = "2022-01-01T00:00:00" + + assert minMaxDatetime.datetime.eval(config={}) == expected_value diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py index c128f04f391d..6d93dd50c764 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_datetime_based_cursor.py @@ -10,7 +10,7 @@ from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType -from airbyte_cdk.sources.declarative.types import Record +from airbyte_cdk.sources.declarative.types import Record, StreamSlice datetime_format = "%Y-%m-%dT%H:%M:%S.%f%z" cursor_granularity = "PT0.000001S" @@ -343,35 +343,35 @@ def test_stream_slices( ( "test_close_slice_previous_cursor_is_highest", "2023-01-01", - {"end_time": "2022-01-01"}, + StreamSlice(partition={}, cursor_slice={"end_time": "2022-01-01"}), {cursor_field: "2021-01-01"}, {cursor_field: "2023-01-01"}, ), ( "test_close_slice_stream_slice_partition_end_is_highest", "2021-01-01", - {"end_time": "2023-01-01"}, + StreamSlice(partition={}, cursor_slice={"end_time": "2023-01-01"}), {cursor_field: "2021-01-01"}, {cursor_field: "2023-01-01"}, ), ( "test_close_slice_latest_record_cursor_value_is_highest", "2021-01-01", - {"end_time": "2022-01-01"}, + StreamSlice(partition={}, cursor_slice={"end_time": "2022-01-01"}), {cursor_field: "2023-01-01"}, {cursor_field: "2023-01-01"}, ), ( "test_close_slice_without_latest_record", "2021-01-01", - {"end_time": "2022-01-01"}, + StreamSlice(partition={}, cursor_slice={"end_time": "2022-01-01"}), None, {cursor_field: "2022-01-01"}, ), ( "test_close_slice_without_cursor", None, - {"end_time": "2022-01-01"}, + StreamSlice(partition={}, cursor_slice={"end_time": "2022-01-01"}), {cursor_field: "2023-01-01"}, {cursor_field: "2023-01-01"}, ), @@ -391,6 +391,19 @@ def test_close_slice(test_name, previous_cursor, stream_slice, latest_record_dat assert updated_state == expected_state +def test_close_slice_fails_if_slice_has_a_partition(): + cursor = DatetimeBasedCursor( + start_datetime=MinMaxDatetime(datetime="2021-01-01T00:00:00.000000+0000", parameters={}), + cursor_field=InterpolatedString(string=cursor_field, parameters={}), + datetime_format="%Y-%m-%d", + config=config, + parameters={}, + ) + stream_slice = StreamSlice(partition={"key": "value"}, cursor_slice={"end_time": "2022-01-01"}) + with pytest.raises(ValueError): + cursor.close_slice(stream_slice, Record({"id": 1}, stream_slice)) + + def test_given_different_format_and_slice_is_highest_when_close_slice_then_slice_datetime_format(): cursor = DatetimeBasedCursor( start_datetime=MinMaxDatetime(datetime="2021-01-01T00:00:00.000000+0000", parameters={}), @@ -401,7 +414,7 @@ def test_given_different_format_and_slice_is_highest_when_close_slice_then_slice parameters={}, ) - _slice = {"end_time": "2023-01-04T17:30:19.000Z"} + _slice = StreamSlice(partition={}, cursor_slice={"end_time": "2023-01-04T17:30:19.000Z"}) record_cursor_value = "2023-01-03" cursor.close_slice(_slice, Record({cursor_field: record_cursor_value}, _slice)) @@ -418,7 +431,7 @@ def test_given_partition_end_is_specified_and_greater_than_record_when_close_sli config=config, parameters={}, ) - stream_slice = {partition_field_end: "2025-01-01"} + stream_slice = StreamSlice(partition={}, cursor_slice={partition_field_end: "2025-01-01"}) cursor.close_slice(stream_slice, Record({cursor_field: "2020-01-01"}, stream_slice)) updated_state = cursor.get_stream_state() assert {cursor_field: "2025-01-01"} == updated_state @@ -489,6 +502,31 @@ def test_request_option(test_name, inject_into, field_name, expected_req_params, assert expected_body_data == slicer.get_request_body_data(stream_slice=stream_slice) +@pytest.mark.parametrize( + "stream_slice", [ + pytest.param(None, id="test_none_stream_slice"), + pytest.param({}, id="test_none_stream_slice"), + ] +) +def test_request_option_with_empty_stream_slice(stream_slice): + start_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="starttime") + end_request_option = RequestOption(inject_into=RequestOptionType.request_parameter, parameters={}, field_name="endtime") + slicer = DatetimeBasedCursor( + start_datetime=MinMaxDatetime(datetime="2021-01-01T00:00:00.000000+0000", parameters={}), + end_datetime=MinMaxDatetime(datetime="2021-01-10T00:00:00.000000+0000", parameters={}), + step="P1D", + cursor_field=InterpolatedString(string=cursor_field, parameters={}), + datetime_format=datetime_format, + cursor_granularity=cursor_granularity, + lookback_window=InterpolatedString(string="P0D", parameters={}), + start_time_option=start_request_option, + end_time_option=end_request_option, + config=config, + parameters={}, + ) + assert {} == slicer.get_request_params(stream_slice=stream_slice) + + @pytest.mark.parametrize( "test_name, input_date, date_format, date_format_granularity, expected_output_date", [ diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py index cb7857c9352a..769f3e073fcc 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor.py @@ -7,11 +7,7 @@ import pytest from airbyte_cdk.sources.declarative.incremental.cursor import Cursor -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import ( - PerPartitionCursor, - PerPartitionKeySerializer, - PerPartitionStreamSlice, -) +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionCursor, PerPartitionKeySerializer, StreamSlice from airbyte_cdk.sources.declarative.stream_slicers.stream_slicer import StreamSlicer from airbyte_cdk.sources.declarative.types import Record @@ -78,13 +74,13 @@ def test_given_tuples_in_json_then_deserialization_convert_to_list(): def test_stream_slice_merge_dictionaries(): - stream_slice = PerPartitionStreamSlice({"partition key": "partition value"}, {"cursor key": "cursor value"}) + stream_slice = StreamSlice(partition={"partition key": "partition value"}, cursor_slice={"cursor key": "cursor value"}) assert stream_slice == {"partition key": "partition value", "cursor key": "cursor value"} def test_overlapping_slice_keys_raise_error(): with pytest.raises(ValueError): - PerPartitionStreamSlice({"overlapping key": "partition value"}, {"overlapping key": "cursor value"}) + StreamSlice(partition={"overlapping key": "partition value"}, cursor_slice={"overlapping key": "cursor value"}) class MockedCursorBuilder: @@ -131,7 +127,7 @@ def test_given_no_partition_when_stream_slices_then_no_slices(mocked_cursor_fact def test_given_partition_router_without_state_has_one_partition_then_return_one_slice_per_cursor_slice( mocked_cursor_factory, mocked_partition_router ): - partition = {"partition_field_1": "a value", "partition_field_2": "another value"} + partition = StreamSlice(partition={"partition_field_1": "a value", "partition_field_2": "another value"}, cursor_slice={}) mocked_partition_router.stream_slices.return_value = [partition] cursor_slices = [{"start_datetime": 1}, {"start_datetime": 2}] mocked_cursor_factory.create.return_value = MockedCursorBuilder().with_stream_slices(cursor_slices).build() @@ -139,19 +135,19 @@ def test_given_partition_router_without_state_has_one_partition_then_return_one_ slices = cursor.stream_slices() - assert list(slices) == [PerPartitionStreamSlice(partition, cursor_slice) for cursor_slice in cursor_slices] + assert list(slices) == [StreamSlice(partition=partition, cursor_slice=cursor_slice) for cursor_slice in cursor_slices] def test_given_partition_associated_with_state_when_stream_slices_then_do_not_recreate_cursor( mocked_cursor_factory, mocked_partition_router ): - partition = {"partition_field_1": "a value", "partition_field_2": "another value"} + partition = StreamSlice(partition={"partition_field_1": "a value", "partition_field_2": "another value"}, cursor_slice={}) mocked_partition_router.stream_slices.return_value = [partition] cursor_slices = [{"start_datetime": 1}] mocked_cursor_factory.create.return_value = MockedCursorBuilder().with_stream_slices(cursor_slices).build() cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) - cursor.set_initial_state({"states": [{"partition": partition, "cursor": CURSOR_STATE}]}) + cursor.set_initial_state({"states": [{"partition": partition.partition, "cursor": CURSOR_STATE}]}) mocked_cursor_factory.create.assert_called_once() slices = list(cursor.stream_slices()) @@ -161,7 +157,7 @@ def test_given_partition_associated_with_state_when_stream_slices_then_do_not_re def test_given_multiple_partitions_then_each_have_their_state(mocked_cursor_factory, mocked_partition_router): first_partition = {"first_partition_key": "first_partition_value"} - mocked_partition_router.stream_slices.return_value = [first_partition, {"second_partition_key": "second_partition_value"}] + mocked_partition_router.stream_slices.return_value = [StreamSlice(partition=first_partition, cursor_slice={}), StreamSlice(partition={"second_partition_key": "second_partition_value"}, cursor_slice={})] first_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() second_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "second slice cursor value"}]).build() mocked_cursor_factory.create.side_effect = [first_cursor, second_cursor] @@ -173,10 +169,10 @@ def test_given_multiple_partitions_then_each_have_their_state(mocked_cursor_fact first_cursor.stream_slices.assert_called_once() second_cursor.stream_slices.assert_called_once() assert slices == [ - PerPartitionStreamSlice( + StreamSlice( partition={"first_partition_key": "first_partition_value"}, cursor_slice={CURSOR_SLICE_FIELD: "first slice cursor value"} ), - PerPartitionStreamSlice( + StreamSlice( partition={"second_partition_key": "second_partition_value"}, cursor_slice={CURSOR_SLICE_FIELD: "second slice cursor value"} ), ] @@ -187,7 +183,7 @@ def test_given_stream_slices_when_get_stream_state_then_return_updated_state(moc MockedCursorBuilder().with_stream_state({CURSOR_STATE_KEY: "first slice cursor value"}).build(), MockedCursorBuilder().with_stream_state({CURSOR_STATE_KEY: "second slice cursor value"}).build(), ] - mocked_partition_router.stream_slices.return_value = [{"partition key": "first partition"}, {"partition key": "second partition"}] + mocked_partition_router.stream_slices.return_value = [StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), StreamSlice(partition={"partition key": "second partition"}, cursor_slice={})] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) list(cursor.stream_slices()) assert cursor.get_stream_state() == { @@ -201,7 +197,7 @@ def test_given_stream_slices_when_get_stream_state_then_return_updated_state(moc def test_when_get_stream_state_then_delegate_to_underlying_cursor(mocked_cursor_factory, mocked_partition_router): underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() mocked_cursor_factory.create.side_effect = [underlying_cursor] - mocked_partition_router.stream_slices.return_value = [{"partition key": "first partition"}] + mocked_partition_router.stream_slices.return_value = [StreamSlice(partition={"partition key": "first partition"}, cursor_slice={})] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) first_slice = list(cursor.stream_slices())[0] @@ -213,8 +209,8 @@ def test_when_get_stream_state_then_delegate_to_underlying_cursor(mocked_cursor_ def test_close_slice(mocked_cursor_factory, mocked_partition_router): underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() mocked_cursor_factory.create.side_effect = [underlying_cursor] - stream_slice = PerPartitionStreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) - mocked_partition_router.stream_slices.return_value = [stream_slice.partition] + stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) + mocked_partition_router.stream_slices.return_value = [stream_slice] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) last_record = Mock() list(cursor.stream_slices()) # generate internal state @@ -227,8 +223,8 @@ def test_close_slice(mocked_cursor_factory, mocked_partition_router): def test_given_no_last_record_when_close_slice_then_do_not_raise_error(mocked_cursor_factory, mocked_partition_router): underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() mocked_cursor_factory.create.side_effect = [underlying_cursor] - stream_slice = PerPartitionStreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) - mocked_partition_router.stream_slices.return_value = [stream_slice.partition] + stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) + mocked_partition_router.stream_slices.return_value = [stream_slice] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) list(cursor.stream_slices()) # generate internal state @@ -241,7 +237,7 @@ def test_given_unknown_partition_when_close_slice_then_raise_error(): any_cursor_factory = Mock() any_partition_router = Mock() cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) - stream_slice = PerPartitionStreamSlice(partition={"unknown_partition": "unknown"}, cursor_slice={}) + stream_slice = StreamSlice(partition={"unknown_partition": "unknown"}, cursor_slice={}) with pytest.raises(ValueError): cursor.close_slice(stream_slice, Record({}, stream_slice)) @@ -251,7 +247,7 @@ def test_given_unknown_partition_when_should_be_synced_then_raise_error(): any_partition_router = Mock() cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) with pytest.raises(ValueError): - cursor.should_be_synced(Record({}, PerPartitionStreamSlice(partition={"unknown_partition": "unknown"}, cursor_slice={}))) + cursor.should_be_synced(Record({}, StreamSlice(partition={"unknown_partition": "unknown"}, cursor_slice={}))) def test_given_records_with_different_slice_when_is_greater_than_or_equal_then_raise_error(): @@ -260,8 +256,26 @@ def test_given_records_with_different_slice_when_is_greater_than_or_equal_then_r cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) with pytest.raises(ValueError): cursor.is_greater_than_or_equal( - Record({}, PerPartitionStreamSlice(partition={"a slice": "value"}, cursor_slice={})), - Record({}, PerPartitionStreamSlice(partition={"another slice": "value"}, cursor_slice={})), + Record({}, StreamSlice(partition={"a slice": "value"}, cursor_slice={})), + Record({}, StreamSlice(partition={"another slice": "value"}, cursor_slice={})), + ) + + +@pytest.mark.parametrize( + "first_record_slice, second_record_slice", + [ + pytest.param(StreamSlice(partition={"a slice": "value"}, cursor_slice={}), None, id="second record does not have a slice"), + pytest.param(None, StreamSlice(partition={"a slice": "value"}, cursor_slice={}), id="first record does not have a slice"), + ] +) +def test_given_records_without_a_slice_when_is_greater_than_or_equal_then_raise_error(first_record_slice, second_record_slice): + any_cursor_factory = Mock() + any_partition_router = Mock() + cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) + with pytest.raises(ValueError): + cursor.is_greater_than_or_equal( + Record({}, first_record_slice), + Record({}, second_record_slice) ) @@ -271,16 +285,16 @@ def test_given_slice_is_unknown_when_is_greater_than_or_equal_then_raise_error() cursor = PerPartitionCursor(any_cursor_factory, any_partition_router) with pytest.raises(ValueError): cursor.is_greater_than_or_equal( - Record({}, PerPartitionStreamSlice(partition={"a slice": "value"}, cursor_slice={})), - Record({}, PerPartitionStreamSlice(partition={"a slice": "value"}, cursor_slice={})), + Record({}, StreamSlice(partition={"a slice": "value"}, cursor_slice={})), + Record({}, StreamSlice(partition={"a slice": "value"}, cursor_slice={})), ) def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response(mocked_cursor_factory, mocked_partition_router): underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() mocked_cursor_factory.create.side_effect = [underlying_cursor] - stream_slice = PerPartitionStreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) - mocked_partition_router.stream_slices.return_value = [stream_slice.partition] + stream_slice = StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}) + mocked_partition_router.stream_slices.return_value = [stream_slice] cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) first_record = Record({"first": "value"}, stream_slice) second_record = Record({"second": "value"}, stream_slice) @@ -290,3 +304,103 @@ def test_when_is_greater_than_or_equal_then_return_underlying_cursor_response(mo assert result == underlying_cursor.is_greater_than_or_equal.return_value underlying_cursor.is_greater_than_or_equal.assert_called_once_with(first_record, second_record) + + +@pytest.mark.parametrize( + "stream_slice, expected_output", + [ + pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"), + pytest.param(None, None, id="first partition"), + ] +) +def test_get_request_params(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): + underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() + underlying_cursor.get_request_params.return_value = {"cursor": "params"} + mocked_cursor_factory.create.side_effect = [underlying_cursor] + mocked_partition_router.stream_slices.return_value = [stream_slice] + mocked_partition_router.get_request_params.return_value = {"router": "params"} + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + if stream_slice: + cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + params = cursor.get_request_params(stream_slice=stream_slice) + assert params == expected_output + mocked_partition_router.get_request_params.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None) + underlying_cursor.get_request_params.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + else: + with pytest.raises(ValueError): + cursor.get_request_params(stream_slice=stream_slice) + + +@pytest.mark.parametrize( + "stream_slice, expected_output", + [ + pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"), + pytest.param(None, None, id="first partition"), + ] +) +def test_get_request_headers(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): + underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() + underlying_cursor.get_request_headers.return_value = {"cursor": "params"} + mocked_cursor_factory.create.side_effect = [underlying_cursor] + mocked_partition_router.stream_slices.return_value = [stream_slice] + mocked_partition_router.get_request_headers.return_value = {"router": "params"} + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + if stream_slice: + cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + params = cursor.get_request_headers(stream_slice=stream_slice) + assert params == expected_output + mocked_partition_router.get_request_headers.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None) + underlying_cursor.get_request_headers.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + else: + with pytest.raises(ValueError): + cursor.get_request_headers(stream_slice=stream_slice) + + +@pytest.mark.parametrize( + "stream_slice, expected_output", + [ + pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"), + pytest.param(None, None, id="first partition"), + ] +) +def test_get_request_body_data(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): + underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() + underlying_cursor.get_request_body_data.return_value = {"cursor": "params"} + mocked_cursor_factory.create.side_effect = [underlying_cursor] + mocked_partition_router.stream_slices.return_value = [stream_slice] + mocked_partition_router.get_request_body_data.return_value = {"router": "params"} + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + if stream_slice: + cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + params = cursor.get_request_body_data(stream_slice=stream_slice) + assert params == expected_output + mocked_partition_router.get_request_body_data.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None) + underlying_cursor.get_request_body_data.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + else: + with pytest.raises(ValueError): + cursor.get_request_body_data(stream_slice=stream_slice) + + +@pytest.mark.parametrize( + "stream_slice, expected_output", + [ + pytest.param(StreamSlice(partition={"partition key": "first partition"}, cursor_slice={}), {"cursor": "params", "router": "params"}, id="first partition"), + pytest.param(None, None, id="first partition"), + ] +) +def test_get_request_body_json(mocked_cursor_factory, mocked_partition_router, stream_slice, expected_output): + underlying_cursor = MockedCursorBuilder().with_stream_slices([{CURSOR_SLICE_FIELD: "first slice cursor value"}]).build() + underlying_cursor.get_request_body_json.return_value = {"cursor": "params"} + mocked_cursor_factory.create.side_effect = [underlying_cursor] + mocked_partition_router.stream_slices.return_value = [stream_slice] + mocked_partition_router.get_request_body_json.return_value = {"router": "params"} + cursor = PerPartitionCursor(mocked_cursor_factory, mocked_partition_router) + if stream_slice: + cursor.set_initial_state({"states": [{"partition": stream_slice.partition, "cursor": CURSOR_STATE}]}) + params = cursor.get_request_body_json(stream_slice=stream_slice) + assert params == expected_output + mocked_partition_router.get_request_body_json.assert_called_once_with(stream_state=None, stream_slice=stream_slice, next_page_token=None) + underlying_cursor.get_request_body_json.assert_called_once_with(stream_state=None, stream_slice={}, next_page_token=None) + else: + with pytest.raises(ValueError): + cursor.get_request_body_json(stream_slice=stream_slice) diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py index ef1f123fd124..e5080f1286a2 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/incremental/test_per_partition_cursor_integration.py @@ -4,8 +4,9 @@ from unittest.mock import patch -from airbyte_cdk.models import SyncMode -from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import PerPartitionStreamSlice +from airbyte_cdk.logger import init_logger +from airbyte_cdk.models import ConfiguredAirbyteCatalog, SyncMode, Type +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice from airbyte_cdk.sources.declarative.manifest_declarative_source import ManifestDeclarativeSource from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever from airbyte_cdk.sources.declarative.types import Record @@ -16,19 +17,35 @@ class ManifestBuilder: def __init__(self): - self._incremental_sync = None - self._partition_router = None + self._incremental_sync = {} + self._partition_router = {} + self._substream_partition_router = {} - def with_list_partition_router(self, cursor_field, partitions): - self._partition_router = { + def with_list_partition_router(self, stream_name, cursor_field, partitions): + self._partition_router[stream_name] = { "type": "ListPartitionRouter", "cursor_field": cursor_field, "values": partitions, } return self - def with_incremental_sync(self, start_datetime, end_datetime, datetime_format, cursor_field, step, cursor_granularity): - self._incremental_sync = { + def with_substream_partition_router(self, stream_name): + self._substream_partition_router[stream_name] = { + "type": "SubstreamPartitionRouter", + "parent_stream_configs": [ + { + "type": "ParentStreamConfig", + "stream": "#/definitions/Rates", + "parent_key": "id", + "partition_field": "parent_id", + + } + ] + } + return self + + def with_incremental_sync(self, stream_name, start_datetime, end_datetime, datetime_format, cursor_field, step, cursor_granularity): + self._incremental_sync[stream_name] = { "type": "DatetimeBasedCursor", "start_datetime": start_datetime, "end_datetime": end_datetime, @@ -44,8 +61,27 @@ def build(self): "version": "0.34.2", "type": "DeclarativeSource", "check": {"type": "CheckStream", "stream_names": ["Rates"]}, - "streams": [ - { + "definitions": { + "AnotherStream": { + "type": "DeclarativeStream", + "name": "AnotherStream", + "primary_key": [], + "schema_loader": { + "type": "InlineSchemaLoader", + "schema": {"$schema": "http://json-schema.org/schema#", "properties": {"id": {"type": "string"}}, "type": "object"}, + }, + "retriever": { + "type": "SimpleRetriever", + "requester": { + "type": "HttpRequester", + "url_base": "https://api.apilayer.com", + "path": "/exchangerates_data/latest", + "http_method": "GET", + }, + "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, + }, + }, + "Rates": { "type": "DeclarativeStream", "name": "Rates", "primary_key": [], @@ -63,7 +99,11 @@ def build(self): }, "record_selector": {"type": "RecordSelector", "extractor": {"type": "DpathExtractor", "field_path": []}}, }, - } + }, + }, + "streams": [ + {"$ref": "#/definitions/Rates"}, + {"$ref": "#/definitions/AnotherStream"} ], "spec": { "connection_specification": { @@ -77,18 +117,21 @@ def build(self): "type": "Spec", }, } - if self._incremental_sync: - manifest["streams"][0]["incremental_sync"] = self._incremental_sync - if self._partition_router: - manifest["streams"][0]["retriever"]["partition_router"] = self._partition_router + for stream_name, incremental_sync_definition in self._incremental_sync.items(): + manifest["definitions"][stream_name]["incremental_sync"] = incremental_sync_definition + for stream_name, partition_router_definition in self._partition_router.items(): + manifest["definitions"][stream_name]["retriever"]["partition_router"] = partition_router_definition + for stream_name, partition_router_definition in self._substream_partition_router.items(): + manifest["definitions"][stream_name]["retriever"]["partition_router"] = partition_router_definition return manifest def test_given_state_for_only_some_partition_when_stream_slices_then_create_slices_using_state_or_start_from_start_datetime(): source = ManifestDeclarativeSource( source_config=ManifestBuilder() - .with_list_partition_router("partition_field", ["1", "2"]) + .with_list_partition_router("Rates", "partition_field", ["1", "2"]) .with_incremental_sync( + "Rates", start_datetime="2022-01-01", end_datetime="2022-02-28", datetime_format="%Y-%m-%d", @@ -123,8 +166,9 @@ def test_given_state_for_only_some_partition_when_stream_slices_then_create_slic def test_given_record_for_partition_when_read_then_update_state(): source = ManifestDeclarativeSource( source_config=ManifestBuilder() - .with_list_partition_router("partition_field", ["1", "2"]) + .with_list_partition_router("Rates", "partition_field", ["1", "2"]) .with_incremental_sync( + "Rates", start_datetime="2022-01-01", end_datetime="2022-02-28", datetime_format="%Y-%m-%d", @@ -137,9 +181,11 @@ def test_given_record_for_partition_when_read_then_update_state(): stream_instance = source.streams({})[0] list(stream_instance.stream_slices(sync_mode=SYNC_MODE)) - stream_slice = PerPartitionStreamSlice({"partition_field": "1"}, {"start_time": "2022-01-01", "end_time": "2022-01-31"}) + stream_slice = StreamSlice(partition={"partition_field": "1"}, + cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}) with patch.object( - SimpleRetriever, "_read_pages", side_effect=[[Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, stream_slice)]] + SimpleRetriever, "_read_pages", + side_effect=[[Record({"a record key": "a record value", CURSOR_FIELD: "2022-01-15"}, stream_slice)]] ): list( stream_instance.read_records( @@ -158,3 +204,125 @@ def test_given_record_for_partition_when_read_then_update_state(): } ] } + + +def test_substream_without_input_state(): + source = ManifestDeclarativeSource( + source_config=ManifestBuilder() + .with_substream_partition_router("AnotherStream") + .with_incremental_sync( + "Rates", + start_datetime="2022-01-01", + end_datetime="2022-02-28", + datetime_format="%Y-%m-%d", + cursor_field=CURSOR_FIELD, + step="P1M", + cursor_granularity="P1D", + ) + .with_incremental_sync( + "AnotherStream", + start_datetime="2022-01-01", + end_datetime="2022-02-28", + datetime_format="%Y-%m-%d", + cursor_field=CURSOR_FIELD, + step="P1M", + cursor_granularity="P1D", + ) + .build() + ) + + stream_instance = source.streams({})[1] + + stream_slice = StreamSlice(partition={"parent_id": "1"}, + cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}) + + with patch.object( + SimpleRetriever, "_read_pages", side_effect=[[Record({"id": "1", CURSOR_FIELD: "2022-01-15"}, stream_slice)], + Record({"id": "2", CURSOR_FIELD: "2022-01-15"}, stream_slice)] + ): + slices = list(stream_instance.stream_slices(sync_mode=SYNC_MODE)) + assert list(slices) == [ + StreamSlice(partition={"parent_id": "1", "parent_slice": {}, }, + cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}), + StreamSlice(partition={"parent_id": "1", "parent_slice": {}, }, + cursor_slice={"start_time": "2022-02-01", "end_time": "2022-02-28"}), + ] + + +def test_substream_with_legacy_input_state(): + source = ManifestDeclarativeSource( + source_config=ManifestBuilder() + .with_substream_partition_router("AnotherStream") + .with_incremental_sync( + "Rates", + start_datetime="2022-01-01", + end_datetime="2022-02-28", + datetime_format="%Y-%m-%d", + cursor_field=CURSOR_FIELD, + step="P1M", + cursor_granularity="P1D", + ) + .with_incremental_sync( + "AnotherStream", + start_datetime="2022-01-01", + end_datetime="2022-02-28", + datetime_format="%Y-%m-%d", + cursor_field=CURSOR_FIELD, + step="P1M", + cursor_granularity="P1D", + ) + .build() + ) + + stream_instance = source.streams({})[1] + + input_state = { + "states": [ + { + "partition": {"item_id": "an_item_id", + "parent_slice": {"end_time": "1629640663", "start_time": "1626962264"}, + }, + "cursor": { + "updated_at": "1709058818" + } + } + ] + } + stream_instance.state = input_state + + stream_slice = StreamSlice(partition={"parent_id": "1"}, + cursor_slice={"start_time": "2022-01-01", "end_time": "2022-01-31"}) + + logger = init_logger("airbyte") + configured_catalog = ConfiguredAirbyteCatalog( + streams=[ + { + "stream": {"name": "AnotherStream", "json_schema": {}, "supported_sync_modes": ["incremental"]}, + "sync_mode": "incremental", + "destination_sync_mode": "overwrite", + }, + ] + ) + + with patch.object( + SimpleRetriever, "_read_pages", side_effect=[ + [Record({"id": "1", CURSOR_FIELD: "2022-01-15"}, stream_slice)], + [Record({"parent_id": "1"}, stream_slice)], + [Record({"id": "2", CURSOR_FIELD: "2022-01-15"}, stream_slice)], + [Record({"parent_id": "2", CURSOR_FIELD: "2022-01-15"}, stream_slice)] + ] + ): + messages = list(source.read(logger, {}, configured_catalog, input_state)) + + output_state_message = [message for message in messages if message.type == Type.STATE][0] + + expected_state = {"states": [ + { + "cursor": { + "cursor_field": "2022-01-31" + }, + "partition": {"parent_id": "1", "parent_slice": {}} + } + ]} + + assert output_state_message.state.stream.stream_state == expected_state diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py b/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py index c0eee22f471a..0a5a796566c9 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py @@ -525,23 +525,23 @@ def test_datetime_based_cursor(): assert isinstance(stream_slicer, DatetimeBasedCursor) assert stream_slicer._step == datetime.timedelta(days=10) - assert stream_slicer.cursor_field.string == "created" + assert stream_slicer._cursor_field.string == "created" assert stream_slicer.cursor_granularity == "PT0.000001S" - assert stream_slicer.lookback_window.string == "P5D" + assert stream_slicer._lookback_window.string == "P5D" assert stream_slicer.start_time_option.inject_into == RequestOptionType.request_parameter assert stream_slicer.start_time_option.field_name.eval(config=input_config | {"cursor_field": "updated_at"}) == "since_updated_at" assert stream_slicer.end_time_option.inject_into == RequestOptionType.body_json assert stream_slicer.end_time_option.field_name.eval({}) == "before_created_at" - assert stream_slicer.partition_field_start.eval({}) == "star" - assert stream_slicer.partition_field_end.eval({}) == "en" + assert stream_slicer._partition_field_start.eval({}) == "star" + assert stream_slicer._partition_field_end.eval({}) == "en" - assert isinstance(stream_slicer.start_datetime, MinMaxDatetime) + assert isinstance(stream_slicer._start_datetime, MinMaxDatetime) assert stream_slicer.start_datetime._datetime_format == "%Y-%m-%dT%H:%M:%S.%f%z" assert stream_slicer.start_datetime.datetime.string == "{{ config['start_time'] }}" assert stream_slicer.start_datetime.min_datetime.string == "{{ config['start_time'] + day_delta(2) }}" - assert isinstance(stream_slicer.end_datetime, MinMaxDatetime) - assert stream_slicer.end_datetime.datetime.string == "{{ config['end_time'] }}" + assert isinstance(stream_slicer._end_datetime, MinMaxDatetime) + assert stream_slicer._end_datetime.datetime.string == "{{ config['end_time'] }}" def test_stream_with_incremental_and_retriever_with_partition_router(): @@ -636,17 +636,17 @@ def test_stream_with_incremental_and_retriever_with_partition_router(): datetime_stream_slicer = stream.retriever.stream_slicer._cursor_factory.create() assert isinstance(datetime_stream_slicer, DatetimeBasedCursor) - assert isinstance(datetime_stream_slicer.start_datetime, MinMaxDatetime) - assert datetime_stream_slicer.start_datetime.datetime.string == "{{ config['start_time'] }}" - assert isinstance(datetime_stream_slicer.end_datetime, MinMaxDatetime) - assert datetime_stream_slicer.end_datetime.datetime.string == "{{ config['end_time'] }}" + assert isinstance(datetime_stream_slicer._start_datetime, MinMaxDatetime) + assert datetime_stream_slicer._start_datetime.datetime.string == "{{ config['start_time'] }}" + assert isinstance(datetime_stream_slicer._end_datetime, MinMaxDatetime) + assert datetime_stream_slicer._end_datetime.datetime.string == "{{ config['end_time'] }}" assert datetime_stream_slicer.step == "P10D" - assert datetime_stream_slicer.cursor_field.string == "created" + assert datetime_stream_slicer._cursor_field.string == "created" list_stream_slicer = stream.retriever.stream_slicer._partition_router assert isinstance(list_stream_slicer, ListPartitionRouter) assert list_stream_slicer.values == ["airbyte", "airbyte-cloud"] - assert list_stream_slicer.cursor_field.string == "a_key" + assert list_stream_slicer._cursor_field.string == "a_key" def test_incremental_data_feed(): diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py index 3a83af1eb714..b98f8f82d0b7 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_list_partition_router.py @@ -5,6 +5,7 @@ import pytest as pytest from airbyte_cdk.sources.declarative.partition_routers.list_partition_router import ListPartitionRouter from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType +from airbyte_cdk.sources.declarative.types import StreamSlice partition_values = ["customer", "store", "subscription"] cursor_field = "owner_resource" @@ -17,17 +18,23 @@ ( ["customer", "store", "subscription"], "owner_resource", - [{"owner_resource": "customer"}, {"owner_resource": "store"}, {"owner_resource": "subscription"}], + [StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={}), + StreamSlice(partition={"owner_resource": "store"}, cursor_slice={}), + StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={})], ), ( '["customer", "store", "subscription"]', "owner_resource", - [{"owner_resource": "customer"}, {"owner_resource": "store"}, {"owner_resource": "subscription"}], + [StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={}), + StreamSlice(partition={"owner_resource": "store"}, cursor_slice={}), + StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={})], ), ( '["customer", "store", "subscription"]', "{{ parameters['cursor_field'] }}", - [{"owner_resource": "customer"}, {"owner_resource": "store"}, {"owner_resource": "subscription"}], + [StreamSlice(partition={"owner_resource": "customer"}, cursor_slice={}), + StreamSlice(partition={"owner_resource": "store"}, cursor_slice={}), + StreamSlice(partition={"owner_resource": "subscription"}, cursor_slice={})], ), ], ids=[ @@ -40,6 +47,7 @@ def test_list_partition_router(partition_values, cursor_field, expected_slices): slicer = ListPartitionRouter(values=partition_values, cursor_field=cursor_field, config={}, parameters=parameters) slices = [s for s in slicer.stream_slices()] assert slices == expected_slices + assert all(isinstance(s, StreamSlice) for s in slices) @pytest.mark.parametrize( @@ -93,6 +101,22 @@ def test_request_option(request_option, expected_req_params, expected_headers, e assert expected_body_data == partition_router.get_request_body_data(stream_slice=stream_slice) +@pytest.mark.parametrize( + "stream_slice", + [ + pytest.param({}, id="test_request_option_is_empty_if_empty_stream_slice"), + pytest.param({"not the cursor": "value"}, id="test_request_option_is_empty_if_the_stream_slice_does_not_have_cursor_field"), + pytest.param(None, id="test_request_option_is_empty_if_no_stream_slice") + ] +) +def test_request_option_is_empty_if_no_stream_slice(stream_slice): + request_option = RequestOption(inject_into=RequestOptionType.body_data, parameters={}, field_name="owner_resource") + partition_router = ListPartitionRouter( + values=partition_values, cursor_field=cursor_field, config={}, request_option=request_option, parameters={} + ) + assert {} == partition_router.get_request_body_data(stream_slice=stream_slice) + + @pytest.mark.parametrize( "field_name_interpolation, expected_request_params", [ diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py index 1f9570955038..008283d5dced 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_single_partition_router.py @@ -3,6 +3,7 @@ # from airbyte_cdk.sources.declarative.partition_routers.single_partition_router import SinglePartitionRouter +from airbyte_cdk.sources.declarative.types import StreamSlice def test(): @@ -10,4 +11,4 @@ def test(): stream_slices = iterator.stream_slices() next_slice = next(stream_slices) - assert next_slice == dict() + assert next_slice == StreamSlice(partition={}, cursor_slice={}) diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py index 618a0fdb23e9..664dcaa73421 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/partition_routers/test_substream_partition_router.py @@ -6,10 +6,11 @@ import pytest as pytest from airbyte_cdk.models import AirbyteMessage, AirbyteRecordMessage, SyncMode, Type +from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.incremental.per_partition_cursor import StreamSlice from airbyte_cdk.sources.declarative.partition_routers.substream_partition_router import ParentStreamConfig, SubstreamPartitionRouter from airbyte_cdk.sources.declarative.requesters.request_option import RequestOption, RequestOptionType from airbyte_cdk.sources.declarative.types import Record -from airbyte_cdk.sources.streams.core import Stream parent_records = [{"id": 1, "data": "data1"}, {"id": 2, "data": "data2"}] more_records = [{"id": 10, "data": "data10", "slice": "second_parent"}, {"id": 20, "data": "data20", "slice": "second_parent"}] @@ -19,10 +20,10 @@ data_third_parent_slice = [] all_parent_data = data_first_parent_slice + data_second_parent_slice + data_third_parent_slice parent_slices = [{"slice": "first"}, {"slice": "second"}, {"slice": "third"}] -second_parent_stream_slice = [{"slice": "second_parent"}] +second_parent_stream_slice = [StreamSlice(partition={"slice": "second_parent"}, cursor_slice={})] -class MockStream(Stream): +class MockStream(DeclarativeStream): def __init__(self, slices, records, name): self._slices = slices self._records = records @@ -38,8 +39,12 @@ def primary_key(self) -> Optional[Union[str, List[str], List[List[str]]]]: def stream_slices( self, *, sync_mode: SyncMode, cursor_field: List[str] = None, stream_state: Mapping[str, Any] = None - ) -> Iterable[Optional[Mapping[str, Any]]]: - yield from self._slices + ) -> Iterable[Optional[StreamSlice]]: + for s in self._slices: + if isinstance(s, StreamSlice): + yield s + else: + yield StreamSlice(partition=s, cursor_slice={}) def read_records( self, @@ -100,6 +105,22 @@ def read_records( {"parent_slice": {"slice": "second"}, "first_stream_id": 2}, ], ), + ( + [ + ParentStreamConfig( + stream=MockStream([StreamSlice(partition=p, cursor_slice={"start": 0, "end": 1}) for p in parent_slices], all_parent_data, "first_stream"), + parent_key="id", + partition_field="first_stream_id", + parameters={}, + config={}, + ) + ], + [ + {"parent_slice": {"slice": "first"}, "first_stream_id": 0}, + {"parent_slice": {"slice": "first"}, "first_stream_id": 1}, + {"parent_slice": {"slice": "second"}, "first_stream_id": 2}, + ], + ), ( [ ParentStreamConfig( @@ -164,6 +185,7 @@ def read_records( "test_single_parent_slices_with_records", "test_with_parent_slices_and_records", "test_multiple_parent_streams", + "test_cursor_values_are_removed_from_parent_slices", "test_missed_parent_key", "test_dpath_extraction", ], diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/test_declarative_stream.py b/airbyte-cdk/python/unit_tests/sources/declarative/test_declarative_stream.py index a5da7e092139..e5e16e66044a 100644 --- a/airbyte-cdk/python/unit_tests/sources/declarative/test_declarative_stream.py +++ b/airbyte-cdk/python/unit_tests/sources/declarative/test_declarative_stream.py @@ -4,20 +4,21 @@ from unittest.mock import MagicMock +import pytest from airbyte_cdk.models import AirbyteLogMessage, AirbyteMessage, AirbyteTraceMessage, Level, SyncMode, TraceType, Type from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream +from airbyte_cdk.sources.declarative.types import StreamSlice SLICE_NOT_CONSIDERED_FOR_EQUALITY = {} +_name = "stream" +_primary_key = "pk" +_cursor_field = "created_at" +_json_schema = {"name": {"type": "string"}} -def test_declarative_stream(): - name = "stream" - primary_key = "pk" - cursor_field = "created_at" - schema_loader = MagicMock() - json_schema = {"name": {"type": "string"}} - schema_loader.get_json_schema.return_value = json_schema +def test_declarative_stream(): + schema_loader = _schema_loader() state = MagicMock() records = [ @@ -27,9 +28,9 @@ def test_declarative_stream(): AirbyteMessage(type=Type.TRACE, trace=AirbyteTraceMessage(type=TraceType.ERROR, emitted_at=12345)), ] stream_slices = [ - {"date": "2021-01-01"}, - {"date": "2021-01-02"}, - {"date": "2021-01-03"}, + StreamSlice(partition={}, cursor_slice={"date": "2021-01-01"}), + StreamSlice(partition={}, cursor_slice={"date": "2021-01-02"}), + StreamSlice(partition={}, cursor_slice={"date": "2021-01-03"}), ] retriever = MagicMock() @@ -40,8 +41,8 @@ def test_declarative_stream(): config = {"api_key": "open_sesame"} stream = DeclarativeStream( - name=name, - primary_key=primary_key, + name=_name, + primary_key=_primary_key, stream_cursor_field="{{ parameters['cursor_field'] }}", schema_loader=schema_loader, retriever=retriever, @@ -49,14 +50,37 @@ def test_declarative_stream(): parameters={"cursor_field": "created_at"}, ) - assert stream.name == name - assert stream.get_json_schema() == json_schema + assert stream.name == _name + assert stream.get_json_schema() == _json_schema assert stream.state == state input_slice = stream_slices[0] - assert list(stream.read_records(SyncMode.full_refresh, cursor_field, input_slice, state)) == records - assert stream.primary_key == primary_key - assert stream.cursor_field == cursor_field - assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=cursor_field, stream_state=None) == stream_slices + assert list(stream.read_records(SyncMode.full_refresh, _cursor_field, input_slice, state)) == records + assert stream.primary_key == _primary_key + assert stream.cursor_field == _cursor_field + assert stream.stream_slices(sync_mode=SyncMode.incremental, cursor_field=_cursor_field, stream_state=None) == stream_slices + + +def test_read_records_raises_exception_if_stream_slice_is_not_per_partition_stream_slice(): + schema_loader = _schema_loader() + + retriever = MagicMock() + retriever.state = MagicMock() + retriever.read_records.return_value = [] + stream_slice = {"date": "2021-01-01"} + retriever.stream_slices.return_value = [stream_slice] + + stream = DeclarativeStream( + name=_name, + primary_key=_primary_key, + stream_cursor_field="{{ parameters['cursor_field'] }}", + schema_loader=schema_loader, + retriever=retriever, + config={}, + parameters={"cursor_field": "created_at"}, + ) + + with pytest.raises(ValueError): + list(stream.read_records(SyncMode.full_refresh, _cursor_field, stream_slice, MagicMock())) def test_state_checkpoint_interval(): @@ -71,3 +95,9 @@ def test_state_checkpoint_interval(): ) assert stream.state_checkpoint_interval is None + + +def _schema_loader(): + schema_loader = MagicMock() + schema_loader.get_json_schema.return_value = _json_schema + return schema_loader diff --git a/airbyte-cdk/python/unit_tests/sources/declarative/test_types.py b/airbyte-cdk/python/unit_tests/sources/declarative/test_types.py new file mode 100644 index 000000000000..8fea30c128c7 --- /dev/null +++ b/airbyte-cdk/python/unit_tests/sources/declarative/test_types.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. + +import pytest +from airbyte_cdk.sources.declarative.types import StreamSlice + + +@pytest.mark.parametrize( + "stream_slice, expected_partition", + [ + pytest.param(StreamSlice(partition={},cursor_slice={}), {}, id="test_partition_with_empty_partition"), + pytest.param(StreamSlice(partition=StreamSlice(partition={}, cursor_slice={}), cursor_slice={}), {}, id="test_partition_nested_empty"), + pytest.param(StreamSlice(partition={"key": "value"}, cursor_slice={}), {"key": "value"}, id="test_partition_with_mapping_partition"), + pytest.param(StreamSlice(partition={},cursor_slice={"cursor": "value"}), {}, id="test_partition_with_only_cursor"), + pytest.param(StreamSlice(partition=StreamSlice(partition={}, cursor_slice={}), cursor_slice={"cursor": "value"}), {}, id="test_partition_nested_empty_and_cursor_value_mapping"), + pytest.param(StreamSlice(partition=StreamSlice(partition={}, cursor_slice={"cursor": "value"}), cursor_slice={}), {}, id="test_partition_nested_empty_and_cursor_value"), + ] +) +def test_partition(stream_slice, expected_partition): + partition = stream_slice.partition + + assert partition == expected_partition + + +@pytest.mark.parametrize( + "stream_slice, expected_cursor_slice", + [ + pytest.param(StreamSlice(partition={},cursor_slice={}), {}, id="test_cursor_slice_with_empty_cursor"), + pytest.param(StreamSlice(partition={}, cursor_slice=StreamSlice(partition={}, cursor_slice={})), {}, id="test_cursor_slice_nested_empty"), + + pytest.param(StreamSlice(partition={}, cursor_slice={"key": "value"}), {"key": "value"}, id="test_cursor_slice_with_mapping_cursor_slice"), + pytest.param(StreamSlice(partition={"partition": "value"}, cursor_slice={}), {}, id="test_cursor_slice_with_only_partition"), + pytest.param(StreamSlice(partition={"partition": "value"}, cursor_slice=StreamSlice(partition={}, cursor_slice={})), {}, id="test_cursor_slice_nested_empty_and_partition_mapping"), + pytest.param(StreamSlice(partition=StreamSlice(partition={"partition": "value"}, cursor_slice={}), cursor_slice={}), {}, id="test_cursor_slice_nested_empty_and_partition"), + ] +) +def test_cursor_slice(stream_slice, expected_cursor_slice): + cursor_slice = stream_slice.cursor_slice + + assert cursor_slice == expected_cursor_slice