Skip to content

Commit

Permalink
chore: fix or noqa remaining mypy issues (now 100% pass)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Dec 3, 2024
1 parent 92e6994 commit 3eff135
Show file tree
Hide file tree
Showing 16 changed files with 86 additions and 61 deletions.
4 changes: 3 additions & 1 deletion airbyte_cdk/cli/source_declarative_manifest/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def handle_remote_manifest_command(args: list[str]) -> None:
)


def create_declarative_source(args: list[str]) -> ConcurrentDeclarativeSource:
def create_declarative_source(
args: list[str],
) -> ConcurrentDeclarativeSource: # type: ignore [type-arg]
"""Creates the source with the injected config.
This essentially does what other low-code sources do at build time, but at runtime,
Expand Down
8 changes: 4 additions & 4 deletions airbyte_cdk/connector_builder/message_grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _get_message_groups(
current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
current_slice_pages = []
at_least_one_page_in_group = False
elif message.type == MessageType.LOG and message.log.message.startswith(
elif message.type == MessageType.LOG and message.log.message.startswith( # type: ignore[union-attr] # None doesn't have 'message'
SliceLogger.SLICE_LOG_PREFIX
): # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
# parsing the first slice
Expand Down Expand Up @@ -280,7 +280,7 @@ def _get_message_groups(
datetime_format_inferrer.accumulate(message.record)
elif (
message.type == MessageType.CONTROL
and message.control.type == OrchestratorType.CONNECTOR_CONFIG
and message.control.type == OrchestratorType.CONNECTOR_CONFIG # type: ignore[union-attr] # None doesn't have 'type'
): # type: ignore[union-attr] # AirbyteMessage with MessageType.CONTROL has control.type
yield message.control
elif message.type == MessageType.STATE:
Expand Down Expand Up @@ -310,8 +310,8 @@ def _need_to_close_page(
and message.type == MessageType.LOG
and (
MessageGrouper._is_page_http_request(json_message)
or message.log.message.startswith("slice:")
) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
or message.log.message.startswith("slice:") # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message
)
)

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions airbyte_cdk/destinations/vector_db_based/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def embedding_dimensions(self) -> int:
class OpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
super().__init__(
OpenAIEmbeddings(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key, max_retries=15, disallowed_special=()
),
chunk_size,
Expand All @@ -118,7 +118,7 @@ class AzureOpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
# Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request
super().__init__(
OpenAIEmbeddings(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key,
chunk_size=16,
max_retries=15,
Expand Down
9 changes: 5 additions & 4 deletions airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,18 @@ def handle_record_counts(
case Type.RECORD:
stream_message_count[
HashableStreamDescriptor(
name=message.record.stream, namespace=message.record.namespace
name=message.record.stream, # type: ignore[union-attr] # record has `stream`
namespace=message.record.namespace, # type: ignore[union-attr] # record has `namespace`
)
] += 1.0 # type: ignore[union-attr] # record has `stream` and `namespace`
] += 1.0
case Type.STATE:
stream_descriptor = message_utils.get_stream_descriptor(message)

# Set record count from the counter onto the state message
message.state.sourceStats = message.state.sourceStats or AirbyteStateStats() # type: ignore[union-attr] # state has `sourceStats`
message.state.sourceStats.recordCount = stream_message_count.get(
message.state.sourceStats.recordCount = stream_message_count.get( # type: ignore[union-attr] # state has `sourceStats`
stream_descriptor, 0.0
) # type: ignore[union-attr] # state has `sourceStats`
)

# Reset the counter
stream_message_count[stream_descriptor] = 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def create(
if isinstance(interpolated_string_or_min_max_datetime, InterpolatedString) or isinstance(
interpolated_string_or_min_max_datetime, str
):
return MinMaxDatetime(
return MinMaxDatetime( # type: ignore [call-arg]
datetime=interpolated_string_or_min_max_datetime, parameters=parameters
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
:param stream_state: The state of the stream as returned by get_stream_state
"""
self._cursor = (
stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None
stream_state.get(self.cursor_field.eval(self.config)) if stream_state else None # type: ignore [union-attr]
) # type: ignore # cursor_field is converted to an InterpolatedString in __post_init__

def observe(self, stream_slice: StreamSlice, record: Record) -> None:
Expand All @@ -158,7 +158,9 @@ def observe(self, stream_slice: StreamSlice, record: Record) -> None:
)
if (
self._is_within_daterange_boundaries(
record, stream_slice.get(start_field), stream_slice.get(end_field)
record,
stream_slice.get(start_field), # type: ignore [arg-type]
stream_slice.get(end_field), # type: ignore [arg-type]
) # type: ignore # we know that stream_slices for these cursors will use a string representing an unparsed date
and is_highest_observed_cursor_value
):
Expand Down Expand Up @@ -368,7 +370,7 @@ def _get_request_options(
self._partition_field_start.eval(self.config)
)
if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr]
self._partition_field_end.eval(self.config)
) # type: ignore # field_name is always casted to an interpolated string
return options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _get_request_option(
if value:
params.update(
{
parent_config.request_option.field_name.eval(
parent_config.request_option.field_name.eval( # type: ignore [union-attr]
config=self.config
): value
}
Expand Down Expand Up @@ -162,7 +162,7 @@ def stream_slices(self) -> Iterable[StreamSlice]:
extra_fields = None
if parent_stream_config.extra_fields:
extra_fields = [
[field_path_part.eval(self.config) for field_path_part in field_path]
[field_path_part.eval(self.config) for field_path_part in field_path] # type: ignore [union-attr]
for field_path in parent_stream_config.extra_fields
] # type: ignore # extra_fields is always casted to an interpolated string

Expand Down Expand Up @@ -192,7 +192,10 @@ def stream_slices(self) -> Iterable[StreamSlice]:
message=f"Parent stream returned records as invalid type {type(parent_record)}"
)
try:
partition_value = dpath.get(parent_record, parent_field)
partition_value = dpath.get(
parent_record, # type: ignore [arg-type]
parent_field,
)
except KeyError:
continue

Expand Down Expand Up @@ -228,7 +231,10 @@ def _extract_extra_fields(
if extra_fields:
for extra_field_path in extra_fields:
try:
extra_field_value = dpath.get(parent_record, extra_field_path)
extra_field_value = dpath.get(
parent_record, # type: ignore [arg-type]
extra_field_path,
)
self.logger.debug(
f"Extracted extra_field_path: {extra_field_path} with value: {extra_field_value}"
)
Expand Down Expand Up @@ -291,7 +297,7 @@ def set_initial_state(self, stream_state: StreamState) -> None:
if not parent_state and incremental_dependency:
# Attempt to retrieve child state
substream_state = list(stream_state.values())
substream_state = substream_state[0] if substream_state else {}
substream_state = substream_state[0] if substream_state else {} # type: ignore [assignment] # Incorrect type for assignment
parent_state = {}

# Copy child state to parent streams with incremental dependencies
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,18 @@ def _create_error_message(self, response: requests.Response) -> Optional[str]:
:param response: The HTTP response which can be used during interpolation
:return: The evaluated error message string to be emitted
"""
return self.error_message.eval(
return self.error_message.eval( # type: ignore [no-any-return, union-attr]
self.config, response=self._safe_response_json(response), headers=response.headers
) # type: ignore # error_message is always cast to an interpolated string

def _response_matches_predicate(self, response: requests.Response) -> bool:
return (
bool(
self.predicate.condition
and self.predicate.eval(
None, response=self._safe_response_json(response), headers=response.headers
self.predicate.condition # type: ignore [union-attr]
and self.predicate.eval( # type: ignore [union-attr]
None, # type: ignore [arg-type]
response=self._safe_response_json(response), # type: ignore [arg-type]
headers=response.headers,
)
)
if self.predicate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _get_request_options(self, option_type: RequestOptionType) -> MutableMapping
and self.pagination_strategy.get_page_size()
and self.page_size_option.inject_into == option_type
):
options[self.page_size_option.field_name.eval(config=self.config)] = (
options[self.page_size_option.field_name.eval(config=self.config)] = ( # type: ignore [union-attr]
self.pagination_strategy.get_page_size()
) # type: ignore # field_name is always cast to an interpolated string
return options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _get_request_options(
self._partition_field_start.eval(self.config)
)
if self.end_time_option and self.end_time_option.inject_into == option_type:
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get(
options[self.end_time_option.field_name.eval(config=self.config)] = stream_slice.get( # type: ignore [union-attr]
self._partition_field_end.eval(self.config)
) # type: ignore # field_name is always casted to an interpolated string
return options
9 changes: 6 additions & 3 deletions airbyte_cdk/sources/file_based/file_types/excel_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def infer_schema(
df = self.open_and_parse_file(fp)
for column, df_type in df.dtypes.items():
# Choose the broadest data type if the column's data type differs in dataframes
prev_frame_column_type = fields.get(column)
prev_frame_column_type = fields.get(column) # type: ignore [call-overload]
fields[column] = self.dtype_to_json_type( # type: ignore [index]
prev_frame_column_type,
df_type,
Expand Down Expand Up @@ -139,7 +139,10 @@ def file_read_mode(self) -> FileReadMode:
return FileReadMode.READ_BINARY

@staticmethod
def dtype_to_json_type(current_type: Optional[str], dtype: dtype_) -> str:
def dtype_to_json_type(
current_type: Optional[str],
dtype: dtype_, # type: ignore [type-arg]
) -> str:
"""
Convert Pandas DataFrame types to Airbyte Types.
Expand Down Expand Up @@ -190,4 +193,4 @@ def open_and_parse_file(fp: Union[IOBase, str, Path]) -> pd.DataFrame:
Returns:
pd.DataFrame: Parsed data from the Excel file.
"""
return pd.ExcelFile(fp, engine="calamine").parse() # type: ignore [arg-type]
return pd.ExcelFile(fp, engine="calamine").parse() # type: ignore [arg-type, call-overload, no-any-return]
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class DefaultFileBasedCursor(AbstractFileBasedCursor):
CURSOR_FIELD = "_ab_source_file_last_modified"

def __init__(self, stream_config: FileBasedStreamConfig, **_: Any):
super().__init__(stream_config)
super().__init__(stream_config) # type: ignore [safe-super]
self._file_to_datetime_history: MutableMapping[str, str] = {}
self._time_window_if_history_is_full = timedelta(
days=stream_config.days_to_sync_if_history_is_full
Expand Down
6 changes: 3 additions & 3 deletions airbyte_cdk/sources/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def read( # type: ignore # ignoring typing for ConnectorStateManager because o
# should be fixed on the stream implementation, but we should also protect against this in the CDK as well
stream_state_tracker = self.get_updated_state(
stream_state_tracker,
record_data,
record_data, # type: ignore [arg-type]
)
self._observe_state(checkpoint_reader, stream_state_tracker)
record_counter += 1
Expand Down Expand Up @@ -277,7 +277,7 @@ def read_only_records(self, state: Optional[Mapping[str, Any]] = None) -> Iterab
if state
else {}, # read() expects MutableMapping instead of Mapping which is used more often
state_manager=None,
internal_config=InternalConfig(),
internal_config=InternalConfig(), # type: ignore [call-arg]
)

@abstractmethod
Expand Down Expand Up @@ -653,7 +653,7 @@ def _checkpoint_state( # type: ignore # ignoring typing for ConnectorStateMana
# todo: This can be consolidated into one ConnectorStateManager.update_and_create_state_message() method, but I want
# to reduce changes right now and this would span concurrent as well
state_manager.update_state_for_stream(self.name, self.namespace, stream_state)
return state_manager.create_state_message(self.name, self.namespace)
return state_manager.create_state_message(self.name, self.namespace) # type: ignore [no-any-return]

@property
def configured_json_schema(self) -> Optional[Dict[str, Any]]:
Expand Down
2 changes: 1 addition & 1 deletion airbyte_cdk/sources/streams/http/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def stream_slices(
# Skip non-records (eg AirbyteLogMessage)
if isinstance(parent_record, AirbyteMessage):
if parent_record.type == MessageType.RECORD:
parent_record = parent_record.record.data
parent_record = parent_record.record.data # type: ignore [assignment, union-attr] # Incorrect type for assignment
else:
continue
elif isinstance(parent_record, Record):
Expand Down
Loading

0 comments on commit 3eff135

Please sign in to comment.