diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2f77eb0fd..bd4ea98e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -78,7 +78,7 @@ jobs: PYTEST_ADDOPTS: "--color=yes" run: | pytest -vv tests/unit - + - name: Perform acceptance tests env: PYTEST_ADDOPTS: "--color=yes" diff --git a/CHANGELOG.md b/CHANGELOG.md index fff4cd867..edd78e74e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,13 @@ * retrieve oauth token automatically from different oauth endpoints * retrieve configruation with mTLS authentication + +* reimplementation of HTTP Input Connector with following Features: + * Wildcard based HTTP Request routing + * Regex based HTTP Request routing + * Improvements in thread-based runtime + * Configuration and possibility to add metadata + ### Improvements * remove `versioneer` dependency in favor of `setuptools-scm` diff --git a/logprep/connector/http/input.py b/logprep/connector/http/input.py index 129f64739..f9988a7df 100644 --- a/logprep/connector/http/input.py +++ b/logprep/connector/http/input.py @@ -13,6 +13,9 @@ input: myhttpinput: type: http_input + message_backlog_size: 15000 + collect_meta: False + metafield_name: "@metadata" uvicorn_config: host: 0.0.0.0 port: 9000 @@ -22,52 +25,123 @@ /thirdendpoint: jsonl """ -import contextlib import inspect import queue import threading -from abc import ABC, abstractmethod -from typing import Mapping, Tuple, Union - +from abc import ABC +from logging import Logger +import logging +import re +from typing import Mapping, Tuple, Union, Callable +from attrs import define, field, validators import msgspec import uvicorn -from attrs import define, field, validators -from fastapi import FastAPI, Request -from pydantic import BaseModel # pylint: disable=no-name-in-module - -from logprep.abc.input import Input +import falcon.asgi +from falcon import HTTPTooManyRequests, HTTPMethodNotAllowed # pylint: disable=no-name-in-module +from logprep.abc.input import FatalInputError, Input +from logprep.util import defaults uvicorn_parameter_keys = inspect.signature(uvicorn.Config).parameters.keys() UVICORN_CONFIG_KEYS = [ parameter for parameter in uvicorn_parameter_keys if parameter not in ["app", "log_level"] ] +# Config Parts that's checked for Config Change +HTTP_INPUT_CONFIG_KEYS = [ + "preprocessing", + "uvicorn_config", + "endpoints", + "collect_meta", + "metafield_name", + "message_backlog_size", +] + + +def decorator_request_exceptions(func: Callable): + """Decorator to wrap http calls and raise exceptions""" + + async def func_wrapper(*args, **kwargs): + try: + if args[1].method == "POST": + func_wrapper = await func(*args, **kwargs) + else: + raise HTTPMethodNotAllowed(["POST"]) + except queue.Full as exc: + raise HTTPTooManyRequests(description="Logprep Message Queue is full.") from exc + return func_wrapper + + return func_wrapper + + +def decorator_add_metadata(func: Callable): + """Decorator to add metadata to resulting http event. + Uses attribute collect_meta of endpoint class to decide over metadata collection + Uses attribute metafield_name to define key name for metadata + """ + + async def func_wrapper(*args, **kwargs): + req = args[1] + endpoint = args[0] + if endpoint.collect_meta: + metadata = { + "url": req.url, + "remote_addr": req.remote_addr, + "user_agent": req.user_agent, + } + kwargs["metadata"] = {endpoint.metafield_name: metadata} + else: + kwargs["metadata"] = {} + func_wrapper = await func(*args, **kwargs) + return func_wrapper + + return func_wrapper + + +def route_compile_helper(input_re_str: str): + """falcon add_sink handles prefix routes as independent URI elements + therefore we need regex position anchors to ensure beginning and + end of given route and replace * with .* for user-friendliness + """ + input_re_str = input_re_str.replace("*", ".*") + input_re_str = "^" + input_re_str + "$" + return re.compile(input_re_str) + class HttpEndpoint(ABC): - """interface for http endpoints""" + """Interface for http endpoints. + Additional functionality is added to child classes via removable decorators. + Parameters + ---------- messages: queue.Queue - - def __init__(self, messages: queue.Queue) -> None: + Input Events are put here + collect_meta: bool + Collects Metadata on True (default) + metafield_name: str + Defines key name for metadata + """ + + def __init__(self, messages: queue.Queue, collect_meta: bool, metafield_name: str) -> None: self.messages = messages - - @abstractmethod - async def endpoint(self, **kwargs): - """callback method for route""" - ... # pragma: no cover + self.collect_meta = collect_meta + self.metafield_name = metafield_name class JSONHttpEndpoint(HttpEndpoint): """:code:`json` endpoint to get json from request""" - class Event(BaseModel): - """model for event""" - - message: str + _decoder = msgspec.json.Decoder() - async def endpoint(self, event: Event): # pylint: disable=arguments-differ + @decorator_request_exceptions + @decorator_add_metadata + async def __call__(self, req, resp, **kwargs): # pylint: disable=arguments-differ """json endpoint method""" - self.messages.put(dict(event)) + data = await req.stream.read() + data = data.decode("utf8") + metadata = kwargs.get("metadata", {}) + if data: + event = self._decoder.decode(data) + self.messages.put({**event, **metadata}, block=False) class JSONLHttpEndpoint(HttpEndpoint): @@ -75,51 +149,145 @@ class JSONLHttpEndpoint(HttpEndpoint): _decoder = msgspec.json.Decoder() - async def endpoint(self, request: Request): # pylint: disable=arguments-differ + @decorator_request_exceptions + @decorator_add_metadata + async def __call__(self, req, resp, **kwargs): # pylint: disable=arguments-differ """jsonl endpoint method""" - data = await request.body() + data = await req.stream.read() data = data.decode("utf8") - for line in data.splitlines(): - line = line.strip() - if line: - event = self._decoder.decode(line) - self.messages.put(event) + event = kwargs.get("metadata", {}) + metadata = kwargs.get("metadata", {}) + stripped_lines = map(str.strip, data.splitlines()) + events = (self._decoder.decode(line) for line in stripped_lines if line) + for event in events: + self.messages.put({**event, **metadata}, block=False) class PlaintextHttpEndpoint(HttpEndpoint): - """:code:`plaintext` endpoint to get the body from request and put it in :code:`message` field""" + """:code:`plaintext` endpoint to get the body from request + and put it in :code:`message` field""" - async def endpoint(self, request: Request): # pylint: disable=arguments-differ + @decorator_request_exceptions + @decorator_add_metadata + async def __call__(self, req, resp, **kwargs): # pylint: disable=arguments-differ """plaintext endpoint method""" - data = await request.body() - self.messages.put({"message": data.decode("utf8")}) - - -class Server(uvicorn.Server): - """the uvicorn server""" + data = await req.stream.read() + metadata = kwargs.get("metadata", {}) + event = {"message": data.decode("utf8")} + print(event) + self.messages.put({**event, **metadata}, block=False) + + +class ThreadingHTTPServer: # pylint: disable=too-many-instance-attributes + """Singleton Wrapper Class around Uvicorn Thread that controls + lifecycle of Uvicorn HTTP Server. During Runtime this singleton object + is stateful and therefore we need to check for some attributes during + __init__ when multiple consecutive reconfigurations are happening. + + Parameters + ---------- + connector_config: Input.Config + Holds full connector config for config change checks + endpoints_config: dict + Endpoint paths as key and initiated endpoint objects as + value + log_level: str + Log level to be set for uvicorn server + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + with cls._lock: + if not cls._instance: + cls._instance = super(ThreadingHTTPServer, cls).__new__(cls) + return cls._instance + + def __init__( + self, + connector_config: Input.Config, + endpoints_config: dict, + log_level: str, + ) -> None: + """Creates object attributes with necessary configuration. + As this class creates a singleton object, the existing server + will be stopped and restarted on consecutively creations""" + super().__init__() + + self.connector_config = connector_config + self.endpoints_config = endpoints_config + self.log_level = log_level + + if hasattr(self, "thread"): + if self.thread.is_alive(): # pylint: disable=access-member-before-definition + self._stop() + self._start() + + def _start(self): + """Collect all configs, initiate application server and webserver + and run thread with uvicorn+falcon http server and wait + until it is up (started)""" + self.uvicorn_config = self.connector_config.uvicorn_config + self._init_web_application_server(self.endpoints_config) + log_config = self._init_log_config() + self.compiled_config = uvicorn.Config( + **self.uvicorn_config, + app=self.app, + log_level=self.log_level, + log_config=log_config, + ) + self.server = uvicorn.Server(self.compiled_config) + self._override_runtime_logging() + self.thread = threading.Thread(daemon=False, target=self.server.run) + self.thread.start() + while not self.server.started: + continue + + def _stop(self): + """Stop thread with uvicorn+falcon http server, wait for uvicorn + to exit gracefully and join the thread""" + if self.thread.is_alive(): + self.server.should_exit = True + while self.thread.is_alive(): + continue + self.thread.join() + + def _init_log_config(self) -> dict: + """Use for Uvicorn same log formatter like for Logprep""" + log_config = uvicorn.config.LOGGING_CONFIG + log_config["formatters"]["default"]["fmt"] = defaults.DEFAULT_LOG_FORMAT + log_config["formatters"]["access"]["fmt"] = defaults.DEFAULT_LOG_FORMAT + log_config["handlers"]["default"]["stream"] = "ext://sys.stdout" + return log_config + + def _override_runtime_logging(self): + """Uvicorn doesn't provide API to change name and handler beforehand + needs to be done during runtime""" + http_server_name = logging.getLogger("Logprep").name + " HTTPServer" + for logger_name in ["uvicorn", "uvicorn.access"]: + logging.getLogger(logger_name).removeHandler(logging.getLogger(logger_name).handlers[0]) + logging.getLogger(logger_name).addHandler( + logging.getLogger("Logprep").parent.handlers[0] + ) + logging.getLogger("uvicorn.access").name = http_server_name + logging.getLogger("uvicorn.error").name = http_server_name - def install_signal_handlers(self): - pass + def _init_web_application_server(self, endpoints_config: dict) -> None: + "Init falcon application server and setting endpoint routes" + self.app = falcon.asgi.App() # pylint: disable=attribute-defined-outside-init + for endpoint_path, endpoint in endpoints_config.items(): + self.app.add_sink(endpoint, prefix=route_compile_helper(endpoint_path)) - @contextlib.contextmanager - def run_in_thread(self): - """Context manager to run the server in a separate thread""" - thread = threading.Thread(target=self.run) - thread.start() - try: - while not self.started: - pass - yield - finally: - self.should_exit = True - thread.join() + def shut_down(self): + """Shutdown method to trigger http server shutdown externally""" + self._stop() class HttpConnector(Input): """Connector to accept log messages as http post requests""" - messages: queue.Queue = queue.Queue() - _endpoint_registry: Mapping[str, HttpEndpoint] = { "json": JSONHttpEndpoint, "plaintext": PlaintextHttpEndpoint, @@ -135,10 +303,12 @@ class Config(Input.Config): validators.instance_of(dict), validators.deep_mapping( key_validator=validators.in_(UVICORN_CONFIG_KEYS), + # lamba xyz tuple necessary because of input structure value_validator=lambda x, y, z: True, ), ] ) + """Configure uvicorn server. For possible settings see `uvicorn settings page `_. """ @@ -162,30 +332,89 @@ class Config(Input.Config): :noindex: """ - app: FastAPI - server: uvicorn.Server + message_backlog_size: int = field( + validator=validators.instance_of((int, float)), default=15000 + ) + """Configures maximum size of input message queue for this connector. When limit is reached + the server will answer with 429 Too Many Requests. For reasonable throughput this shouldn't + be smaller than default value of 15.000 messages. + """ + + collect_meta: str = field(validator=validators.instance_of(bool), default=True) + """Defines if metadata should be collected + - :code:`True`: Collect metadata + - :code:`False`: Won't collect metadata + """ - __slots__ = ["app", "server"] + metafield_name: str = field(validator=validators.instance_of(str), default="@metadata") + """Defines the name of the key for the collected metadata fields""" + + __slots__ = [] + + def __init__(self, name: str, configuration: "HttpConnector.Config", logger: Logger) -> None: + super().__init__(name, configuration, logger) + internal_uvicorn_config = { + "lifespan": "off", + "loop": "asyncio", + "timeout_graceful_shutdown": 0, + } + self._config.uvicorn_config.update(internal_uvicorn_config) + self.logger = logger + self.port = self._config.uvicorn_config["port"] + self.host = self._config.uvicorn_config["host"] + self.target = "http://" + self.host + ":" + str(self.port) + self.messages = queue.Queue( + self._config.message_backlog_size + ) # pylint: disable=attribute-defined-outside-init def setup(self): + """setup starts the actual functionality of this connector. + By checking against pipeline_index we're assuring this connector + only runs a single time for multiple processes. + """ + super().setup() - self.app = FastAPI() + if not hasattr(self, "pipeline_index"): + raise FatalInputError( + self, "Necessary instance attribute `pipeline_index` could not be found." + ) + # Start HTTP Input only when in first process + if self.pipeline_index != 1: + return + + endpoints_config = {} + collect_meta = self._config.collect_meta + metafield_name = self._config.metafield_name + # preparing dict with endpoint paths and initialized endpoints objects for endpoint_path, endpoint_name in self._config.endpoints.items(): endpoint_class = self._endpoint_registry.get(endpoint_name) - endpoint = endpoint_class(self.messages) - self.app.add_api_route( - path=f"{endpoint_path}", endpoint=endpoint.endpoint, methods=["POST"] + endpoints_config[endpoint_path] = endpoint_class( + self.messages, collect_meta, metafield_name ) - uvicorn_config = uvicorn.Config( - **self._config.uvicorn_config, app=self.app, log_level=self._logger.level + + self.http_server = ThreadingHTTPServer( # pylint: disable=attribute-defined-outside-init + connector_config=self._config, + endpoints_config=endpoints_config, + log_level=self._logger.level, ) - self.server = Server(uvicorn_config) def _get_event(self, timeout: float) -> Tuple: - """returns the first message from the queue""" + """Returns the first message from the queue""" try: message = self.messages.get(timeout=timeout) raw_message = str(message).encode("utf8") return message, raw_message except queue.Empty: return None, None + + def get_app_instance(self): + """Return app instance from webserver thread""" + return self.http_server.app + + def get_server_instance(self): + """Return server instance from webserver thread""" + return self.http_server.server + + def shut_down(self): + """Raises Uvicorn HTTP Server internal stop flag and waits to join""" + self.http_server.shut_down() diff --git a/logprep/framework/pipeline.py b/logprep/framework/pipeline.py index 03259aa2d..98c1b8f7a 100644 --- a/logprep/framework/pipeline.py +++ b/logprep/framework/pipeline.py @@ -100,9 +100,6 @@ class Metrics(Component.Metrics): _lock: Lock """ the lock for the pipeline process """ - _used_server_ports: dict - """ a shard dict for signaling used ports between pipeline processes """ - pipeline_index: int """ the index of this pipeline """ @@ -168,7 +165,6 @@ def __init__( pipeline_index: int = None, log_queue: multiprocessing.Queue = None, lock: Lock = None, - used_server_ports: dict = None, ) -> None: self._log_queue = log_queue self.logger = logging.getLogger(f"Logprep Pipeline {pipeline_index}") @@ -178,7 +174,6 @@ def __init__( self._continue_iterating = Value(c_bool) self._lock = lock - self._used_server_ports = used_server_ports self.pipeline_index = pipeline_index self._encoder = msgspec.msgpack.Encoder() self._decoder = msgspec.msgpack.Decoder() @@ -201,10 +196,6 @@ def _setup(self): for _, output in self._output.items(): output.setup() - if hasattr(self._input, "server"): - while self._input.server.config.port in self._used_server_ports: - self._input.server.config.port += 1 - self._used_server_ports.update({self._input.server.config.port: self._process_name}) self.logger.debug("Finished creating connectors") self.logger.info("Start building pipeline") _ = self._pipeline @@ -227,13 +218,8 @@ def run(self) -> None: warnings.simplefilter("default") self._setup() self.logger.debug("Start iterating") - if hasattr(self._input, "server"): - with self._input.server.run_in_thread(): - while self._continue_iterating.value: - self.process_pipeline() - else: - while self._continue_iterating.value: - self.process_pipeline() + while self._continue_iterating.value: + self.process_pipeline() self._shut_down() @_handle_pipeline_error @@ -312,8 +298,6 @@ def _store_extra_data(self, extra_data: List[tuple]) -> None: def _shut_down(self) -> None: self._input.shut_down() - if hasattr(self._input, "server"): - self._used_server_ports.pop(self._input.server.config.port) self._drain_input_queues() for _, output in self._output.items(): output.shut_down() diff --git a/logprep/framework/pipeline_manager.py b/logprep/framework/pipeline_manager.py index 4132288bc..1c90d7425 100644 --- a/logprep/framework/pipeline_manager.py +++ b/logprep/framework/pipeline_manager.py @@ -13,6 +13,7 @@ from logprep.metrics.exporter import PrometheusExporter from logprep.metrics.metrics import CounterMetric from logprep.util.configuration import Configuration +from logprep.util.logging import SingleThreadQueueListener class PipelineManager: @@ -50,21 +51,18 @@ def __init__(self, configuration: Configuration): self.metrics = self.Metrics(labels={"component": "manager"}) self._logger = logging.getLogger("Logprep PipelineManager") self.log_queue = multiprocessing.Queue(-1) - self._queue_listener = logging.handlers.QueueListener(self.log_queue) + self._queue_listener = SingleThreadQueueListener(self.log_queue) self._queue_listener.start() self._pipelines: list[multiprocessing.Process] = [] self._configuration = configuration self._lock = multiprocessing.Lock() - self._used_server_ports = None prometheus_config = self._configuration.metrics if prometheus_config.enabled: self.prometheus_exporter = PrometheusExporter(prometheus_config) else: self.prometheus_exporter = None - manager = multiprocessing.Manager() - self._used_server_ports = manager.dict() def get_count(self) -> int: """Get the pipeline count. @@ -145,7 +143,6 @@ def _create_pipeline(self, index) -> multiprocessing.Process: config=self._configuration, log_queue=self.log_queue, lock=self._lock, - used_server_ports=self._used_server_ports, ) self._logger.info("Created new pipeline") process = multiprocessing.Process(target=pipeline.run, daemon=True) diff --git a/logprep/run_logprep.py b/logprep/run_logprep.py index 54bee710b..7f61ad472 100644 --- a/logprep/run_logprep.py +++ b/logprep/run_logprep.py @@ -19,6 +19,7 @@ from logprep.util.configuration import Configuration, InvalidConfigurationError from logprep.util.helper import get_versions_string, print_fcolor from logprep.util.rule_dry_runner import DryRunner +from logprep.util import defaults warnings.simplefilter("always", DeprecationWarning) logging.captureWarnings(True) @@ -38,9 +39,7 @@ def _print_version(config: "Configuration") -> None: def _get_logger(logger_config: dict) -> logging.Logger: log_level = logger_config.get("level", "INFO") - logging.basicConfig( - level=log_level, format="%(asctime)-15s %(name)-5s %(levelname)-8s: %(message)s" - ) + logging.basicConfig(level=log_level, format=defaults.DEFAULT_LOG_FORMAT) logger = logging.getLogger("Logprep") logger.setLevel(log_level) return logger diff --git a/logprep/util/defaults.py b/logprep/util/defaults.py index 829d25d24..6dd2558a5 100644 --- a/logprep/util/defaults.py +++ b/logprep/util/defaults.py @@ -1,4 +1,5 @@ """Default values for logprep.""" DEFAULT_CONFIG_LOCATION = "file:///etc/logprep/pipeline.yml" +DEFAULT_LOG_FORMAT = "%(asctime)-15s %(name)-5s %(levelname)-8s: %(message)s" ENV_NAME_LOGPREP_CREDENTIALS_FILE = "LOGPREP_CREDENTIALS_FILE" diff --git a/logprep/util/logging.py b/logprep/util/logging.py new file mode 100644 index 000000000..d6a87b285 --- /dev/null +++ b/logprep/util/logging.py @@ -0,0 +1,84 @@ +"""Different helper-functions and -classes for support logging""" + +# pragma: no cover + +import time +import logging +import logging.handlers +import threading +from queue import Empty + + +# gratefully using implementation +# from https://medium.com/@augustomen/using-logging-asynchronously-c8e854de874c +class SingleThreadQueueListener(logging.handlers.QueueListener): + """A subclass of QueueListener that uses a single thread for all queues. + + See https://github.com/python/cpython/blob/main/Lib/logging/handlers.py + for the implementation of QueueListener. + """ + + monitor_thread = None + listeners = [] + sleep_time = 0.1 + + @classmethod + def _start(cls): + """Start a single thread, only if none is started.""" + if cls.monitor_thread is None or not cls.monitor_thread.is_alive(): + cls.monitor_thread = t = threading.Thread( + target=cls._monitor_all, name="logging_monitor" + ) + t.daemon = True + t.start() + return cls.monitor_thread + + @classmethod + def _join(cls): + """Waits for the thread to stop. + Only call this after stopping all listeners. + """ + if cls.monitor_thread is not None and cls.monitor_thread.is_alive(): + cls.monitor_thread.join() + cls.monitor_thread = None + + @classmethod + def _monitor_all(cls): + """A monitor function for all the registered listeners. + Does not block when obtaining messages from the queue to give all + listeners a chance to get an item from the queue. That's why we + must sleep at every cycle. + + If a sentinel is sent, the listener is unregistered. + When all listeners are unregistered, the thread stops. + """ + noop = None + while cls.listeners: + time.sleep(cls.sleep_time) # does not block all threads + for listener in cls.listeners: + try: + # Gets all messages in this queue without blocking + task_done = getattr(listener.queue, "task_done", noop) + while True: + record = listener.dequeue(False) + if record is listener._sentinel: # pylint: disable=protected-access + cls.listeners.remove(listener) + else: + listener.handle(record) + task_done() + except Empty: + continue + except TypeError: + continue + + def start(self): + """Override default implementation. + Register this listener and call class' _start() instead. + """ + SingleThreadQueueListener.listeners.append(self) + # Start if not already + SingleThreadQueueListener._start() + + def stop(self): + """Enqueues the sentinel but does not stop the thread.""" + self.enqueue_sentinel() diff --git a/pyproject.toml b/pyproject.toml index 38d49b246..77292b1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ dependencies = [ "click", "pandas", "tabulate", + "falcon==3.1.3", ] diff --git a/quickstart/exampledata/config/http_pipeline.yml b/quickstart/exampledata/config/http_pipeline.yml index 5be5501c4..b65afac40 100644 --- a/quickstart/exampledata/config/http_pipeline.yml +++ b/quickstart/exampledata/config/http_pipeline.yml @@ -1,16 +1,22 @@ version: 1 +process_count: 2 metrics: enabled: true - port: 8000 - + port: 8003 input: httpinput: type: http_input + message_backlog_size: 1500000 + collect_meta: True + metafield_name: "@metadata" uvicorn_config: host: 0.0.0.0 port: 9000 endpoints: + /json: json + /lab/123/(first|second|third)/js.*: jsonl + /lab/123/(ABC|DEF)/pl.*: plaintext /lab/123/ABC/auditlog: jsonl output: kafka: diff --git a/quickstart/exampledata/config/prometheus/prometheus.yml b/quickstart/exampledata/config/prometheus/prometheus.yml index 8fdd847e4..3913cbba4 100644 --- a/quickstart/exampledata/config/prometheus/prometheus.yml +++ b/quickstart/exampledata/config/prometheus/prometheus.yml @@ -29,7 +29,7 @@ scrape_configs: - targets: ["localhost:9090"] - job_name: "logprep" static_configs: - - targets: ["localhost:8000", "localhost:8001"] + - targets: ["localhost:8001", "localhost:8003"] - job_name: "kafka" metrics_path: "/metrics" static_configs: diff --git a/tests/acceptance/test_config_refresh.py b/tests/acceptance/test_config_refresh.py index 5d5969886..2ab257e4a 100644 --- a/tests/acceptance/test_config_refresh.py +++ b/tests/acceptance/test_config_refresh.py @@ -24,10 +24,10 @@ def test_two_times_config_refresh_after_5_seconds(tmp_path): wait_for_output(proc, "Config refresh interval is set to: 5 seconds", test_timeout=5) config.version = "2" config_path.write_text(config.as_json()) - wait_for_output(proc, "Successfully reloaded configuration", test_timeout=7) + wait_for_output(proc, "Successfully reloaded configuration", test_timeout=12) config.version = "other version" config_path.write_text(config.as_json()) - wait_for_output(proc, "Successfully reloaded configuration", test_timeout=6) + wait_for_output(proc, "Successfully reloaded configuration", test_timeout=12) def test_no_config_refresh_after_5_seconds(tmp_path): diff --git a/tests/acceptance/test_full_configuration.py b/tests/acceptance/test_full_configuration.py index 8466c79d4..d7d892b8e 100644 --- a/tests/acceptance/test_full_configuration.py +++ b/tests/acceptance/test_full_configuration.py @@ -125,7 +125,7 @@ def test_logprep_exposes_prometheus_metrics(tmp_path): config = get_default_logprep_config(pipeline, with_hmac=False) config.version = "my_custom_version" config.config_refresh_interval = 300 - config.metrics = {"enabled": True, "port": 8000} + config.metrics = {"enabled": True, "port": 8003} config.input = { "fileinput": { "type": "file_input", @@ -164,7 +164,7 @@ def test_logprep_exposes_prometheus_metrics(tmp_path): assert "exception" not in output.lower(), "error message" if "Finished building pipeline" in output: break - response = requests.get("http://127.0.0.1:8000", timeout=5) + response = requests.get("http://127.0.0.1:8003", timeout=5) response.raise_for_status() metrics = response.text expected_metrics = [ diff --git a/tests/acceptance/test_http_input.py b/tests/acceptance/test_http_input.py index a2706452f..9cd5f42fb 100644 --- a/tests/acceptance/test_http_input.py +++ b/tests/acceptance/test_http_input.py @@ -1,6 +1,5 @@ # pylint: disable=missing-docstring # pylint: disable=line-too-long -import os import time from logging import DEBUG, basicConfig, getLogger from pathlib import Path @@ -44,11 +43,12 @@ def config_fixture(): "endpoints": {"/json": "json", "/jsonl": "jsonl", "/plaintext": "plaintext"}, } } + return config -def setup_function(): - stop_logprep() +# def setup_function(): +# start_logprep() def teardown_function(): @@ -67,3 +67,18 @@ def test_http_input_accepts_message_for_single_pipeline(tmp_path: Path, config: requests.post("https://127.0.0.1:9000/plaintext", data="my message", verify=False, timeout=5) time.sleep(0.5) assert "my message" in output_path.read_text() + + +@pytest.mark.filterwarnings("ignore:Unverified HTTPS request is being made to host '127.0.0.1'") +def test_http_input_accepts_message_for_multiple_pipelines(tmp_path: Path, config: Configuration): + config.process_count = 4 + output_path = tmp_path / "output.jsonl" + config.output = {"testoutput": {"type": "jsonl_output", "output_file": str(output_path)}} + config_path = tmp_path / "generated_config.yml" + config_path.write_text(config.as_yaml()) + proc = start_logprep(config_path) + wait_for_output(proc, "Uvicorn running on https://127.0.0.1:9000", test_timeout=15) + + requests.post("https://127.0.0.1:9000/plaintext", data="my message", verify=False, timeout=5) + time.sleep(0.5) + assert "my message" in output_path.read_text() diff --git a/tests/unit/connector/test_http_input.py b/tests/unit/connector/test_http_input.py index 941655d92..610d46210 100644 --- a/tests/unit/connector/test_http_input.py +++ b/tests/unit/connector/test_http_input.py @@ -2,81 +2,172 @@ # pylint: disable=protected-access # pylint: disable=attribute-defined-outside-init from copy import deepcopy -import json - +from concurrent.futures import ThreadPoolExecutor import requests import uvicorn -from fastapi import FastAPI -from fastapi.testclient import TestClient +import falcon from logprep.connector.http.input import HttpConnector from logprep.factory import Factory +from logprep.abc.input import FatalInputError from tests.unit.connector.base import BaseInputTestCase class TestHttpConnector(BaseInputTestCase): + def setup_method(self): super().setup_method() + self.object.pipeline_index = 1 self.object.setup() # we have to empty the queue for testing while not self.object.messages.empty(): self.object.messages.get(timeout=0.001) - self.client = TestClient(self.object.app) + self.target = self.object.target CONFIG: dict = { "type": "http_input", + "message_backlog_size": 100, + "collect_meta": False, + "metafield_name": "@metadata", "uvicorn_config": {"port": 9000, "host": "127.0.0.1"}, - "endpoints": {"/json": "json", "/jsonl": "jsonl", "/plaintext": "plaintext"}, + "endpoints": { + "/json": "json", + "/*json": "json", + "/jsonl": "jsonl", + "/(first|second)/jsonl": "jsonl", + "/(third|fourth)/jsonl*": "jsonl", + "/plaintext": "plaintext", + }, } + def teardown_method(self): + self.object.shut_down() + def test_create_connector(self): assert isinstance(self.object, HttpConnector) - def test_has_fastapi_app(self): - assert isinstance(self.object.app, FastAPI) + def test_has_falcon_asgi_app(self): + assert isinstance(self.object.get_app_instance(), falcon.asgi.App) + + def test_no_pipeline_index(self): + connector_config = deepcopy(self.CONFIG) + connector = Factory.create({"test connector": connector_config}, logger=self.logger) + try: + connector.setup() + assert False + except FatalInputError: + assert True + + def test_not_first_pipeline(self): + connector_config = deepcopy(self.CONFIG) + connector = Factory.create({"test connector": connector_config}, logger=self.logger) + connector.pipeline_index = 2 + connector.setup() + assert not hasattr(connector, "http_server") + + def test_get_error_code_on_get(self): + resp = requests.get(url=f"{self.target}/json", timeout=0.5) + assert resp.status_code == 405 + + def test_get_error_code_too_many_requests(self): + data = {"message": "my log message"} + session = requests.Session() + session.mount( + "http://", + requests.adapters.HTTPAdapter(pool_maxsize=20, max_retries=3, pool_block=True), + ) + + def get_url(url): + for _ in range(100): + _ = session.post(url, json=data) + + with ThreadPoolExecutor(max_workers=100) as executor: + executor.submit(get_url, f"{self.target}/json") + resp = requests.post(url=f"{self.target}/json", json=data, timeout=0.5) + assert resp.status_code == 429 def test_json_endpoint_accepts_post_request(self): data = {"message": "my log message"} - resp = self.client.post(url="/json", content=json.dumps(data)) + resp = requests.post(url=f"{self.target}/json", json=data, timeout=0.5) assert resp.status_code == 200 - def test_json_message_is_put_in_queue(self): + def test_json_endpoint_match_wildcard_route(self): data = {"message": "my log message"} - resp = self.client.post(url="/json", content=json.dumps(data)) + resp = requests.post(url=f"{self.target}/api/wildcard_path/json", json=data, timeout=0.5) + assert resp.status_code == 200 + + def test_json_endpoint_not_match_wildcard_route(self): + data = {"message": "my log message"} + resp = requests.post( + url=f"{self.target}/api/wildcard_path/json/another_path", json=data, timeout=0.5 + ) + assert resp.status_code == 404 + + data = {"message": "my log message"} + resp = requests.post(url=f"{self.target}/json", json=data, timeout=0.5) assert resp.status_code == 200 event_from_queue = self.object.messages.get(timeout=0.001) assert event_from_queue == data def test_plaintext_endpoint_accepts_post_request(self): data = "my log message" - resp = self.client.post(url="/plaintext", content=data) + resp = requests.post(url=f"{self.target}/plaintext", json=data, timeout=0.5) assert resp.status_code == 200 def test_plaintext_message_is_put_in_queue(self): data = "my log message" - resp = self.client.post("/plaintext", content=data) + resp = requests.post(url=f"{self.target}/plaintext", data=data, timeout=0.5) assert resp.status_code == 200 event_from_queue = self.object.messages.get(timeout=0.001) assert event_from_queue.get("message") == data + def test_jsonl_endpoint_match_regex_route(self): + data = {"message": "my log message"} + resp = requests.post(url=f"{self.target}/first/jsonl", json=data, timeout=0.5) + assert resp.status_code == 200 + + def test_jsonl_endpoint_not_match_regex_route(self): + data = {"message": "my log message"} + resp = requests.post(url=f"{self.target}/firs/jsonl", json=data, timeout=0.5) + assert resp.status_code == 404 + + def test_jsonl_endpoint_not_match_before_start_regex(self): + data = {"message": "my log message"} + resp = requests.post(url=f"{self.target}/api/first/jsonl", json=data, timeout=0.5) + assert resp.status_code == 404 + + def test_jsonl_endpoint_match_wildcard_regex_mix_route(self): + data = {"message": "my log message"} + resp = requests.post( + url=f"{self.target}/third/jsonl/another_path/last_path", json=data, timeout=0.5 + ) + assert resp.status_code == 200 + + def test_jsonl_endpoint_not_match_wildcard_regex_mix_route(self): + data = {"message": "my log message"} + resp = requests.post( + url=f"{self.target}/api/third/jsonl/another_path", json=data, timeout=0.5 + ) + assert resp.status_code == 404 + def test_jsonl_messages_are_put_in_queue(self): data = """ {"message": "my first log message"} {"message": "my second log message"} {"message": "my third log message"} """ - resp = self.client.post("/jsonl", content=data) + resp = requests.post(url=f"{self.target}/jsonl", data=data, timeout=0.5) assert resp.status_code == 200 assert self.object.messages.qsize() == 3 + event_from_queue = self.object.messages.get(timeout=1) + assert event_from_queue["message"] == "my first log message" event_from_queue = self.object.messages.get(timeout=0.001) - assert event_from_queue == {"message": "my first log message"} + assert event_from_queue["message"] == "my second log message" event_from_queue = self.object.messages.get(timeout=0.001) - assert event_from_queue == {"message": "my second log message"} - event_from_queue = self.object.messages.get(timeout=0.001) - assert event_from_queue == {"message": "my third log message"} + assert event_from_queue["message"] == "my third log message" def test_get_next_returns_message_from_queue(self): data = {"message": "my log message"} - self.client.post(url="/json", content=json.dumps(data)) + requests.post(url=f"{self.target}/json", json=data, timeout=0.5) assert self.object.get_next(0.001) == (data, None) def test_get_next_returns_first_in_first_out(self): @@ -86,7 +177,7 @@ def test_get_next_returns_first_in_first_out(self): {"message": "third message"}, ] for message in data: - self.client.post(url="/json", content=json.dumps(message)) + requests.post(url=self.target + "/json", json=message, timeout=0.5) assert self.object.get_next(0.001) == (data[0], None) assert self.object.get_next(0.001) == (data[1], None) assert self.object.get_next(0.001) == (data[2], None) @@ -100,9 +191,9 @@ def test_get_next_returns_first_in_first_out_for_mixed_endpoints(self): for message in data: endpoint, post_data = message.values() if endpoint == "json": - self.client.post(url="/json", content=json.dumps(post_data)) + requests.post(url=self.target + "/json", json=post_data, timeout=0.5) if endpoint == "plaintext": - self.client.post("/plaintext", content=post_data) + requests.post(url=self.target + "/plaintext", data=post_data, timeout=0.5) assert self.object.get_next(0.001)[0] == data[0].get("data") assert self.object.get_next(0.001)[0] == {"message": data[1].get("data")} assert self.object.get_next(0.001)[0] == data[2].get("data") @@ -111,16 +202,54 @@ def test_get_next_returns_none_for_empty_queue(self): assert self.object.get_next(0.001)[0] is None def test_server_returns_uvicorn_server_instance(self): - assert isinstance(self.object.server, uvicorn.Server) - - def test_server_starts_threaded_server_with_context_manager(self): - with self.object.server.run_in_thread(): - message = {"message": "my message"} - for i in range(100): - message["message"] = f"message number {i}" - requests.post(url="http://127.0.0.1:9000/json", json=message) # nosemgrep + assert isinstance(self.object.get_server_instance(), uvicorn.Server) + + def test_server_starts_threaded_server(self): + message = {"message": "my message"} + for i in range(100): + message["message"] = f"message number {i}" + requests.post(url=f"{self.target}/json", json=message, timeout=0.5) # nosemgrep assert self.object.messages.qsize() == 100, "messages are put to queue" + def test_get_metadata(self): + message = {"message": "my message"} + connector_config = deepcopy(self.CONFIG) + connector_config["collect_meta"] = True + connector_config["metafield_name"] = "custom" + connector = Factory.create({"test connector": connector_config}, logger=self.logger) + connector.pipeline_index = 1 + connector.setup() + target = connector.target + resp = requests.post(url=f"{target}/json", json=message, timeout=0.5) # nosemgrep + assert resp.status_code == 200 + message = connector.messages.get(timeout=0.5) + assert message["custom"]["url"] == target + "/json" + assert message["custom"]["remote_addr"] == connector.host + assert isinstance(message["custom"]["user_agent"], str) + + def test_server_multiple_config_changes(self): + message = {"message": "my message"} + connector_config = deepcopy(self.CONFIG) + connector_config["uvicorn_config"]["port"] = 9001 + connector = Factory.create({"test connector": connector_config}, logger=self.logger) + connector.pipeline_index = 1 + connector.setup() + target = connector.target + resp = requests.post(url=f"{target}/json", json=message, timeout=0.5) # nosemgrep + assert resp.status_code == 200 + target = target.replace(":9001", ":9000") + try: + resp = requests.post(url=f"{target}/json", json=message, timeout=0.5) # nosemgrep + except requests.exceptions.ConnectionError as e: + assert e.response is None + connector_config = deepcopy(self.CONFIG) + connector = Factory.create({"test connector": connector_config}, logger=self.logger) + connector.pipeline_index = 1 + connector.setup() + target = connector.target + resp = requests.post(url=f"{target}/json", json=message, timeout=0.5) # nosemgrep + assert resp.status_code == 200 + def test_get_next_with_hmac_of_raw_message(self): connector_config = deepcopy(self.CONFIG) connector_config.update( @@ -135,10 +264,10 @@ def test_get_next_with_hmac_of_raw_message(self): } ) connector = Factory.create({"test connector": connector_config}, logger=self.logger) + connector.pipeline_index = 1 connector.setup() test_event = "the content" - with connector.server.run_in_thread(): - requests.post(url="http://127.0.0.1:9000/plaintext", data=test_event) # nosemgrep + requests.post(url=f"{self.target}/plaintext", data=test_event, timeout=0.5) # nosemgrep expected_event = { "message": "the content", diff --git a/tests/unit/framework/test_pipeline.py b/tests/unit/framework/test_pipeline.py index e992317e5..f2bcf0d65 100644 --- a/tests/unit/framework/test_pipeline.py +++ b/tests/unit/framework/test_pipeline.py @@ -58,7 +58,6 @@ def setup_method(self): config=self.logprep_config, log_queue=mock.MagicMock(), lock=self.lock, - used_server_ports=mock.MagicMock(), ) def test_pipeline_property_returns_pipeline(self, mock_create): @@ -465,40 +464,6 @@ def test_retrieve_and_process_data_calls_store_failed_for_non_critical_error_mes "This is non critical", {"some": "event"}, None ) - def test_http_input_registers_to_shard_dict(self, _): - self.pipeline._setup() - self.pipeline._input.server.config.port = 9000 - self.pipeline._used_server_ports = {} - self.pipeline._setup() - assert 9000 in self.pipeline._used_server_ports - - def test_http_input_registers_increased_port_to_shard_dict(self, _): - self.pipeline._setup() - self.pipeline._input.server.config.port = 9000 - self.pipeline._used_server_ports = {9000: "other_process_name"} - self.pipeline._setup() - assert 9001 in self.pipeline._used_server_ports - - def test_http_input_removes_port_from_shard_dict_on_shut_down(self, _): - self.pipeline._setup() - self.pipeline._input.server.config.port = 9000 - self.pipeline._used_server_ports = {} - self.pipeline._setup() - assert 9000 in self.pipeline._used_server_ports - self.pipeline._shut_down() - assert 9000 not in self.pipeline._used_server_ports - - def test_http_input_registers_increased_port_to_shard_dict_after_shut_down(self, _): - self.pipeline._setup() - self.pipeline._input.server.config.port = 9000 - self.pipeline._used_server_ports = {9000: "other_process_name"} - self.pipeline._setup() - assert 9001 in self.pipeline._used_server_ports - self.pipeline._shut_down() - assert 9001 not in self.pipeline._used_server_ports - self.pipeline._setup() - assert 9001 in self.pipeline._used_server_ports - def test_shut_down_drains_input_queues(self, _): self.pipeline._setup() input_config = { @@ -513,7 +478,9 @@ def test_shut_down_drains_input_queues(self, _): "endpoints": {"/json": "json", "/jsonl": "jsonl", "/plaintext": "plaintext"}, } } - self.pipeline._input = original_create(input_config, mock.MagicMock()) + self.pipeline._input = original_create(input_config, self.pipeline.logger) + self.pipeline._input.pipeline_index = 1 + self.pipeline._input.setup() self.pipeline._input.messages.put({"message": "test message"}) assert self.pipeline._input.messages.qsize() == 1 self.pipeline._shut_down()