Skip to content

Commit

Permalink
type: make code respect mypy strict mode (#290)
Browse files Browse the repository at this point in the history
Fixes: #258
  • Loading branch information
ssbarnea authored Aug 28, 2024
1 parent 05bd6ed commit 271ad3b
Show file tree
Hide file tree
Showing 35 changed files with 155 additions and 102 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ repos:
- types-botocore
- types-mock
- types-requests
- watchdog
- watchdog>=5.0.0
- xxhash

- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down Expand Up @@ -113,7 +113,7 @@ repos:
- pyyaml
- requests
- types-aiobotocore
- watchdog
- watchdog>=5.0.0
- xxhash
- repo: local
hooks:
Expand Down
5 changes: 4 additions & 1 deletion extensions/eda/plugins/event_filter/dashes_to_underscores.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
"""

import multiprocessing as mp
from typing import Any


def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002
def main(
event: dict[str, Any], overwrite: bool = True
) -> dict[str, Any]: # noqa: FBT001, FBT002
"""Change dashes in keys to underscores."""
logger = mp.get_logger()
logger.info("dashes_to_underscores")
Expand Down
13 changes: 7 additions & 6 deletions extensions/eda/plugins/event_filter/json_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,22 @@
from __future__ import annotations

import fnmatch
from typing import Any, Optional


def _matches_include_keys(include_keys: list, string: str) -> bool:
def _matches_include_keys(include_keys: list[str], string: str) -> bool:
return any(fnmatch.fnmatch(string, pattern) for pattern in include_keys)


def _matches_exclude_keys(exclude_keys: list, string: str) -> bool:
def _matches_exclude_keys(exclude_keys: list[str], string: str) -> bool:
return any(fnmatch.fnmatch(string, pattern) for pattern in exclude_keys)


def main(
event: dict,
exclude_keys: list | None = None,
include_keys: list | None = None,
) -> dict:
event: dict[str, Any],
exclude_keys: Optional[list[str]] = None,
include_keys: Optional[list[str]] = None,
) -> dict[str, Any]:
"""Filter keys out of events."""
if exclude_keys is None:
exclude_keys = []
Expand Down
4 changes: 3 additions & 1 deletion extensions/eda/plugins/event_filter/noop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""noop.py: An event filter that does nothing to the input."""

from typing import Any

def main(event: dict) -> dict:

def main(event: dict[str, Any]) -> dict[str, Any]:
"""Return the input."""
return event
9 changes: 6 additions & 3 deletions extensions/eda/plugins/event_filter/normalize_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,25 @@
import logging
import multiprocessing as mp
import re
from typing import Any

normalize_regex = re.compile("[^0-9a-zA-Z_]+")


def main(event: dict, overwrite: bool = True) -> dict: # noqa: FBT001, FBT002
def main(
event: dict[str, Any], overwrite: bool = True
) -> dict[str, Any]: # noqa: FBT001, FBT002
"""Change keys that contain non-alphanumeric characters to underscores."""
logger = mp.get_logger()
logger.info("normalize_keys")
return _normalize_embedded_keys(event, overwrite, logger)


def _normalize_embedded_keys(
obj: dict,
obj: dict[str, Any],
overwrite: bool, # noqa: FBT001
logger: logging.Logger,
) -> dict:
) -> dict[str, Any]:
if isinstance(obj, dict):
new_dict = {}
original_keys = list(obj.keys())
Expand Down
4 changes: 2 additions & 2 deletions extensions/eda/plugins/event_source/alertmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def clean_host(host: str) -> str:
return host


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via alertmanager webhook."""
app = web.Application()
app["queue"] = queue
Expand Down Expand Up @@ -144,7 +144,7 @@ async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
25 changes: 18 additions & 7 deletions extensions/eda/plugins/event_source/aws_cloudtrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@
from botocore.client import BaseClient


def _cloudtrail_event_to_dict(event: dict) -> dict:
def _cloudtrail_event_to_dict(event: dict[str, Any]) -> dict[str, Any]:
event["EventTime"] = event["EventTime"].isoformat()
event["CloudTrailEvent"] = json.loads(event["CloudTrailEvent"])
return event


def _get_events(events: list[dict], last_event_ids: list[str]) -> list:
def _get_events(
events: list[dict[str, Any]], last_event_ids: list[str]
) -> tuple[list[dict[str, Any]], Any, list[str]]:
event_time = None
event_ids = []
result = []
Expand All @@ -60,13 +62,22 @@ def _get_events(events: list[dict], last_event_ids: list[str]) -> list:
elif event_time == event["EventTime"]:
event_ids.append(event["EventId"])
result.append(event)
return [result, event_time, event_ids]
return result, event_time, event_ids


async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict]:
async def _get_cloudtrail_events(
client: BaseClient, params: dict[str, Any]
) -> list[dict[str, Any]]:
paginator = client.get_paginator("lookup_events")
results = await paginator.paginate(**params).build_full_result()
return results.get("Events", [])
events = results.get("Events", [])
# type guards:
if not isinstance(events, list):
raise ValueError("Events is not a list")
for event in events:
if not isinstance(event, dict):
raise ValueError("Event is not a dictionary")
return events


ARGS_MAPPING = {
Expand All @@ -75,7 +86,7 @@ async def _get_cloudtrail_events(client: BaseClient, params: dict) -> list[dict]
}


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via AWS CloudTrail."""
delay = int(args.get("delay_seconds", 10))

Expand Down Expand Up @@ -131,7 +142,7 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]:
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
4 changes: 2 additions & 2 deletions extensions/eda/plugins/event_source/aws_sqs_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


# pylint: disable=too-many-locals
async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Receive events via an AWS SQS queue."""
logger = logging.getLogger()

Expand Down Expand Up @@ -117,7 +117,7 @@ def connection_args(args: dict[str, Any]) -> dict[str, Any]:
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
6 changes: 3 additions & 3 deletions extensions/eda/plugins/event_source/azure_service_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def receive_events(
loop: asyncio.events.AbstractEventLoop,
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any], # pylint: disable=W0621
) -> None:
"""Receive events from service bus."""
Expand All @@ -53,7 +53,7 @@ def receive_events(


async def main(
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any], # pylint: disable=W0621
) -> None:
"""Receive events from service bus in a loop."""
Expand All @@ -69,7 +69,7 @@ async def main(
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

def put_nowait(self: "MockQueue", event: dict) -> None:
def put_nowait(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
8 changes: 4 additions & 4 deletions extensions/eda/plugins/event_source/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from watchdog.observers import Observer


def send_facts(queue: Queue, filename: Union[str, bytes]) -> None:
def send_facts(queue: Queue[Any], filename: Union[str, bytes]) -> None:
"""Send facts to the queue."""
if isinstance(filename, bytes):
filename = str(filename, "utf-8")
Expand All @@ -50,7 +50,7 @@ def send_facts(queue: Queue, filename: Union[str, bytes]) -> None:
coroutine = queue.put(item) # noqa: F841


def main(queue: Queue, args: dict) -> None:
def main(queue: Queue[Any], args: dict[str, Any]) -> None:
"""Load facts from YAML files initially and when the file changes."""
files = [pathlib.Path(f).resolve().as_posix() for f in args.get("files", [])]

Expand All @@ -62,7 +62,7 @@ def main(queue: Queue, args: dict) -> None:
_observe_files(queue, files)


def _observe_files(queue: Queue, files: list[str]) -> None:
def _observe_files(queue: Queue[Any], files: list[str]) -> None:
class Handler(RegexMatchingEventHandler):
"""A handler for file events."""

Expand Down Expand Up @@ -104,7 +104,7 @@ def on_moved(self: "Handler", event: FileSystemEvent) -> None:
class MockQueue(Queue[Any]):
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
8 changes: 4 additions & 4 deletions extensions/eda/plugins/event_source/file_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

def watch(
loop: asyncio.events.AbstractEventLoop,
queue: asyncio.Queue,
args: dict,
queue: asyncio.Queue[Any],
args: dict[str, Any],
) -> None:
"""Watch for changes and put events on the queue."""
root_path = args["path"]
Expand Down Expand Up @@ -96,7 +96,7 @@ def on_moved(self: "Handler", event: FileSystemEvent) -> None:
observer.join()


async def main(queue: asyncio.Queue, args: dict) -> None:
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None:
"""Watch for changes to a file and put events on the queue."""
loop = asyncio.get_event_loop()

Expand All @@ -110,7 +110,7 @@ async def main(queue: asyncio.Queue, args: dict) -> None:
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

def put_nowait(self: "MockQueue", event: dict) -> None:
def put_nowait(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
12 changes: 7 additions & 5 deletions extensions/eda/plugins/event_source/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ class DelayArgs:
class Generic:
"""Generic source plugin to generate different events."""

def __init__(self: Generic, queue: asyncio.Queue, args: dict[str, Any]) -> None:
def __init__(
self: Generic, queue: asyncio.Queue[Any], args: dict[str, Any]
) -> None:
"""Insert event data into the queue."""
self.queue = queue
field_names = [f.name for f in fields(Args)]
Expand Down Expand Up @@ -164,7 +166,7 @@ async def __call__(self: Generic) -> None:

await asyncio.sleep(self.delay_args.shutdown_after)

async def _post_event(self: Generic, event: dict, index: int) -> None:
async def _post_event(self: Generic, event: dict[str, Any], index: int) -> None:
data = self._create_data(index)

data.update(event)
Expand All @@ -189,7 +191,7 @@ async def _load_payload_from_file(self: Generic) -> None:
def _create_data(
self: Generic,
index: int,
) -> dict:
) -> dict[str, Any]:
data: dict[str, str | int] = {}
if self.my_args.create_index:
data[self.my_args.create_index] = index
Expand All @@ -206,7 +208,7 @@ def _create_data(


async def main( # pylint: disable=R0914
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any],
) -> None:
"""Call the Generic Source Plugin."""
Expand All @@ -218,7 +220,7 @@ async def main( # pylint: disable=R0914
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

async def put(self: MockQueue, event: dict) -> None:
async def put(self: MockQueue, event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
2 changes: 1 addition & 1 deletion extensions/eda/plugins/event_source/journald.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from systemd import journal # type: ignore


async def main(queue: asyncio.Queue, args: dict[str, Any]) -> None: # noqa: D417
async def main(queue: asyncio.Queue[Any], args: dict[str, Any]) -> None: # noqa: D417
"""Read journal entries and add them to the provided queue.
Args:
Expand Down
6 changes: 3 additions & 3 deletions extensions/eda/plugins/event_source/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@


async def main( # pylint: disable=R0914
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
args: dict[str, Any],
) -> None:
"""Receive events via a kafka topic."""
Expand Down Expand Up @@ -116,7 +116,7 @@ async def main( # pylint: disable=R0914


async def receive_msg(
queue: asyncio.Queue,
queue: asyncio.Queue[Any],
kafka_consumer: AIOKafkaConsumer,
encoding: str,
) -> None:
Expand Down Expand Up @@ -161,7 +161,7 @@ async def receive_msg(
class MockQueue(asyncio.Queue[Any]):
"""A fake queue."""

async def put(self: "MockQueue", event: dict) -> None:
async def put(self: "MockQueue", event: dict[str, Any]) -> None:
"""Print the event."""
print(event) # noqa: T201

Expand Down
Loading

0 comments on commit 271ad3b

Please sign in to comment.