Skip to content

Commit

Permalink
Add Flush Timeout to s3 Connector (#553)
Browse files Browse the repository at this point in the history
* Add flush timeout to s3 connector

* Move s3 resource creation into setup method.

* Make `_s3_resource` a cached property.
* Move s3 session and resource creation into `setup()` method so that
  task for `flush_timeout` can be scheduled properly.

* Extend `setup()` to verify bucket exists and can be accessed.

* Adapt task scheduling unit test to not interfere with local opensearch
  instance

* Add decorator to handle boto exceptions

---------

Co-authored-by: Marco Herzog <[email protected]>
  • Loading branch information
saegel and clumsy9 authored May 8, 2024
1 parent 7eada41 commit 28dbcec
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 100 deletions.
2 changes: 1 addition & 1 deletion logprep/connector/elasticsearch/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class Config(Output.Config):
ca_cert: Optional[str] = field(validator=validators.instance_of(str), default="")
"""(Optional) Path to a SSL ca certificate to verify the ssl context."""
flush_timeout: Optional[int] = field(validator=validators.instance_of(int), default=60)
"""(Optional) Timout after :code:`message_backlog` is flushed if
"""(Optional) Timeout after :code:`message_backlog` is flushed if
:code:`message_backlog_size` is not reached."""
loglevel: Optional[str] = field(validator=validators.instance_of(str), default="INFO")
"""(Optional) Log level for the underlying library. Enables fine-grained control over the
Expand Down
194 changes: 101 additions & 93 deletions logprep/connector/s3/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@
"""

import json
import re
from collections import defaultdict
from functools import cached_property
from logging import Logger
from time import time
from typing import DefaultDict, Optional
from typing import Any, DefaultDict, Optional
from uuid import uuid4

import boto3
Expand All @@ -64,6 +63,25 @@
from logprep.util.time import TimeParser


def _handle_s3_error(func):
def _inner(self: "S3Output", *args) -> Any:
try:
return func(self, *args)
except EndpointConnectionError as error:
raise FatalOutputError(self, "Could not connect to the endpoint URL") from error
except ConnectionClosedError as error:
raise FatalOutputError(
self,
"Connection was closed before we received a valid response from endpoint URL",
) from error
except (BotoCoreError, ClientError) as error:
raise FatalOutputError(self, str(error)) from error

return None

return _inner


class S3Output(Output):
"""An s3 output connector."""

Expand Down Expand Up @@ -120,6 +138,9 @@ class Config(Output.Config):
)
"""The input callback is called after the maximum backlog size has been reached
if this is set to True (optional)"""
flush_timeout: Optional[int] = field(validator=validators.instance_of(int), default=60)
"""(Optional) Timeout after :code:`message_backlog` is flushed if
:code:`message_backlog_size` is not reached."""

@define(kw_only=True)
class Metrics(Output.Metrics):
Expand All @@ -133,24 +154,19 @@ class Metrics(Output.Metrics):
)
"""Number of events that were successfully written to s3"""

__slots__ = ["_message_backlog", "_index_cache"]
__slots__ = ["_message_backlog", "_base_prefix"]

_message_backlog: DefaultDict

_s3_resource: Optional["boto3.resources.factory.s3.ServiceResource"]

_encoder: msgspec.json.Encoder = msgspec.json.Encoder()

_base_prefix: str

def __init__(self, name: str, configuration: "S3Output.Config", logger: Logger):
super().__init__(name, configuration, logger)
self._message_backlog = defaultdict(list)
self._base_prefix = f"{self._config.base_prefix}/" if self._config.base_prefix else ""
self._s3_resource = None
self._setup_s3_resource()

def _setup_s3_resource(self):
@cached_property
def _s3_resource(self) -> boto3.resources.factory.ServiceResource:
session = boto3.Session(
aws_access_key_id=self._config.aws_access_key_id,
aws_secret_access_key=self._config.aws_secret_access_key,
Expand All @@ -160,7 +176,7 @@ def _setup_s3_resource(self):
connect_timeout=self._config.connect_timeout,
retries={"max_attempts": self._config.max_retries},
)
self._s3_resource = session.resource(
return session.resource(
"s3",
endpoint_url=f"{self._config.endpoint_url}",
verify=self._config.ca_cert,
Expand All @@ -169,16 +185,11 @@ def _setup_s3_resource(self):
)

@property
def s3_resource(self):
"""Return s3 resource"""
return self._s3_resource

@property
def _backlog_size(self):
def _backlog_size(self) -> int:
return sum(map(len, self._message_backlog.values()))

@cached_property
def _replace_pattern(self):
def _replace_pattern(self) -> re.Pattern[str]:
return re.compile(r"%{\S+?}")

def describe(self) -> str:
Expand All @@ -193,69 +204,15 @@ def describe(self) -> str:
base_description = super().describe()
return f"{base_description} - S3 Output: {self._config.endpoint_url}"

def _add_dates(self, prefix):
date_format_matches = self._replace_pattern.findall(prefix)
if date_format_matches:
now = TimeParser.now()
for date_format_match in date_format_matches:
formatted_date = now.strftime(date_format_match[2:-1])
prefix = re.sub(date_format_match, formatted_date, prefix)
return prefix

@Metric.measure_time()
def _write_to_s3_resource(self):
"""Writes a document into s3 bucket using given prefix."""
if self._backlog_size >= self._config.message_backlog_size:
self._write_backlog()

def _add_to_backlog(self, document: dict, prefix: str):
"""Adds document to backlog and adds a a prefix.
Parameters
----------
document : dict
Document to store in backlog.
"""
prefix = self._add_dates(prefix)
prefix = f"{self._base_prefix}{prefix}"
self._message_backlog[prefix].append(document)

def _write_backlog(self):
"""Write to s3 if it is not already writing."""
if not self._message_backlog:
return

self._logger.info("Writing %s documents to s3", self._backlog_size)
for prefix_mb, document_batch in self._message_backlog.items():
self._write_document_batch(document_batch, f"{prefix_mb}/{time()}-{uuid4()}")
self._message_backlog.clear()

if not self._config.call_input_callback:
return

if self.input_connector and hasattr(self.input_connector, "batch_finished_callback"):
self.input_connector.batch_finished_callback()

def _write_document_batch(self, document_batch: dict, identifier: str):
try:
self._write_to_s3(document_batch, identifier)
except EndpointConnectionError as error:
raise FatalOutputError(self, "Could not connect to the endpoint URL") from error
except ConnectionClosedError as error:
raise FatalOutputError(
self,
"Connection was closed before we received a valid response from endpoint URL",
) from error
except (BotoCoreError, ClientError) as error:
raise FatalOutputError(self, str(error)) from error
@_handle_s3_error
def setup(self) -> None:
super().setup()
flush_timeout = self._config.flush_timeout
self._schedule_task(task=self._write_backlog, seconds=flush_timeout)

def _write_to_s3(self, document_batch: dict, identifier: str):
self._logger.debug(f'Writing "{identifier}" to s3 bucket "{self._config.bucket}"')
s3_obj = self.s3_resource.Object(self._config.bucket, identifier)
s3_obj.put(Body=self._encoder.encode(document_batch), ContentType="application/json")
self.metrics.number_of_successful_writes += len(document_batch)
_ = self._s3_resource.meta.client.head_bucket(Bucket=self._config.bucket)

def store(self, document: dict):
def store(self, document: dict) -> None:
"""Store a document into s3 bucket.
Parameters
Expand All @@ -273,19 +230,7 @@ def store(self, document: dict):
self._add_to_backlog(document, prefix_value)
self._write_to_s3_resource()

@staticmethod
def _build_no_prefix_document(message_document: dict, reason: str):
document = {
"reason": reason,
"@timestamp": TimeParser.now().isoformat(),
}
try:
document["message"] = json.dumps(message_document)
except TypeError:
document["message"] = str(message_document)
return document

def store_custom(self, document: dict, target: str):
def store_custom(self, document: dict, target: str) -> None:
"""Store document into backlog to be written into s3 bucket using the target prefix.
Only add to backlog instead of writing the batch and calling batch_finished_callback,
Expand All @@ -304,7 +249,9 @@ def store_custom(self, document: dict, target: str):
self.metrics.number_of_processed_events += 1
self._add_to_backlog(document, target)

def store_failed(self, error_message: str, document_received: dict, document_processed: dict):
def store_failed(
self, error_message: str, document_received: dict, document_processed: dict
) -> None:
"""Write errors into s3 bucket using error prefix for documents that failed processing.
Parameters
Expand All @@ -326,3 +273,64 @@ def store_failed(self, error_message: str, document_received: dict, document_pro
}
self._add_to_backlog(error_document, self._config.error_prefix)
self._write_to_s3_resource()

def _add_dates(self, prefix: str) -> str:
date_format_matches = self._replace_pattern.findall(prefix)
if date_format_matches:
now = TimeParser.now()
for date_format_match in date_format_matches:
formatted_date = now.strftime(date_format_match[2:-1])
prefix = re.sub(date_format_match, formatted_date, prefix)
return prefix

@Metric.measure_time()
def _write_to_s3_resource(self) -> None:
"""Writes a document into s3 bucket using given prefix."""
if self._backlog_size >= self._config.message_backlog_size:
self._write_backlog()

def _add_to_backlog(self, document: dict, prefix: str) -> None:
"""Adds document to backlog and adds a a prefix.
Parameters
----------
document : dict
Document to store in backlog.
"""
prefix = self._add_dates(prefix)
prefix = f"{self._base_prefix}{prefix}"
self._message_backlog[prefix].append(document)

def _write_backlog(self) -> None:
"""Write to s3 if it is not already writing."""
if not self._message_backlog:
return

self._logger.info("Writing %s documents to s3", self._backlog_size)
for prefix_mb, document_batch in self._message_backlog.items():
self._write_document_batch(document_batch, f"{prefix_mb}/{time()}-{uuid4()}")
self._message_backlog.clear()

if not self._config.call_input_callback:
return

if self.input_connector and hasattr(self.input_connector, "batch_finished_callback"):
self.input_connector.batch_finished_callback()

@_handle_s3_error
def _write_document_batch(self, document_batch: dict, identifier: str) -> None:
self._logger.debug(f'Writing "{identifier}" to s3 bucket "{self._config.bucket}"')
s3_obj = self._s3_resource.Object(self._config.bucket, identifier)
s3_obj.put(Body=self._encoder.encode(document_batch), ContentType="application/json")
self.metrics.number_of_successful_writes += len(document_batch)

def _build_no_prefix_document(self, message_document: dict, reason: str) -> dict:
document = {
"reason": reason,
"@timestamp": TimeParser.now().isoformat(),
}
try:
document["message"] = self._encoder.encode(message_document).decode("utf-8")
except (msgspec.EncodeError, TypeError):
document["message"] = str(message_document)
return document
2 changes: 1 addition & 1 deletion tests/unit/connector/test_opensearch_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def test_setup_raises_fatal_output_error_if_opensearch_error_is_raised(self):
self.object.setup()

def test_setup_registers_flush_timout_tasks(self):
# this test fails if opensearch is running on localhost
self.object._config.hosts = ["opensearch:9092"]
job_count = len(Component._scheduler.jobs)
with pytest.raises(FatalOutputError):
self.object.setup()
Expand Down
54 changes: 49 additions & 5 deletions tests/unit/connector/test_s3_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# pylint: disable=wrong-import-order
# pylint: disable=attribute-defined-outside-init
import logging
from collections import defaultdict
from copy import deepcopy
from datetime import datetime
from math import isclose
Expand Down Expand Up @@ -65,7 +66,7 @@ def test_store_sends_with_default_prefix(self, base_prefix):
expected = {
default_prefix: [
{
"message": '{"field": "content"}',
"message": '{"field":"content"}',
"reason": "Prefix field 'foo_prefix_field' empty or missing in document",
}
]
Expand All @@ -89,7 +90,7 @@ def test_store_sends_event_to_with_expected_prefix_if_prefix_missing_in_event(
event = {"field": "content"}
default_prefix = f"{base_prefix}/default_prefix" if base_prefix else "default_prefix"
expected = {
"message": '{"field": "content"}',
"message": '{"field":"content"}',
"reason": "Prefix field 'foo_prefix_field' empty or missing in document",
}
s3_config = deepcopy(self.CONFIG)
Expand Down Expand Up @@ -132,7 +133,6 @@ def test_store_failed(self, base_prefix):

s3_output.store_failed(error_message, event_received, event)

print(s3_output._message_backlog)
error_document = s3_output._message_backlog[error_prefix][0]
# timestamp is compared to be approximately the same,
# since it is variable and then removed to compare the rest
Expand Down Expand Up @@ -177,8 +177,9 @@ def test_create_s3_building_prefix_with_invalid_json(self):
)
def test_write_document_batch_calls_handles_errors(self, caplog, error, message):
with caplog.at_level(logging.WARNING):
with mock.patch(
"logprep.connector.s3.output.S3Output._write_to_s3",
with mock.patch.object(
self.object._s3_resource,
"Object",
side_effect=error,
):
with pytest.raises(FatalOutputError, match=message):
Expand Down Expand Up @@ -267,10 +268,53 @@ def test_message_backlog_is_not_written_if_message_backlog_size_not_reached(self
self.object.store({"test": "event"})
mock_write_backlog.assert_not_called()

def test_write_backlog_executed_on_empty_message_backlog(self):
with mock.patch(
"logprep.connector.s3.output.S3Output._backlog_size", new_callable=mock.PropertyMock
) as mock_backlog_size:
self.object._write_backlog()
mock_backlog_size.assert_not_called()

def test_store_failed_counts_failed_events(self):
self.object._write_backlog = mock.MagicMock()
super().test_store_failed_counts_failed_events()

def test_setup_registers_flush_timeout_tasks(self):
job_count = len(self.object._scheduler.jobs)
with pytest.raises(FatalOutputError):
self.object.setup()
assert len(self.object._scheduler.jobs) == job_count + 1

@pytest.mark.parametrize(
"error, message",
[
(
EndpointConnectionError(endpoint_url="http://xdfzy:123"),
r".*Could not connect to the endpoint URL.*",
),
(
ConnectionClosedError(endpoint_url="http://xdfzy:123"),
r".*Connection was closed before we received a valid response from endpoint URL.*",
),
(
ClientError(error_response={"foo": "bar"}, operation_name="HeadBucket"),
r".*An error occurred \(\w+\) when calling the HeadBucket operation: \w+.*",
),
(
BotoCoreError(),
r".*An unspecified error occurred.*",
),
],
)
def test_setup_raises_fataloutputerror_if_boto_exception_is_raised(self, error, message):
with mock.patch.object(
self.object._s3_resource.meta.client,
"head_bucket",
side_effect=error,
):
with pytest.raises(FatalOutputError, match=message):
self.object.setup()

@staticmethod
def _calculate_backlog_size(s3_output):
return sum(len(values) for values in s3_output._message_backlog.values())

0 comments on commit 28dbcec

Please sign in to comment.