Skip to content

Commit

Permalink
Improve incomplete type hints (#515)
Browse files Browse the repository at this point in the history
  • Loading branch information
disrupted authored Jul 29, 2024
1 parent 7ffe5d7 commit b565e5e
Show file tree
Hide file tree
Showing 21 changed files with 134 additions and 122 deletions.
17 changes: 0 additions & 17 deletions .github/ruff-matcher.json

This file was deleted.

3 changes: 1 addition & 2 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ jobs:
run: |
if [[ "$RUNNER_OS" == "Linux" && "${{ matrix.python-version }}" == "3.10" ]]
then
echo "::add-matcher::.github/ruff-matcher.json"
poetry run ruff check . --config pyproject.toml --output-format text --no-fix
poetry run ruff check . --config pyproject.toml --output-format=github --no-fix
else
poetry run pre-commit run ruff-lint --all-files --show-diff-on-failure
fi;
Expand Down
12 changes: 6 additions & 6 deletions kpops/component_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from kpops.component_handlers.kafka_connect.kafka_connect_handler import (
Expand All @@ -15,11 +15,11 @@ class ComponentHandlers:

def __new__(
cls,
schema_handler,
connector_handler,
topic_handler,
*args,
**kwargs,
schema_handler: SchemaHandler | None,
connector_handler: KafkaConnectHandler,
topic_handler: TopicHandler,
*args: Any,
**kwargs: Any,
) -> ComponentHandlers:
if not cls._instance:
cls._instance = super().__new__(cls, *args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion kpops/component_handlers/helm_wrapper/helm_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, config: HelmDiffConfig) -> None:
def calculate_changes(
current_release: Iterable[HelmTemplate],
new_release: Iterable[HelmTemplate],
) -> Iterator[Change[KubernetesManifest]]:
) -> Iterator[Change[KubernetesManifest, KubernetesManifest]]:
"""Compare 2 releases and generate a Change object for each difference.
:param current_release: Iterable containing HelmTemplate objects for the current release
Expand Down
11 changes: 7 additions & 4 deletions kpops/component_handlers/kafka_connect/connect_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import httpx

Expand Down Expand Up @@ -46,8 +46,11 @@ async def create_connector(
:param connector_config: The config of the connector
:return: The current connector info if successful.
"""
config_json = connector_config.model_dump()
connect_data = {"name": connector_config.name, "config": config_json}
config_json: dict[str, Any] = connector_config.model_dump()
connect_data: dict[str, Any] = {
"name": connector_config.name,
"config": config_json,
}
response = await self._client.post(
url=f"{self.url}connectors", headers=HEADERS, json=connect_data
)
Expand Down Expand Up @@ -112,7 +115,7 @@ async def update_connector_config(
json=config_json,
)

data: dict = response.json()
data: dict[str, Any] = response.json()
if response.status_code == httpx.codes.OK:
log.info(f"Config for connector {connector_name} updated.")
log.debug(data)
Expand Down
9 changes: 6 additions & 3 deletions kpops/component_handlers/kafka_connect/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from pydantic import (
BaseModel,
ConfigDict,
SerializationInfo,
field_validator,
model_serializer,
)
Expand Down Expand Up @@ -77,8 +76,12 @@ def serialize_topics(self, topics: list[KafkaTopic]) -> str | None:

# TODO(Ivan Yordanov): Currently hacky and potentially unsafe. Find cleaner solution
@model_serializer(mode="wrap", when_used="always")
def serialize_model(self, handler, info: SerializationInfo) -> dict[str, Any]:
result = exclude_by_value(handler(self), None)
def serialize_model(
self,
default_serialize_handler: pydantic.SerializerFunctionWrapHandler,
info: pydantic.SerializationInfo,
) -> dict[str, Any]:
result = exclude_by_value(default_serialize_handler(self), None)
return {by_alias(self, name): value for name, value in result.items()}


Expand Down
7 changes: 4 additions & 3 deletions kpops/component_handlers/topic/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from typing import Any

from kpops.component_handlers.topic.exception import (
TopicNotFoundException,
Expand Down Expand Up @@ -118,7 +119,7 @@ async def __execute_topic_creation(
)

if differences:
json_body = []
json_body: list[dict[str, str]] = []
for difference in differences:
if difference.diff_type is DiffType.REMOVE:
json_body.append(
Expand Down Expand Up @@ -216,15 +217,15 @@ def __prepare_body(cls, topic: KafkaTopic) -> TopicSpec:
:param topic_config: The topic config
:return: Topic specification
"""
topic_spec_json: dict = topic.config.model_dump(
topic_spec_json: dict[str, Any] = topic.config.model_dump(
include={
"partitions_count": True,
"replication_factor": True,
"configs": True,
},
exclude_none=True,
)
configs = []
configs: list[dict[str, Any]] = []
for config_name, config_value in topic_spec_json["configs"].items():
configs.append({"name": config_name, "value": config_value})
topic_spec_json["configs"] = configs
Expand Down
4 changes: 2 additions & 2 deletions kpops/component_handlers/topic/proxy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import httpx

Expand Down Expand Up @@ -163,7 +163,7 @@ async def get_topic_config(self, topic_name: str) -> TopicConfigResponse:
raise KafkaRestProxyError(response)

async def batch_alter_topic_config(
self, topic_name: str, json_body: list[dict]
self, topic_name: str, json_body: list[dict[str, Any]]
) -> None:
"""Reset config of given config_name param to the default value on the Kafka server.
Expand Down
18 changes: 10 additions & 8 deletions kpops/component_handlers/topic/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from kpops.component_handlers.topic.model import (
BrokerConfigResponse,
KafkaTopicConfigSource,
Expand All @@ -6,8 +8,8 @@


def parse_and_compare_topic_configs(
topic_config_in_cluster: TopicConfigResponse, topic_config: dict
) -> tuple[dict, dict]:
topic_config_in_cluster: TopicConfigResponse, topic_config: dict[str, Any]
) -> tuple[dict[str, str], dict[str, Any]]:
comparable_in_cluster_config_dict, default_configs = parse_rest_proxy_topic_config(
topic_config_in_cluster
)
Expand Down Expand Up @@ -36,9 +38,9 @@ def parse_and_compare_topic_configs(


def populate_default_configs(
config_overwrites: set,
default_configs: dict,
config_to_populate: dict,
config_overwrites: set[str],
default_configs: dict[str, str],
config_to_populate: dict[str, str],
description_text: str,
):
for overwrite in config_overwrites:
Expand All @@ -52,9 +54,9 @@ def populate_default_configs(

def parse_rest_proxy_topic_config(
topic_config_in_cluster: TopicConfigResponse,
) -> tuple[dict, dict]:
comparable_in_cluster_config_dict = {}
default_configs = {}
) -> tuple[dict[str, str], dict[str, str]]:
comparable_in_cluster_config_dict: dict[str, str] = {}
default_configs: dict[str, str] = {}
for config in topic_config_in_cluster.data:
if config.source == KafkaTopicConfigSource.DYNAMIC_TOPIC_CONFIG:
comparable_in_cluster_config_dict[config.name] = config.value
Expand Down
8 changes: 4 additions & 4 deletions kpops/components/base_components/base_defaults_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class BaseDefaultsComponent(DescConfigModel, ABC):

model_config = ConfigDict(
arbitrary_types_allowed=True,
ignored_types=(cached_property, cached_classproperty),
ignored_types=(cached_property, cached_classproperty), # pyright: ignore[reportArgumentType]
)
enrich: SkipJsonSchema[bool] = Field(
default=True,
Expand Down Expand Up @@ -189,7 +189,7 @@ def load_defaults(cls, *defaults_file_paths: Path) -> dict[str, Any]:
"""
defaults: dict[str, Any] = {}
for base in (cls, *cls.parents):
component_type: str = base.type
component_type = base.type
defaults = update_nested(
defaults,
*(
Expand All @@ -204,7 +204,7 @@ def _validate_custom(self) -> None:
"""Run custom validation on component."""


def defaults_from_yaml(path: Path, key: str) -> dict:
def defaults_from_yaml(path: Path, key: str) -> dict[str, Any]:
"""Read component-specific settings from a ``defaults*.yaml`` file and return @default if not found.
:param path: Path to ``defaults*.yaml`` file
Expand Down Expand Up @@ -246,7 +246,7 @@ def get_defaults_file_paths(
:param environment: Optional. The environment for which default configuration files are sought.
:returns: A list of Path objects representing the default configuration file paths.
"""
default_paths = []
default_paths: list[Path] = []

if not pipeline_path.is_file():
msg = f"{pipeline_path} is not a valid pipeline file."
Expand Down
12 changes: 8 additions & 4 deletions kpops/components/base_components/helm_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, Any

import pydantic
from pydantic import Field, SerializationInfo, model_serializer
from pydantic import Field, model_serializer
from typing_extensions import override

from kpops.component_handlers.helm_wrapper.dry_run_handler import DryRunHandler
Expand Down Expand Up @@ -53,7 +53,7 @@ class HelmAppValues(KubernetesAppValues):
# BEWARE! All default values are enforced, hard to replicate without
# access to ``model_dump``
@override
def model_dump(self, **_) -> dict[str, Any]:
def model_dump(self, **_: Any) -> dict[str, Any]:
return super().model_dump(
by_alias=True, exclude_none=True, exclude_defaults=True
)
Expand Down Expand Up @@ -208,5 +208,9 @@ def print_helm_diff(self, stdout: str) -> None:
# HACK: workaround for Pydantic to exclude cached properties during model export
# TODO(Ivan Yordanov): Currently hacky and potentially unsafe. Find cleaner solution
@model_serializer(mode="wrap", when_used="always")
def serialize_model(self, handler, info: SerializationInfo) -> dict[str, Any]:
return exclude_by_name(handler(self), "helm", "helm_diff")
def serialize_model(
self,
default_serialize_handler: pydantic.SerializerFunctionWrapHandler,
info: pydantic.SerializationInfo,
) -> dict[str, Any]:
return exclude_by_name(default_serialize_handler(self), "helm", "helm_diff")
11 changes: 7 additions & 4 deletions kpops/components/base_components/kafka_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import logging
from abc import ABC
from collections.abc import Callable
from typing import Any

import pydantic
Expand Down Expand Up @@ -58,7 +57,7 @@ class KafkaStreamsConfig(CamelCaseConfigModel, DescConfigModel):
@pydantic.field_validator("extra_output_topics", mode="before")
@classmethod
def deserialize_extra_output_topics(
cls, extra_output_topics: Any
cls, extra_output_topics: dict[str, str] | Any
) -> dict[str, KafkaTopic] | Any:
if isinstance(extra_output_topics, dict):
return {
Expand All @@ -76,9 +75,13 @@ def serialize_extra_output_topics(
# TODO(Ivan Yordanov): Currently hacky and potentially unsafe. Find cleaner solution
@pydantic.model_serializer(mode="wrap", when_used="always")
def serialize_model(
self, handler: Callable, info: pydantic.SerializationInfo
self,
default_serialize_handler: pydantic.SerializerFunctionWrapHandler,
info: pydantic.SerializationInfo,
) -> dict[str, Any]:
return exclude_defaults(self, exclude_by_value(handler(self), None))
return exclude_defaults(
self, exclude_by_value(default_serialize_handler(self), None)
)


class KafkaAppValues(HelmAppValues):
Expand Down
4 changes: 3 additions & 1 deletion kpops/components/common/topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def deduplicate(topics: Iterable[KafkaTopic]) -> list[KafkaTopic]:
return list({topic.name: topic for topic in topics}.values())


def deserialize_kafka_topic_from_str(topic: Any) -> KafkaTopic | dict:
def deserialize_kafka_topic_from_str(
topic: str | dict[str, str] | Any,
) -> KafkaTopic | dict[str, str]:
if topic and isinstance(topic, str):
return KafkaTopic(name=topic)
if isinstance(topic, KafkaTopic | dict):
Expand Down
6 changes: 4 additions & 2 deletions kpops/components/streams_bootstrap/streams/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,17 @@ class StreamsConfig(KafkaStreamsConfig):

@pydantic.field_validator("input_topics", mode="before")
@classmethod
def deserialize_input_topics(cls, input_topics: Any) -> list[KafkaTopic] | Any:
def deserialize_input_topics(
cls, input_topics: list[str] | Any
) -> list[KafkaTopic] | Any:
if isinstance(input_topics, list):
return [KafkaTopic(name=topic_name) for topic_name in input_topics]
return input_topics

@pydantic.field_validator("extra_input_topics", mode="before")
@classmethod
def deserialize_extra_input_topics(
cls, extra_input_topics: Any
cls, extra_input_topics: dict[str, str] | Any
) -> dict[str, list[KafkaTopic]] | Any:
if isinstance(extra_input_topics, dict):
return {
Expand Down
12 changes: 6 additions & 6 deletions kpops/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def build_execution_graph(
async def run_parallel_tasks(
coroutines: list[Coroutine[Any, Any, None]],
) -> None:
tasks = []
tasks: list[asyncio.Task[None]] = []
for coro in coroutines:
tasks.append(asyncio.create_task(coro))
await asyncio.gather(*tasks)
Expand All @@ -132,7 +132,7 @@ async def run_graph_tasks(pending_tasks: list[Awaitable[None]]) -> None:

layers_graph: list[list[str]] = list(nx.bfs_layers(graph, root_node))

sorted_tasks = []
sorted_tasks: list[Awaitable[None]] = []
for layer in layers_graph[1:]:
if parallel_tasks := self.__get_parallel_tasks_from(layer, runner):
sorted_tasks.append(run_parallel_tasks(parallel_tasks))
Expand Down Expand Up @@ -220,8 +220,8 @@ class PipelineGenerator:

def parse(
self,
components: list[dict],
environment_components: list[dict],
components: list[dict[str, Any]],
environment_components: list[dict[str, Any]],
) -> Pipeline:
"""Parse pipeline from sequence of component dictionaries.
Expand Down Expand Up @@ -271,7 +271,7 @@ def load_yaml(self, path: Path, environment: str | None) -> Pipeline:

return self.parse(main_content, env_content)

def parse_components(self, components: list[dict]) -> None:
def parse_components(self, components: list[dict[str, Any]]) -> None:
"""Instantiate, enrich and inflate a list of components.
:param components: List of components
Expand All @@ -288,7 +288,7 @@ def parse_components(self, components: list[dict]) -> None:
raise ValueError(msg) from ke
component_class = self.registry[component_type]
self.apply_component(component_class, component_data)
except Exception as ex: # noqa: BLE001
except Exception as ex:
if "name" in component_data:
msg = f"Error enriching {component_data['type']} component {component_data['name']}"
raise ParsingException(msg) from ex
Expand Down
Loading

0 comments on commit b565e5e

Please sign in to comment.