From af5791b0e5bfff140a1a68d7b3a01d3aca9a7822 Mon Sep 17 00:00:00 2001 From: Ryxias Date: Tue, 16 Apr 2019 11:24:31 -0700 Subject: [PATCH] Merges 2.2.0 (#929) * Adds Alert Publisher framework (#900) * Squashed commit of AlertPublishers (Provisional) i aefoiajwoef initial draft to publisher crap awefla aewlfkjhawiehgsv Add some alerts aweklrgjhsakf fadsuhiawef Adds some basic publisher framework pylint Compoiste publisher Assemble_alert_publisher_for_output more code fix tests Replace with publishers WORKING COMMIT This commit is proven to work on Stage. TODOs: - publishers are hard to code, and the class system seems redundant? - merging needs to be tested - required outputs will get published if blanket publication is enabled Fixes a bug where publishers are assigned to required outputs Tidies up docblocks Flesh out publisher documentation * Extends SlackOutput to support custom messages and attachments * Refactor publishers into new directory Functional commit * Continues to fumble with classes and namespaces and stuff * cyclial python module dependency * Missing mock_s3 * workign commit fumbling with more namespaces and directories * pylint * Publishers working? * w * Fixups * Fix comment * fix * Consolidate some functions, rename some Classes to be more compact * more touchiups * i * Update documentation * Refactor to move DefaultPublisher into the core * Add test coverage * Address unused-argument * Remove some extraneous import module things * Move core.py publisher code into shared * fix comment * Catch, log, and reraise keyerror in composite publisher * Fix buggy docuemntationt * Fix docvumentation * @Rule -> @rule * fix documentation string * Reverses the order specific and unspecific output publishers are executed to be more intuitive * Move import_folders to new module * Remove deepcopy, delegating it to CompositePublisher * Clean up docblocks * Remove extraneous return * Write a test for chained inheritance * Move test files * Move incorrectly placed comment * Raise exception when failing to register publisher * Raise exception on invalid arg output * Touch up tests * Gets rid of ugly python \ * Fix some bad list code plus DRY out DefaultPublisher * Renames publish_alert to compose_alert to be less imperative * enforce ssl access only on all S3 buckets (#905) * enforcing ssl only access on all streamalert s3 buckets * updating unit tests for s3 bucket resource creation * fixing duplicate policy bug with cloudtrail bucket (#907) * [Alert Publishers] Add some community Slack Publishers (#904) * Adds some base Slack publishers for the community Remove stuff Fix up fix merge Touchups Remove deepcopy from slack publishers Fix bug and missing test inclusion due to missing init file fix licenses Capitalize fields and remove redundant fix tests Fix a... test? * Convert map() to array syntax. Fix timezone problem * [Alert Publishers] Standardizes magic fields with @-sign prefix (#917) * Prefixes all magic publisher fields with @ * Fix documentation * [Fixup] Alert Publisher PR Feedback (#918) * PR Fixups * I did not end up figuring out how to uncouple this cyclic dependency. optimzied instead * ? * wtf why is consider-using-ternary a mandatory pylint condition? * [Alert Publishers] Rebuilds PagerDuty integration + Adds some community pagerduty publishers (#911) * WIP: PagerDuty publishers fixup wip but maybe working commit of refactoring all this pagerduty crap Fixing unit tests unit tests Draft ya wrfg fixed wip Fix ssl verification Working commit; need add more test coverage Fix documentation Add publishers Test for enumerate_fields + fix bug Add publisher that strips out "streamalert:normalization" from the publication Ef * Upgrades integration paths * successfully deployed and tested v1 on staging * Finishes draft * fixup * Yeah... yeah.. * Expand docblocks * WIP * Fixups documents * Fix bug in pd publisher * fix * Fix bug and some documentation * remove * Fix some bugs * Fix some error messages * Fixes a bug causing alerts not to merge correctly * Fix tests * alphabetize enumerate_fields * Add new remove_fields publisher * Add more test for work * y * Pr feedback and add test * Fix some bugs introduced due to "default behavior" in pagerduty integration (#920) * Fix a bug regarding defaults in the pagerduty integration refactor * PR fixup * fix silly bug (#921) * Adds a new publisher, improves to description parser (#922) * Adds a publisher to bubble deep dict fields to top of publication * Improves the rule parser to accommodate for some weird cases * pylint * Fix comment * bubble_fields -> populate_fields * More PR feedback * Adds new publisher for converting array to string. * naisu * Moves some directories to be more consistent * pr feedback * [Publishers] Adds publisher error detection to rule_test.sh (#923) * tmp wip commit * First attempt at adding publishers to rule_tst * Improved format * fixps * Consolidates test logic a little * Improves format * PR Fixup: compose_alert now requires output * PR feedback * Improvements to Slack Publishers (#924) * Adds new publisher to slack * Improves description parser * Add tests * [Publishers] Adds support for images on Pagerduty v2 (#925) * Support images on pagerduty v2 API * Tests, and adds a new pagerduty publisher to attach an image * bumping version to 2.2.0 --- docs/source/index.rst | 1 + docs/source/outputs.rst | 25 +- docs/source/publishers.rst | 276 ++++ publishers/__init__.py | 0 publishers/community/__init__.py | 0 publishers/community/generic.py | 272 ++++ publishers/community/pagerduty/__init__.py | 0 .../community/pagerduty/pagerduty_layout.py | 150 ++ publishers/community/slack/__init__.py | 0 publishers/community/slack/slack_layout.py | 325 ++++ stream_alert/__init__.py | 2 +- stream_alert/alert_processor/helpers.py | 88 ++ stream_alert/alert_processor/outputs/aws.py | 75 +- .../alert_processor/outputs/carbonblack.py | 3 + .../alert_processor/outputs/demisto.py | 51 +- .../alert_processor/outputs/github.py | 25 +- stream_alert/alert_processor/outputs/jira.py | 19 +- .../alert_processor/outputs/komand.py | 7 +- .../alert_processor/outputs/output_base.py | 2 +- .../alert_processor/outputs/pagerduty.py | 1303 ++++++++++++----- .../alert_processor/outputs/phantom.py | 7 +- stream_alert/alert_processor/outputs/slack.py | 218 ++- stream_alert/rules_engine/rules_engine.py | 138 +- stream_alert/shared/alert.py | 42 +- stream_alert/shared/description.py | 164 +++ stream_alert/shared/importer.py | 61 + stream_alert/shared/publisher.py | 266 ++++ stream_alert/shared/rule.py | 47 +- stream_alert/shared/rule_table.py | 3 +- stream_alert_cli/manage_lambda/package.py | 2 + stream_alert_cli/test/handler.py | 50 +- stream_alert_cli/test/results.py | 232 ++- tests/integration/rules/duo/duo_fraud.json | 2 +- tests/unit/publishers/__init__.py | 0 tests/unit/publishers/community/__init__.py | 0 .../community/pagerduty/__init__.py | 0 .../pagerduty/test_pagerduty_layout.py | 229 +++ .../unit/publishers/community/test_generic.py | 506 +++++++ .../test_outputs/credentials/test_provider.py | 1 + .../test_outputs/test_demisto.py | 6 +- .../test_outputs/test_pagerduty.py | 1281 ++++++++++++---- .../test_outputs/test_slack.py | 117 +- .../test_publishers/__init__.py | 0 .../test_publishers/slack/__init__.py | 0 .../slack/test_slack_layout.py | 271 ++++ .../stream_alert_shared/test_description.py | 328 +++++ .../unit/stream_alert_shared/test_importer.py | 77 + .../stream_alert_shared/test_publisher.py | 397 +++++ tests/unit/stream_alert_shared/test_rule.py | 61 +- .../rules_engine/test_rules_engine.py | 258 +++- 50 files changed, 6488 insertions(+), 900 deletions(-) create mode 100644 docs/source/publishers.rst create mode 100644 publishers/__init__.py create mode 100644 publishers/community/__init__.py create mode 100644 publishers/community/generic.py create mode 100644 publishers/community/pagerduty/__init__.py create mode 100644 publishers/community/pagerduty/pagerduty_layout.py create mode 100644 publishers/community/slack/__init__.py create mode 100644 publishers/community/slack/slack_layout.py create mode 100644 stream_alert/shared/description.py create mode 100644 stream_alert/shared/importer.py create mode 100644 stream_alert/shared/publisher.py create mode 100644 tests/unit/publishers/__init__.py create mode 100644 tests/unit/publishers/community/__init__.py create mode 100644 tests/unit/publishers/community/pagerduty/__init__.py create mode 100644 tests/unit/publishers/community/pagerduty/test_pagerduty_layout.py create mode 100644 tests/unit/publishers/community/test_generic.py create mode 100644 tests/unit/stream_alert_alert_processor/test_publishers/__init__.py create mode 100644 tests/unit/stream_alert_alert_processor/test_publishers/slack/__init__.py create mode 100644 tests/unit/stream_alert_alert_processor/test_publishers/slack/test_slack_layout.py create mode 100644 tests/unit/stream_alert_shared/test_description.py create mode 100644 tests/unit/stream_alert_shared/test_importer.py create mode 100644 tests/unit/stream_alert_shared/test_publisher.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 775285148..60a66bdc4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -77,6 +77,7 @@ Table of Contents rules testing outputs + publishers metrics troubleshooting faq diff --git a/docs/source/outputs.rst b/docs/source/outputs.rst index c45c0a3bc..b025e3ba9 100644 --- a/docs/source/outputs.rst +++ b/docs/source/outputs.rst @@ -15,6 +15,7 @@ Out of the box, StreamAlert supports: * **AWS SNS** * **AWS SQS** * **CarbonBlack** +* **Demisto** * **GitHub** * **Jira** * **Komand** @@ -74,6 +75,9 @@ Adding support for a new service involves five steps: .. code-block:: python + from stream_alert.alert_processor.helpers import compose_alert + + def get_user_defined_properties(self): """Returns any properties for this output that must be provided by the user At a minimum, this method should prompt the user for a 'descriptor' value to @@ -88,12 +92,29 @@ Adding support for a new service involves five steps: '(ie: name of integration/channel/service/etc)')) ]) - def _dispatch(self, **kwargs): + def _dispatch(self, alert, descriptor): """Handles the actual sending of alerts to the configured service. Any external API calls for this service should be added here. This method should return a boolean where True means the alert was successfully sent. + + In general, use the compose_alert() method defined in stream_alert.alert_processor.helpers + when presenting the alert in a generic polymorphic format to be rendered on the chosen output + integration. This is so specialized Publishers can modify how the alert is represented on the + output. + + In addition, adding output-specific fields can be useful to offer more fine-grained control + of the look and feel of an alert. + + For example, an optional field that directly controls a PagerDuty incident's title: + - '@pagerduty.incident_title' + + + When referencing an alert's attributes, reference the alert's field directly (e.g. + alert.alert_id). Do not rely on the published alert. """ - ... + + publication = compose_alert(alert, self, descriptor) + # ... return True **Note**: The ``OutputProperty`` object used in ``get_user_defined_properties`` is a namedtuple consisting of a few properties: diff --git a/docs/source/publishers.rst b/docs/source/publishers.rst new file mode 100644 index 000000000..d1fceb2f3 --- /dev/null +++ b/docs/source/publishers.rst @@ -0,0 +1,276 @@ +Publishers +========== + +Overview +-------- + +Publishers are a framework for transforming alerts prior to dispatching to outputs, on a per-rule basis. +This allows users to customize the look and feel of alerts. + + +How do Publishers work? +----------------------- + +Publishers are blocks of code that are run during alert processing, immediately prior to dispatching +an alert to an output. + + + +Implementing new Publishers +--------------------------- + +All publishers must be added to the ``publishers`` directory. Publishers have two valid syntaxes: + + +**Function** + +Implement a top-level function with that accepts two arguments: An Alert and a dict. Decorate this function +with the ``@Register`` decorator. + +.. code-block:: python + + from stream_alert.shared.publisher import Register + + @Register + def my_publisher(alert: Alert, publication: dict) -> dict: + # ... + return {} + + +**Class** + +Implement a class that inherits from the ``AlertPublisher`` and fill in the implementations for ``publish()``. +Decorate the class with the ``@Register`` decorator. + +.. code-block:: python + + from stream_alert.shared.publisher import AlertPublisher, Register + + @Register + class MyPublisherClass(AlertPublisher): + + def publish(alert: Alert, publication: dict) -> dict: + # ... + return {} + + +**Recommended Implementation** + +Publishers should always return dicts containing only simple types (str, int, list, dict). + +Publishers are executed in series, each passing its published ``Alert`` to the next publisher. The ``publication`` +arg is the result of the previous publisher (or ``{}`` if it is the first publisher in the series). Publishers +should freely add, modify, or delete fields from previous publications. However, publishers should avoid +doing in-place modifications of the publications, and should prefer to copy-and-modify: + +.. code-block:: python + + from stream_alert.shared.publisher import Register + + @Register + def sample_publisher(alert, publication): + publication['new_field'] = 'new_value'] + publication.pop('old_field', None) + + return publication + + +Preparing Outputs +----------------- + +In order to take advantage of Publishers, all outputs must be implemented with the following guidelines: + +**Use compose_alert()** + +When presenting unstructured or miscellaneous data to an output (e.g. an email body, incident details), +outputs should be implemented to use the ``compose_alert(alert: Alert, output: OutputDispatcher, descriptor: str) -> dict`` +method. + +``compose_alert()`` loads all publishers relevant to the given ``Alert`` and executes these publishers in series, +returning the result of the final publisher. + +All data returned by ``compose_alert()`` should be assumed as optional. + +.. code-block:: python + + from stream_alert.alert_processor.helpers import compose_alert + + def _dispatch(self, alert, descriptor): + # ... + publication = compose_alert(alert, self, descriptor) + make_api_call(misc_data=publication) + + +**"Default" Implementations** + +For output-specific fields that are mandatory (such as an incident Title or assignee), each output +should offer a default implementation: + +.. code-block:: python + + def _dispatch(self, alert, descriptor): + default_title = 'Incident Title: #{}'.format(alert.alert_id) + default_html = 'Rule: {}'.format(alert.rule_description) + # ... + + +**Custom fields** + +Outputs can be implemented to offer custom fields that can be filled in by Publishers. This (optionally) +grants fine-grained control of outputs to Publishers. Such fields should adhere to the following conventions: + +* They are top level keys on the final publication dictionary +* Keys are strings, following the format: ``@{output_service}.{field_name}`` +* Keys MUST begin with an at-sign +* The ``output_service`` should match the current outputs ``cls.__service__`` value +* The ``field_name`` should describe its function +* Example: ``@slack.attachments`` + +Below is an example of how you could implement an output: + +.. code-block:: python + + def _dispatch(self, alert, descriptor): + # ... + publication = compose_alert(alert, self, descriptor) + + default_title = 'Incident Title: #{}'.format(alert.alert_id) + default_html = 'Rule: {}'.format(alert.rule_description) + + title = publication.get('@pagerduty.title', default_title) + body_html = publication.get('@pagerduty.body_html', default_html) + + make_api_call(title, body_html, data=publication) + + +**Alert Fields** + +When outputs require mandatory fields that are not subject to publishers, they should reference the ``alert`` +fields directly: + +.. code-block:: python + + def _dispatch(self, alert, descriptor): + rule_description = alert.rule_description + # ... + + +Registering Publishers +---------------------- + +Register publishers on a rule using the ``publisher`` argument on the ``@rule`` decorator: + +.. code-block:: python + + from publishers import publisher_1, publisher_2 + from stream_alert.shared.rule import Rule + + @rule( + logs=['stuff'], + outputs=['pagerduty', 'slack'], + publishers=[publisher_1, publisher_2] + ) + def my_rule(rec): + # ... + +The ``publishers`` argument is a structure containing references to **Publishers** and can follow any of the +following structures: + +**Single Publisher** + +.. code-block:: python + + publishers=publisher_1 + +When using this syntax, the given publisher will be applied to all outputs. + + +**List of Publishers** + +.. code-block:: python + + publishers=[publisher_1, publisher_2, publisher_3] + +When using this syntax, all given publishers will be applied to all outputs. + + +**Dict mapping Output strings to Publisher** + +.. code-block:: python + + publishers={ + 'pagerduty:analyst': [publisher_1, publisher_2], + 'pagerduty': [publisher_3, publisher_4], + 'demisto': other_publisher, + } + +When using this syntax, publishers under each key will be applied to their matching outputs. Publisher keys +with generic outputs (e.g. ``pagerduty``) are loaded first, before publisher keys that pertain to more +specific outputs (e.g. ``pagerduty:analyst``). + +The order in which publishers are loaded will dictate the order in which they are executed. + + +DefaultPublisher +---------------- + +When the ``publishers`` argument is omitted from a ``@rule``, a ``DefaultPublisher`` is loaded and used. This +also occurs when the ``publishers`` are misconfigured. + +The ``DefaultPublisher`` is reverse-compatible with old implementations of ``alert.output_dict()``. + + +Putting It All Together... +-------------------------- + +Here's a real-world example of how to effectively use Publishers and Outputs: + +PagerDuty requires all Incidents be created with an `Incident Summary`, which appears at as the title of every +incident in its UI. Additionally, you can optionally supply `custom details` which appear below as a large, +unstructured body. + +By default, the PagerDuty integration sends ``"StreamAlert Rule Triggered - rule_name"`` as the `Incident Summary`, +along with the entire Alert record in the `custom details`. + +However, the entire record can contain mostly irrelevant or redundant data, which can pollute the PagerDuty UI +and make triage slower, as responders must filter through a large record to find the relevant pieces of +information, this is especially true for alerts of very limited scope and well-understood remediation steps. + +Consider an example where informational alerts are triggered upon login into a machine. Responders only care +about the **time** of login, **source IP address**, and the **username** of the login. + +You can implement a publisher that only returns those three fields and strips out the rest from the alert. +The publisher can also simplify the PagerDuty title: + +.. code-block:: python + + from stream_alert.shared.publisher import Register + + @Register + def simplify_pagerduty_output(alert, publication): + return { + '@pagerduty.record': { + 'source_ip': alert.record['source_ip'], + 'time': alert.record['timestamp'], + 'username': alert.record['user'], + }, + '@pagerduty.summary': 'Machine SSH: {}'.format(alert.record['user']), + } + +Suppose this rule is being output to both PagerDuty and Slack, but you only wish to simplify the PagerDuty +integration, leaving the Slack integration the same. Registering the publisher can be done as such: + +.. code-block:: python + + from publishers.pagerduty import simplify_pagerduty_output + from stream_alert.shared.rule import Rule + + @rule( + logs=['ssh'], + output=['slack:engineering', 'pagerduty:engineering'], + publishers={ + 'pagerduty:engineering': simplify_pagerduty_output, + } + ) + def machine_ssh_login(rec): + # ... diff --git a/publishers/__init__.py b/publishers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/publishers/community/__init__.py b/publishers/community/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/publishers/community/generic.py b/publishers/community/generic.py new file mode 100644 index 000000000..b32d72eac --- /dev/null +++ b/publishers/community/generic.py @@ -0,0 +1,272 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from collections import deque, OrderedDict +import re +from stream_alert.shared.publisher import Register, AlertPublisher +from stream_alert.shared.normalize import Normalizer +from stream_alert.shared.utils import get_keys + + +@Register +def add_record(alert, publication): + """Publisher that adds the alert.record to the publication.""" + publication['record'] = alert.record + + return publication + + +@Register +def blank(*_): + """Erases all fields on existing publications and returns a blank dict""" + return {} + + +@Register +def remove_internal_fields(_, publication): + """This publisher removes fields from DefaultPublisher that are only useful internally""" + + publication.pop('staged', None) + publication.pop('publishers', None) + publication.pop('outputs', None) + + return publication + + +def _delete_dictionary_fields(publication, regexp): + """Deeply destroys all nested dict keys matching the given regexp string + + Args: + publication (dict): A publication + regexp (str): A String that is valid regexp + + Returns: + dict + (!) warning, will modify the original publication + """ + # Python is bad at recursion so I managed to tip toe around that with BFS using a queue. + # This heavily takes advantage of internal references being maintained properly as the loop + # does not actually track the "current scope" of the next_item. + fringe = deque() + fringe.append(publication) + while len(fringe) > 0: + next_item = fringe.popleft() + + if isinstance(next_item, dict): + for key in next_item.keys(): + if re.search(regexp, key): + next_item.pop(key, None) + + for key, item in next_item.iteritems(): + fringe.append(item) + elif isinstance(next_item, list): + fringe.extend(next_item) + else: + # It's a leaf node, or it's some strange object that doesn't belong here + pass + + return publication + + +@Register +def remove_fields(alert, publication): + """This publisher deletes fields from the current publication. + + The publisher uses the alert's context to determine which fields to delete. Example: + + context={ + 'remove_fields': ['^field1$', '^field2$', ...] + } + + "remove_fields" should be an array of strings that are valid regular expressions. + + The algorithm deeply searches the publication for any dict key that matches the given regular + expression. Any such key is removed, and if the value is a nested dict, the entire dict + branch underneath is removed. + """ + fields = alert.context.get('remove_fields', []) + + for field in fields: + publication = _delete_dictionary_fields(publication, field) + + return publication + + +@Register +def remove_streamalert_normalization(_, publication): + """This publisher removes the super heavyweight 'streamalert:normalization' fields""" + return _delete_dictionary_fields(publication, Normalizer.NORMALIZATION_KEY) + + +@Register +def enumerate_fields(_, publication): + """Flattens all currently published fields. + + By default, publications are deeply nested dict structures. This can be very hard to read + when rendered in certain outputs. PagerDuty is one example where the default UI does a very + poor job rendering nested dicts. + + This publisher collapses deeply nested fields into a single-leveled dict with keys that + correspond to the original path of each value in a deeply nested dict. For example: + + { + "top1": { + "mid1": "low", + "mid2": [ "low1", "low2", "low3" ], + "mid3": { + "low1": "verylow" + } + }, + "top2": "mid" + } + + .. would collapse into the following structure: + + { + "top1.mid1": "low", + "top1.mid2[0]": "low1", + "top1.mid2[1]": "low1", + "top1.mid2[2]": "low1", + "top1.mid3.low1: "verylow", + "top2": "mid" + } + + The output dict is an OrderedDict with keys sorted in alphabetical order. + """ + def _recursive_enumerate_fields(structure, output_reference, path=''): + if isinstance(structure, list): + for index, item in enumerate(structure): + _recursive_enumerate_fields(item, output_reference, '{}[{}]'.format(path, index)) + + elif isinstance(structure, dict): + for key in structure: + _recursive_enumerate_fields(structure[key], output_reference, '{prefix}{key}'.format( + prefix='{}.'.format(path) if path else '', # Omit first period + key=key + )) + + else: + output_reference[path] = structure + + output = {} + _recursive_enumerate_fields(publication, output) + + return OrderedDict(sorted(output.items())) + + +@Register +def populate_fields(alert, publication): + """This publisher moves all requested fields to the top level and ignores everything else. + + It uses the context to determine which fields to keep. Example: + + context={ + 'populate_fields': [ 'field1', 'field2', 'field3' ] + } + + "populate_fields" should be an array of strings that are exact matches to the field names. + + The algorithm deeply searches the publication for any dict key that exactly matches one of the + given fields. It then takes the contents of that field and moves them up to the top level. + It discovers ALL values matching each field, so if a field is returned multiple times, the + resulting top level field will be an array. In the special case where exactly one entry is + returned for a populate_field, the value will instead be equal to that value (instead of an + array with 1 element being that value). In the special case when no entries are returned for + an extract_field, the value will be None. + + Aside from the moved fields, this publisher throws away everything else in the original + publication. + + NOTE: It is possible for moved fields to continue to contain nested dicts, so do not assume + this publisher will result in a flat dictionary publication. + """ + + new_publication = {} + for populate_field in alert.context.get('populate_fields', []): + extractions = get_keys(publication, populate_field) + new_publication[populate_field] = extractions + + return new_publication + + +@Register +class StringifyArrays(AlertPublisher): + """Deeply navigates a dict publication and coverts all scalar arrays to strings + + Any array discovered with only scalar values will be joined into a single string with the + given DELIMITER. Subclass implementations of this can override the delimiter to join the + string differently. + """ + DELIMITER = '\n' + + def publish(self, alert, publication): + fringe = deque() + fringe.append(publication) + while len(fringe) > 0: + next_item = fringe.popleft() + + if isinstance(next_item, dict): + # Check all keys + for key, item in next_item.iteritems(): + if self.is_scalar_array(item): + next_item[key] = self.stringify(item) + else: + fringe.append(item) + + elif isinstance(next_item, list): + # At this point, if the item is a list we assert that it is not a SCALAR array; + # because it is too late to stringify it, since we do not have a back reference + # to the object that contains it + fringe.extend(next_item) + else: + # It's a leaf node, or it's some strange object that doesn't belong here + pass + + return publication + + @staticmethod + def is_scalar_array(item): + """Returns if the given item is a python list containing only scalar elements + + NOTE: This method assumes that the 'item' provided comes from a valid JSON compliant dict. + It does not account for strange or complicated types, such as references to functions + or class definitions or other stuff. + + Args: + item (mixed): The python variable to check + + Returns: + bool + """ + if not isinstance(item, list): + return False + + for element in item: + if isinstance(element, dict) or isinstance(element, list): + return False + + return True + + @classmethod + def stringify(cls, array): + """Given a list of elements, will join them together with the publisher's DELIMITER + + Args: + array (list): The array of elements. + + Returns: + str + """ + return cls.DELIMITER.join([str(elem) for elem in array]) diff --git a/publishers/community/pagerduty/__init__.py b/publishers/community/pagerduty/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/publishers/community/pagerduty/pagerduty_layout.py b/publishers/community/pagerduty/pagerduty_layout.py new file mode 100644 index 000000000..245f7ae27 --- /dev/null +++ b/publishers/community/pagerduty/pagerduty_layout.py @@ -0,0 +1,150 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from publishers.community.generic import StringifyArrays +from stream_alert.shared.publisher import AlertPublisher, Register + + +@Register +class ShortenTitle(AlertPublisher): + """A publisher that shortens the title of PagerDuty incidents. + + By popular demand from our TIR team! By default, PagerDuty incidents have a title that look + something like 'StreamAlert Rule Triggered - blah_blah_blah'. If StreamAlert is the only + system producing alerts into PagerDuty, this is a lot of extraneous data. + + Instead, this publisher strips out the 'StreamAlert Rule Triggered' prefix and opts to only + output the rule name. + """ + + def publish(self, alert, publication): + + publication['@pagerduty-v2.summary'] = alert.rule_name + publication['@pagerduty-incident.incident_title'] = alert.rule_name + publication['@pagerduty.description'] = alert.rule_name + + return publication + + +@Register +def as_custom_details(_, publication): + """Takes the current publication and sends the entire thing to custom details. + + It does this for all fields EXCEPT the pagerduty special fields. + """ + def _is_custom_field(key): + return key.startswith('@pagerduty') + + custom_details = { + key: value for key, value in publication.iteritems() if not _is_custom_field(key) + } + + publication['@pagerduty.details'] = custom_details + publication['@pagerduty-v2.custom_details'] = custom_details + + return publication + + +@Register +def v2_high_urgency(_, publication): + """Designates this alert as critical or high urgency + + This only works for pagerduty-v2 and pagerduty-incident Outputs. The original pagerduty + integration uses the Events v1 API which does not support urgency. + """ + publication['@pagerduty-v2.severity'] = 'critical' + publication['@pagerduty-incident.urgency'] = 'high' + return publication + + +@Register +def v2_low_urgency(_, publication): + """Designates this alert as a warning or low urgency + + This only works for pagerduty-v2 and pagerduty-incident Outputs. The original pagerduty + integration uses the Events v1 API which does not support urgency. + """ + publication['@pagerduty-v2.severity'] = 'warning' + publication['@pagerduty-incident.urgency'] = 'low' + return publication + + +@Register +class PrettyPrintArrays(StringifyArrays): + """Deeply navigates a dict publication and coverts all scalar arrays to strings + + Scalar arrays render poorly on PagerDuty's default UI. Newlines are ignored, and the scalar + values are wrapped with quotations: + + [ + "element_here\n with newlines\noh no", + "hello world\nhello world" + ] + + This method searches the publication dict for scalar arrays and transforms them into strings + by joining their values with the provided delimiter. This converts the above array into: + + element here + with newlines + oh no + + ---------- + + hello world + hello world + """ + DELIMITER = '\n\n----------\n\n' + + +@Register +class AttachImage(StringifyArrays): + """Attaches the given image to the PagerDuty request + + Works for both the v1 and v2 event api integrations. + + It is recommended to subclass this class with your own implementation of _image_url(), + _click_url() and _alt_text() so that you can customize your own image. + """ + IMAGE_URL = 'https://streamalert.io/en/stable/_images/sa-banner.png' + IMAGE_CLICK_URL = 'https://streamalert.io/en/stable/' + IMAGE_ALT_TEXT = 'StreamAlert Docs' + + def publish(self, alert, publication): + publication['@pagerduty-v2.images'] = publication.get('@pagerduty-v2.images', []) + publication['@pagerduty-v2.images'].append({ + 'src': self._image_url(), + 'href': self._click_url(), + 'alt': self._alt_text(), + }) + + publication['@pagerduty.contexts'] = publication.get('@pagerduty.contexts', []) + publication['@pagerduty.contexts'].append({ + 'type': 'image', + 'src': self._image_url(), + }) + + return publication + + @classmethod + def _image_url(cls): + return cls.IMAGE_URL + + @classmethod + def _click_url(cls): + return cls.IMAGE_CLICK_URL + + @classmethod + def _alt_text(cls): + return cls.IMAGE_ALT_TEXT diff --git a/publishers/community/slack/__init__.py b/publishers/community/slack/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/publishers/community/slack/slack_layout.py b/publishers/community/slack/slack_layout.py new file mode 100644 index 000000000..7545523a9 --- /dev/null +++ b/publishers/community/slack/slack_layout.py @@ -0,0 +1,325 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import calendar +import cgi +import json +import urllib + +from stream_alert.shared.publisher import AlertPublisher, Register +from stream_alert.shared.description import RuleDescriptionParser + +RAUSCH = '#ff5a5f' +BABU = '#00d1c1' +LIMA = '#8ce071' +HACKBERRY = '#7b0051' +BEACH = '#ffb400' + + +@Register +class Summary(AlertPublisher): + """Adds a brief summary with the rule triggered, author, description, and time + + To customize the behavior of this Publisher, it is recommended to subclass this and override + parameters as necessary. For example, an implementation could override _GITHUB_REPO_URL with + the URL appropriate for the organization using StreamAlert. + """ + + _GITHUB_REPO_URL = 'https://github.com/airbnb/streamalert' + _SEARCH_PATH = '/search' + _RULES_PATH = '/rules' + + def publish(self, alert, publication): + rule_name = alert.rule_name + rule_description = alert.rule_description + rule_presentation = RuleDescriptionParser.present(rule_description) + + author = rule_presentation['author'] + + return { + '@slack.text': 'Rule triggered', + '@slack.attachments': [ + { + 'fallback': 'Rule triggered: {}'.format(rule_name), + 'color': self._color(), + 'author_name': author, + 'author_link': self._author_url(author), + 'author_icon': self._author_icon(author), + 'title': rule_name, + 'title_link': self._title_url(rule_name), + 'text': cgi.escape(rule_presentation['description']), + 'image_url': '', + 'thumb_url': '', + 'footer': '', + 'footer_icon': '', + 'ts': calendar.timegm(alert.created.timetuple()) if alert.created else '', + 'mrkdwn_in': [], + }, + ], + + # This information is passed-through to future publishers. + '@slack._previous_publication': publication, + } + + @staticmethod + def _color(): + """The color of this section""" + return RAUSCH + + @classmethod + def _author_url(cls, _): + """When given an author name, returns a clickable link, if any""" + return '' + + @classmethod + def _author_icon(cls, _): + """When given an author name, returns a URL to an icon, if any""" + return '' + + @classmethod + def _title_url(cls, rule_name): + """When given the rule_name, returns a clickable link, if any""" + + # It's actually super hard to generate a exact link to a file just from the rule_name, + # because the rule/ directory files are not deployed with the publishers in the alert + # processor. + # Instead, we send them to Github with a formatted query string that is LIKELY to + # find the correct file. + # + # If you do not want URLs to show up, simply override this method and return empty string. + return '{}{}?{}'.format( + cls._GITHUB_REPO_URL, + cls._SEARCH_PATH, + urllib.urlencode({ + 'q': '{} path:{}'.format(rule_name, cls._RULES_PATH) + }) + ) + + +@Register +class AttachRuleInfo(AlertPublisher): + """This publisher adds a slack attachment with fields from the rule's description + + It can include such fields as "reference" or "playbook" but will NOT include the description + or the author. + """ + + def publish(self, alert, publication): + publication['@slack.attachments'] = publication.get('@slack.attachments', []) + + rule_description = alert.rule_description + rule_presentation = RuleDescriptionParser.present(rule_description) + + publication['@slack.attachments'].append({ + 'color': self._color(), + 'fields': [ + {'title': key.capitalize(), 'value': rule_presentation['fields'][key]} + for key in rule_presentation['fields'].keys() + ], + }) + + return publication + + @staticmethod + def _color(): + return LIMA + + +@Register +class AttachPublication(AlertPublisher): + """A publisher that attaches previous publications as an attachment + + By default, this publisher needs to be run after the Summary publisher, as it depends on + the magic-magic _previous_publication field. + """ + + def publish(self, alert, publication): + if '@slack._previous_publication' not in publication or '@slack.attachments' not in publication: + # This publisher cannot be run except immediately after the Summary publisher + return publication + + publication_block = '```\n{}\n```'.format( + json.dumps( + self._get_publication(alert, publication), + indent=2, + sort_keys=True, + separators=(',', ': ') + ) + ) + + publication['@slack.attachments'].append({ + 'color': self._color(), + 'title': 'Alert Data:', + 'text': cgi.escape(publication_block), + 'mrkdwn_in': ['text'], + }) + + return publication + + @staticmethod + def _color(): + return BABU + + @staticmethod + def _get_publication(_, publication): + return publication['@slack._previous_publication'] + + +@Register +class AttachStringTemplate(AlertPublisher): + """An extremely flexible publisher that simply renders an attachment as text + + By default, this publisher accepts a template from the alert.context['slack_message_template'] + which is .format()'d with the current publication. This allows individual rules to render + whatever they want. The template is a normal slack message, so it can support newline + characters, and any of slack's pseudo-markdown. + + Subclass implementations of this can decide to override any of the implementation or come + up with their own! + + If this publisher is run after the Summary publisher, it will correctly pull the original + publication from the @slack._previous_publication, otherwise it uses the default publication. + """ + + def publish(self, alert, publication): + rendered_text = self._render_text(alert, publication) + + publication['@slack.attachments'] = publication.get('@slack.attachments', []) + publication['@slack.attachments'].append({ + 'color': self._color(), + 'text': cgi.escape(rendered_text), + }) + + return publication + + @classmethod + def _render_text(cls, alert, publication): + template = cls._get_format_template(alert, publication) + args = cls._get_template_args(alert, publication) + + return template.format(**args) + + @staticmethod + def _get_format_template(alert, _): + return alert.context.get('slack_message_template', '[MISSING TEMPLATE]') + + @staticmethod + def _get_template_args(_, publication): + return ( + publication['@slack._previous_publication'] + if '@slack._previous_publication' in publication + else publication + ) + + @staticmethod + def _color(): + return BEACH + + +@Register +class AttachFullRecord(AlertPublisher): + """This publisher attaches slack attachments generated from the Alert's full record + + The full record is likely to be significantly longer than the slack max messages size. + So we cut up the record by rows and send it as a series of 1 or more attachments. + The attachments are rendered in slack in a way such that a mouse drag and copy will + copy the entire JSON in-tact. + + The first attachment is slightly different as it includes the source entity where the + record originated from. The last attachment includes a footer. + """ + _SLACK_MAXIMUM_ATTACHMENT_CHARACTER_LENGTH = 4000 + + # Reserve space at the beginning and end of the attachment text for backticks and newlines + _LENGTH_PADDING = 10 + + def publish(self, alert, publication): + publication['@slack.attachments'] = publication.get('@slack.attachments', []) + + # Generate the record and then dice it up into parts + record_document = json.dumps(alert.record, indent=2, sort_keys=True, separators=(',', ': ')) + + # Escape the document FIRST because it can increase character length which can throw off + # document slicing + record_document = cgi.escape(record_document) + record_document_lines = record_document.split('\n') + + def make_attachment(document, is_first, is_last): + + footer = '' + if is_last: + footer_url = self._source_service_url(alert.source_service) + if footer_url: + footer = 'via <{}|{}>'.format(footer_url, alert.source_service) + else: + 'via {}'.format(alert.source_service) + + return { + 'color': self._color(), + 'author': alert.source_entity if is_first else '', + 'title': 'Record' if is_first else '', + 'text': '```\n{}\n```'.format(document), + 'fields': [ + { + "title": "Alert Id", + "value": alert.alert_id, + } + ] if is_last else [], + 'footer': footer, + 'footer_icon': self._footer_icon_from_service(alert.source_service), + 'mrkdwn_in': ['text'], + } + + character_limit = self._SLACK_MAXIMUM_ATTACHMENT_CHARACTER_LENGTH - self._LENGTH_PADDING + is_first_document = True + next_document = '' + while len(record_document_lines) > 0: + # Loop, removing one line at a time and attempting to attach it to the next document + # When the next document nears the maximum attachment size, it is flushed, generating + # a new attachment, and the document is reset before the loop pops off the next line. + + next_item_length = len(record_document_lines[0]) + next_length = next_item_length + len(next_document) + if next_document and next_length > character_limit: + # Do not pop off the item just yet. + publication['@slack.attachments'].append( + make_attachment(next_document, is_first_document, False) + ) + next_document = '' + is_first_document = False + + next_document += '\n' + record_document_lines.pop(0) + + # Attach last document, if any remains + if next_document: + publication['@slack.attachments'].append( + make_attachment(next_document, is_first_document, True) + ) + + return publication + + @staticmethod + def _color(): + return HACKBERRY + + @staticmethod + def _source_service_url(source_service): + """A best-effort guess at the AWS dashboard link for the requested service.""" + return 'https://console.aws.amazon.com/{}/home'.format(source_service) + + @staticmethod + def _footer_icon_from_service(_): + """Returns the URL of an icon, given an AWS service""" + return '' diff --git a/stream_alert/__init__.py b/stream_alert/__init__.py index 9b633e0e3..03464351f 100644 --- a/stream_alert/__init__.py +++ b/stream_alert/__init__.py @@ -1,2 +1,2 @@ """StreamAlert version.""" -__version__ = '2.1.6' +__version__ = '2.2.0' diff --git a/stream_alert/alert_processor/helpers.py b/stream_alert/alert_processor/helpers.py index 9c378012e..ea6ab46c0 100644 --- a/stream_alert/alert_processor/helpers.py +++ b/stream_alert/alert_processor/helpers.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from stream_alert.shared.publisher import AlertPublisherRepository, PublisherAssemblyError def elide_string_middle(text, max_length): @@ -30,3 +31,90 @@ def elide_string_middle(text, max_length): half_len = (max_length - 5) / 2 # Length of text on either side. return '{} ... {}'.format(text[:half_len], text[-half_len:]) + + +def compose_alert(alert, output, descriptor): + """Presents the alert as a dict for output classes to send to their API integrations. + + Args: + alert (Alert): The alert to be dispatched + output (OutputDispatcher|None): Instance of the output class dispatching this alert + descriptor (str): The descriptor for the output + + Returns: + dict + """ + + # FIXME (derek.wang) + # this is here because currently the OutputDispatcher sits in a __init__.py module that + # performs eager loading of the other output classes. If you load this helper before + # the output classes are loaded, it creates a cyclical dependency + # helper.py -> output_base.py -> __init__ -> komand.py -> helper.py + # This is a temporary workaround + # A more permanent solution will involve refactoring the OutputDispatcher to load on-demand + # instead of eagerly. + from stream_alert.alert_processor.outputs.output_base import OutputDispatcher + output_service_name = output.__service__ if isinstance(output, OutputDispatcher) else None + + if not output_service_name: + raise PublisherAssemblyError('Invalid output service') + + publisher = _assemble_alert_publisher_for_output( + alert, + output_service_name, + descriptor + ) + return publisher.publish(alert, {}) + + +def _assemble_alert_publisher_for_output(alert, output_service_name, descriptor): + """Gathers all requested publishers on the alert and returns them as a single Publisher + + Note: The alert.publishers field cannot contain references to actual publisher functions + or classes, because this field is pulled from Dynamo. It is always going to be a JSON-formatted + object. + + Args: + alert (Alert): The alert that is pulled from DynamoDB + output_service_name (str): The __service__ name of the OutputDispatcher + descriptor (str): The descriptor of the Output + + Returns: + AlertPublisher + """ + + alert_publishers = alert.publishers + publisher_names = [] + if isinstance(alert_publishers, basestring): + # Case 1: The publisher is a single string. + # apply this single publisher to all outputs + descriptors + publisher_names.append(alert_publishers) + elif isinstance(alert_publishers, list): + # Case 2: The publisher is an array of strings. + # apply all publishers to all outputs + descriptors + publisher_names += alert_publishers + elif isinstance(alert_publishers, dict): + # Case 3: The publisher is a dict mapping output strings -> strings or list of strings + # apply only publishers under the correct output key. We look under 2 keys: + # one key that applies publishers to all outputs for a specific output type, and + # another key that applies publishers only to outputs of the type AND matching + # descriptor. + + # Order is important here; we load the output-generic publishers first + if output_service_name and output_service_name in alert_publishers: + publisher_name_or_names = alert_publishers[output_service_name] + if isinstance(publisher_name_or_names, list): + publisher_names = publisher_names + publisher_name_or_names + else: + publisher_names.append(publisher_name_or_names) + + # Then load output+descriptor-specific publishers second + described_output_name = '{}:{}'.format(output_service_name, descriptor) + if described_output_name in alert_publishers: + publisher_name_or_names = alert_publishers[described_output_name] + if isinstance(publisher_name_or_names, list): + publisher_names += publisher_name_or_names + else: + publisher_names.append(publisher_name_or_names) + + return AlertPublisherRepository.create_composite_publisher(publisher_names) diff --git a/stream_alert/alert_processor/outputs/aws.py b/stream_alert/alert_processor/outputs/aws.py index 1f4843e1e..7409d6375 100644 --- a/stream_alert/alert_processor/outputs/aws.py +++ b/stream_alert/alert_processor/outputs/aws.py @@ -23,7 +23,7 @@ from botocore.exceptions import ClientError import boto3 -from stream_alert.alert_processor.helpers import elide_string_middle +from stream_alert.alert_processor.helpers import compose_alert, elide_string_middle from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -100,6 +100,10 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send alert to a Kinesis Firehose Delivery Stream + Publishing: + By default this output sends the current publication in JSON to Kinesis. + There is no "magic" field to "override" it: Simply publish what you want to send! + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -133,7 +137,7 @@ def _firehose_request_wrapper(json_alert, delivery_stream): if self.__aws_client__ is None: self.__aws_client__ = boto3.client('firehose', region_name=self.region) - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) json_alert = json.dumps(publication, separators=(',', ':')) + '\n' if len(json_alert) > self.MAX_RECORD_SIZE: @@ -187,6 +191,14 @@ def _dispatch(self, alert, descriptor): The alert gets dumped to a JSON string to be sent to the Lambda function + Publishing: + By default this output sends the JSON-serialized alert record as the payload to the + lambda function. You can override this: + + - @aws-lambda.alert_data (dict): + Overrides the alert record. Will instead send this dict, JSON-serialized, to + Lambda as the payload. + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -194,10 +206,15 @@ def _dispatch(self, alert, descriptor): Returns: bool: True if alert was sent successfully, False otherwise """ - publication = alert.publish_for(self, descriptor) - record = publication.get('record', {}) + publication = compose_alert(alert, self, descriptor) + + # Defaults + default_alert_data = alert.record - alert_string = json.dumps(record, separators=(',', ':')) + # Override with publisher + alert_data = publication.get('@aws-lambda.alert_data', default_alert_data) + + alert_string = json.dumps(alert_data, separators=(',', ':')) function_name = self.config[self.__service__][descriptor] # Check to see if there is an optional qualifier included here @@ -272,6 +289,10 @@ def _dispatch(self, alert, descriptor): service/entity/rule_name/datetime.json The alert gets dumped to a JSON string + Publishing: + By default this output sends the current publication in JSON to S3. + There is no "magic" field to "override" it: Simply publish what you want to send! + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -295,7 +316,7 @@ def _dispatch(self, alert, descriptor): LOGGER.debug('Sending %s to S3 bucket %s with key %s', alert, bucket, key) - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) client = boto3.client('s3', region_name=self.region) client.put_object(Body=json.dumps(publication), Bucket=bucket, Key=key) @@ -312,6 +333,16 @@ class SNSOutput(AWSOutput): def get_user_defined_properties(cls): """Properties assigned by the user when configuring a new SNS output. + Publishing: + By default this output sets a default subject and sends a message body that is the + JSON-serialized publication including indents/newlines. You can override this behavior: + + - @aws-sns.topic (str): + Sends a custom subject + + - @aws-sns.message (str); + Send a custom message body. + Returns: OrderedDict: With 'descriptor' and 'aws_value' OutputProperty tuples """ @@ -336,7 +367,7 @@ def _dispatch(self, alert, descriptor): topic_arn = 'arn:aws:sns:{}:{}:{}'.format(self.region, self.account_id, topic_name) topic = boto3.resource('sns', region_name=self.region).Topic(topic_arn) - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) # Presentation defaults default_subject = '{} triggered alert {}'.format(alert.rule_name, alert.alert_id) @@ -344,8 +375,8 @@ def _dispatch(self, alert, descriptor): # Published presentation fields # Subject must be < 100 characters long; - subject = elide_string_middle(publication.get('aws-sns.topic', default_subject), 99) - message = publication.get('aws-sns.message', default_message) + subject = elide_string_middle(publication.get('@aws-sns.topic', default_subject), 99) + message = publication.get('@aws-sns.message', default_message) topic.publish( Message=message, @@ -380,6 +411,14 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send alert to an SQS queue + Publishing: + By default it sends the alert.record to SQS as a JSON string. You can override + it with the following fields: + + - @aws-sqs.message_data (dict): + Replace alert.record with your own JSON-serializable dict. Will send this + as a JSON string to SQS. + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -391,16 +430,17 @@ def _dispatch(self, alert, descriptor): sqs = boto3.resource('sqs', region_name=self.region) queue = sqs.get_queue_by_name(QueueName=queue_name) - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) # Presentation defaults - record = publication.get('record', {}) - default_message_body = json.dumps(record, separators=(',', ':')) + default_message_data = alert.record # Presentation values - message_body = publication.get('aws-sqs:message_body', default_message_body) + message_data = publication.get('@aws-sqs.message_data', default_message_data) - queue.send_message(MessageBody=message_body) + # Transform the body from a dict to a string for SQS + sqs_message = json.dumps(message_data, separators=(',', ':')) + queue.send_message(MessageBody=sqs_message) return True @@ -413,6 +453,11 @@ class CloudwatchLogOutput(AWSOutput): @classmethod def get_user_defined_properties(cls): """Get properties that must be assigned by the user when configuring a new Lambda + + Publishing: + By default this output sends the current publication in JSON to CloudWatch. + There is no "magic" field to "override" it: Simply publish what you want to send! + Returns: OrderedDict: Contains various OutputProperty items """ @@ -428,7 +473,7 @@ def _dispatch(self, alert, descriptor): alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor """ - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) LOGGER.info('New Alert:\n%s', json.dumps(publication, indent=2)) return True diff --git a/stream_alert/alert_processor/outputs/carbonblack.py b/stream_alert/alert_processor/outputs/carbonblack.py index 0b2d1610c..2d2e59913 100644 --- a/stream_alert/alert_processor/outputs/carbonblack.py +++ b/stream_alert/alert_processor/outputs/carbonblack.py @@ -60,6 +60,9 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send ban hash command to CarbonBlack + Publishing: + There is currently no method to control carbonblack's behavior with publishers. + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor diff --git a/stream_alert/alert_processor/outputs/demisto.py b/stream_alert/alert_processor/outputs/demisto.py index e4c5dd6ee..15094b436 100644 --- a/stream_alert/alert_processor/outputs/demisto.py +++ b/stream_alert/alert_processor/outputs/demisto.py @@ -15,6 +15,7 @@ """ from collections import OrderedDict +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -59,6 +60,39 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send a new Incident to Demisto + Publishing: + Demisto offers a suite of default incident values. You can override any of the + following: + + - @demisto.incident_type (str): + + - @demisto.severity (str): + Controls the severity of the incident. Any of the following: + 'info', 'informational', 'low', 'med', 'medium', 'high', 'critical', 'unknown' + + - @demisto.owner (str): + Controls which name shows up under the owner. This can be any name, even of + users that are not registered on Demisto. Incidents can be filtered by name. + + - @demisto.details (str): + A string that briefly describes the nature of the incident and how to respond. + + - @demisto.incident_name (str): + Incident name shows up as the title of the Incident. + + - @demisto.label_data (dict): + By default, this output sends the entire publication into the Demisto labels + section, where the label names are the keys of the publication and the label + values are the values of the publication. + + For deeply nested dictionary publications, the label names become the full path + of all nest dictionary keys, concatenated with periods ("."). + + By providing this override field, you can send a different dict of data to + Demisto, other than the entire publication. Just like in the default case, + if this provided dict is deeply nested, the keys will be flattened. + + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -70,7 +104,7 @@ def _dispatch(self, alert, descriptor): if not creds: return False - request = DemistoRequestAssembler.assemble(alert, alert.publish_for(self, descriptor)) + request = DemistoRequestAssembler.assemble(alert, compose_alert(alert, self, descriptor)) integration = DemistoApiIntegration(creds, self) LOGGER.debug('Sending alert to Demisto: %s', creds['url']) @@ -241,7 +275,6 @@ def assemble(alert, alert_publication): Args: alert (Alert): Instance of the alert alert_publication (Dict): Published alert data of the alert that triggered a rule - descriptor (str): Output descriptor Returns: DemistoCreateIncidentRequest @@ -252,15 +285,17 @@ def assemble(alert, alert_publication): default_severity = 'unknown' default_owner = 'StreamAlert' default_details = alert.rule_description + default_label_data = alert_publication # Special keys that publishers can use to modify default presentation - incident_type = alert_publication.get('demisto.incident_type', default_incident_type) + incident_type = alert_publication.get('@demisto.incident_type', default_incident_type) severity = DemistoCreateIncidentRequest.map_severity_string_to_severity_value( - alert_publication.get('demisto.severity', default_severity) + alert_publication.get('@demisto.severity', default_severity) ) - owner = alert_publication.get('demisto.owner', default_owner) - details = alert_publication.get('demisto.details', default_details) - incident_name = alert_publication.get('demisto.incident_name', default_incident_name) + owner = alert_publication.get('@demisto.owner', default_owner) + details = alert_publication.get('@demisto.details', default_details) + incident_name = alert_publication.get('@demisto.incident_name', default_incident_name) + label_data = alert_publication.get('@demisto.label_data', default_label_data) request = DemistoCreateIncidentRequest( incident_name=incident_name, @@ -288,6 +323,6 @@ def enumerate_fields(record, path=''): else: request.add_label(path, record) - enumerate_fields(alert_publication) + enumerate_fields(label_data) return request diff --git a/stream_alert/alert_processor/outputs/github.py b/stream_alert/alert_processor/outputs/github.py index 1ad4db8b2..d9a1986b5 100644 --- a/stream_alert/alert_processor/outputs/github.py +++ b/stream_alert/alert_processor/outputs/github.py @@ -17,6 +17,7 @@ import base64 import json +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -75,6 +76,17 @@ def _get_default_properties(cls): def _dispatch(self, alert, descriptor): """Send alert to Github + Publishing: + This output provides a default issue title and a very basic issue body containing + the alert record. To override: + + - @github.title (str): + Override the Issue's title + + - @github.body (str): + Overrides the default github issue body. Remember: this string is in Github's + syntax, so it supports markdown and respects linebreaks characters (e.g. \n). + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -93,20 +105,23 @@ def _dispatch(self, alert, descriptor): url = '{}/repos/{}/issues'.format(credentials['api'], credentials['repository']) - publication = alert.publish_for(self, descriptor) - record = publication.get('record', {}) + publication = compose_alert(alert, self, descriptor) # Default presentation to the output default_title = "StreamAlert: {}".format(alert.rule_name) default_body = "### Description\n{}\n\n### Event data\n\n```\n{}\n```".format( alert.rule_description, - json.dumps(record, indent=2, sort_keys=True) + json.dumps(alert.record, indent=2, sort_keys=True) ) + # Override presentation defaults + issue_title = publication.get('@github.title', default_title) + issue_body = publication.get('@github.body', default_body) + # Github Issue to be created issue = { - 'title': publication.get('github.title', default_title), - 'body': publication.get('github.body', default_body), + 'title': issue_title, + 'body': issue_body, 'labels': credentials['labels'].split(',') } diff --git a/stream_alert/alert_processor/outputs/jira.py b/stream_alert/alert_processor/outputs/jira.py index ae5dbecac..ad5c73a49 100644 --- a/stream_alert/alert_processor/outputs/jira.py +++ b/stream_alert/alert_processor/outputs/jira.py @@ -17,6 +17,7 @@ import json import os +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -278,6 +279,18 @@ def _establish_session(self, username, password): def _dispatch(self, alert, descriptor): """Send alert to Jira + Publishing: + This output uses a default issue summary and sends the entire publication into the + issue body as a {{code}} block. To override: + + - @jira.issue_summary (str): + Overrides the issue title that shows up at the top on the JIRA UI + + - @jira.description (str): + Send your own custom description. Remember: This field is in JIRA's syntax, + so it supports their custom markdown-like formatting and respects newline + characters (e.g. \n). + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -289,7 +302,7 @@ def _dispatch(self, alert, descriptor): if not creds: return False - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) # Presentation defaults default_issue_summary = 'StreamAlert {}'.format(alert.rule_name) @@ -298,8 +311,8 @@ def _dispatch(self, alert, descriptor): ) # True Presentation values - issue_summary = publication.get('jira.issue_summary', default_issue_summary) - description = publication.get('jira.description', default_alert_body) + issue_summary = publication.get('@jira.issue_summary', default_issue_summary) + description = publication.get('@jira.description', default_alert_body) issue_id = None comment_id = None diff --git a/stream_alert/alert_processor/outputs/komand.py b/stream_alert/alert_processor/outputs/komand.py index 52ccffdbf..85452b215 100644 --- a/stream_alert/alert_processor/outputs/komand.py +++ b/stream_alert/alert_processor/outputs/komand.py @@ -15,6 +15,7 @@ """ from collections import OrderedDict +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -62,6 +63,10 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send alert to Komand + Publishing: + By default this output sends the current publication to Komand. + There is no "magic" field to "override" it: Simply publish what you want to send! + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -77,7 +82,7 @@ def _dispatch(self, alert, descriptor): LOGGER.debug('sending alert to Komand') - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) resp = self._post_request(creds['url'], {'data': publication}, headers, False) return self._check_http_response(resp) diff --git a/stream_alert/alert_processor/outputs/output_base.py b/stream_alert/alert_processor/outputs/output_base.py index 9c230cda0..ea2d92fe8 100644 --- a/stream_alert/alert_processor/outputs/output_base.py +++ b/stream_alert/alert_processor/outputs/output_base.py @@ -395,7 +395,7 @@ def dispatch(self, alert, output): Args: alert (Alert): Alert instance which triggered a rule - descriptor (str): Output descriptor (e.g. slack channel, pd integration) + output (str): Fully described output (e.g. "demisto:version1", "pagerduty:engineering" Returns: bool: True if alert was sent successfully, False otherwise diff --git a/stream_alert/alert_processor/outputs/pagerduty.py b/stream_alert/alert_processor/outputs/pagerduty.py index b4aa676bc..72c76f674 100644 --- a/stream_alert/alert_processor/outputs/pagerduty.py +++ b/stream_alert/alert_processor/outputs/pagerduty.py @@ -13,10 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ +# pylint: disable=too-many-lines from collections import OrderedDict -import os import backoff +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -33,48 +34,151 @@ LOGGER = get_logger(__name__) +# https://support.pagerduty.com/docs/dynamic-notifications +SEVERITY_CRITICAL = 'critical' +SEVERITY_ERROR = 'error' +SEVERITY_WARNING = 'warning' +SEVERITY_INFO = 'info' +SEVERITY_UNKNOWN = 'unknown' # empty string and any string not in the above defaults to "unknown" -def events_v2_data(output_dispatcher, descriptor, alert, routing_key, with_record=True): - """Helper method to generate the payload to create an event using PagerDuty Events API v2 - Args: - output_dispatcher (OutputDispatcher): The output sending these data - descriptor (str): The descriptor of the output sending these data - alert (Alert): Alert relevant to the triggered rule - routing_key (str): Routing key for this PagerDuty integration - with_record (boolean): Option to add the record data or not +class PagerdutySearchDelay(Exception): + """PagerdutyAlertDelay handles any delays looking up PagerDuty Incidents""" + + +class EventsV2DataProvider(object): + """This class is meant to be mixed-into pagerduty outputs that integrate with v2 of the API - Returns: - dict: Contains JSON blob to be used as event + This is called the CommonEventFormat (PD-CEF). Documentation can be found here: + https://support.pagerduty.com/docs/pd-cef """ - publication = alert.publish_for(output_dispatcher, descriptor) - - # Presentation defaults - default_summary = 'StreamAlert Rule Triggered - {}'.format(alert.rule_name) - default_description = alert.rule_description - default_record = alert.record - - # Special field that Publishers can use to customize the header - # FIXME (derek.wang) the publication key does not adhere to the convention of __service__ - # as a prefix, since this method is overloaded between two different outputs - summary = publication.get('pagerduty.summary', default_summary) - details = OrderedDict() - details['description'] = publication.get('pagerduty.description', default_description) - if with_record: - details['record'] = publication.get('record', default_record) - - payload = { - 'summary': summary, - 'source': alert.log_source, - 'severity': 'critical', - 'custom_details': details - } - return { - 'routing_key': routing_key, - 'payload': payload, - 'event_action': 'trigger', - 'client': 'StreamAlert' - } + + def events_v2_data(self, alert, descriptor, routing_key, with_record=True): + """Helper method to generate the payload to create an event using PagerDuty Events API v2 + + (!) NOTE: this method will not work unless this class is mixed into an OutputDispatcher + + Publishing: + By default the pagerduty event is setup with a blob of data comprising the rule + description and the record in the custom details. You can customize behavior with + the following fields: + + - @pagerduty-v2:summary (str): + Modifies the title of the event + + - @pagerduty-v2.custom_details (dict): + Fills out the pagerduty customdetails with this structure. + + (!) NOTE: Due to PagerDuty's UI, it is extremely hard to read very deeply + nested JSON dicts. It is also extremely hard to read large blobs of data. + Try to collapse deeply nested structures into single-level keys, and + try to truncate blobs of data. + + - @pagerduty-v2:severity (string): + By default the severity of alerts are "critical". You can override this with + any of the following: + 'info', 'warning', 'error', 'critical' + + Args: + descriptor (str): The descriptor of the output sending these data + alert (Alert): Alert relevant to the triggered rule + routing_key (str): Routing key for this PagerDuty integration + with_record (boolean): Option to add the record data or not + + Returns: + dict: Contains JSON blob to be used as event + """ + publication = compose_alert(alert, self, descriptor) + + # Presentation defaults + default_summary = 'StreamAlert Rule Triggered - {}'.format(alert.rule_name) + default_custom_details = OrderedDict() + default_custom_details['description'] = alert.rule_description + if with_record: + default_custom_details['record'] = alert.record + default_severity = SEVERITY_CRITICAL + + # Special field that Publishers can use to customize the header + summary = publication.get('@pagerduty-v2.summary', default_summary) + details = publication.get('@pagerduty-v2.custom_details', default_custom_details) + severity = publication.get('@pagerduty-v2.severity', default_severity) + client_url = publication.get('@pagerduty-v2.client_url', None) + images = self._standardize_images(publication.get('@pagerduty-v2.images', [])) + links = self._standardize_links(publication.get('@pagerduty-v2.links', [])) + component = publication.get('@pagerduty-v2.component', None) + group = publication.get('@pagerduty-v2.group', None) + alert_class = publication.get('@pagerduty-v2.class', None) + + # Structure: https://v2.developer.pagerduty.com/docs/send-an-event-events-api-v2 + return { + 'routing_key': routing_key, + 'event_action': 'trigger', + + # Beware of providing this; when this is provided, even if empty string, this will + # cause the dedup_key to be bound to the ALERT, not the incident. The implication + # is that the incident will no longer be searchable with incident_key=dedup_key + # 'dedup_key': '', + 'payload': { + 'summary': summary, + 'source': alert.log_source, + 'severity': severity, + 'custom_details': details, + + # When provided, must be in valid ISO 8601 format + # 'timestamp': '', + 'component': component, + 'group': group, + 'class': alert_class, + }, + 'client': 'StreamAlert', + 'client_url': client_url, + 'images': images, + 'links': links, + } + + @staticmethod + def _standardize_images(images): + """Strips invalid images out of the images argument + + Images should be dicts with 3 keys: + - src: The full http URL of the image + - href: A URL that the image opens when clicked (Optional) + - alt: Alt text (Optional) + """ + if not isinstance(images, list): + return [] + + return [ + { + # Notably, if href is provided but is an invalid URL, the entire image will + # be entirely omitted from the incident... beware. + 'src': image['src'], + 'href': image['href'] if 'href' in image else '', + 'alt': image['alt'] if 'alt' in image else '', + } + for image in images + if isinstance(image, dict) and 'src' in image + ] + + @staticmethod + def _standardize_links(links): + """Strips invalid links out of the links argument + + Images should be dicts with 2 keys: + - href: A URL of the link + - text: Text of the link (Optional: Defaults to the href if no text given) + """ + if not isinstance(links, list): + return [] + + return [ + { + 'href': link['href'], + 'text': link['text'] if 'text' in link else link['href'], + } + for link in links + if isinstance(link, dict) and 'href' in link + ] @StreamAlertOutput @@ -90,7 +194,7 @@ def _get_default_properties(cls): Returns: dict: Contains various default items for this output (ie: url) """ - return {'url': 'https://events.pagerduty.com/generic/2010-04-15/create_event.json'} + return {'url': PagerDutyEventsV1ApiClient.EVENTS_V1_API_ENDPOINT} @classmethod def get_user_defined_properties(cls): @@ -109,6 +213,9 @@ def get_user_defined_properties(cls): ('descriptor', OutputProperty(description='a short and unique descriptor for this ' 'PagerDuty integration')), + # A version 4 UUID expressed as a 32 digit hexadecimal number. This is the + # integration key for an integration on a given service and can be found on + # the pagerduty integrations UI. ('service_key', OutputProperty(description='the service key for this PagerDuty integration', mask_input=True, @@ -118,6 +225,48 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send alert to Pagerduty + Publishing: + This output can be override with the following fields: + + - @pagerduty.description (str): + The provided string will be rendered as the event's title. + + - @pagerduty.details (dict): + By default this output renders rule description and rule record in a deeply + nested json structure. You can override this with your own dict. + + (!) NOTE: Due to PagerDuty's UI, it is extremely hard to read very deeply + nested JSON dicts. It is also extremely hard to read large blobs of data. + Try to collapse deeply nested structures into single-level keys, and + try to truncate blobs of data. + + - @pagerduty.client_url (str): + A URL. It should be a link to the same alert in a different service. + When given, there will be a "view in streamalert" link given at the bottom. + Currently this 'streamalert' string is hardcoded into the api client + as the 'client' field. + + This is not included in the default implementation. + + - @pagerduty.contexts (list[dict]): + This field can be used to automatically attach images and links to the incident + event. This should be a list of dicts. Each dict should follow ONE OF these + formats: + + Link: + { + 'type': 'link', + 'href': 'https://streamalert.io/', + } + + Image embed + { + 'type': 'image', + 'src': 'https://streamalert.io/en/stable/_images/sa-complete-arch.png', + } + + This is not included in the default implementation. + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -129,31 +278,67 @@ def _dispatch(self, alert, descriptor): if not creds: return False - publication = alert.publish_for(self, descriptor) - - message = 'StreamAlert Rule Triggered - {}'.format(publication.get('rule_name', '')) - details = { - 'description': publication.get('rule_description', ''), - 'record': publication.get('record', {}) - } - data = { - 'service_key': creds['service_key'], - 'event_type': 'trigger', - 'description': message, - 'details': details, - 'client': 'StreamAlert' + # Presentation defaults + default_description = 'StreamAlert Rule Triggered - {}'.format(alert.rule_name) + default_details = { + 'description': alert.rule_description, + 'record': alert.record, } + default_contexts = [] + default_client_url = '' - try: - self._post_request_retry(creds['url'], data, None, True) - except OutputRequestFailure: - return False + # Override presentation with publisher + publication = compose_alert(alert, self, descriptor) + description = publication.get('@pagerduty.description', default_description) + details = publication.get('@pagerduty.details', default_details) + client_url = publication.get('@pagerduty.client_url', default_client_url) + contexts = publication.get('@pagerduty.contexts', default_contexts) + contexts = self._strip_invalid_contexts(contexts) - return True + http = JsonHttpProvider(self) + client = PagerDutyEventsV1ApiClient(creds['service_key'], http, api_endpoint=creds['url']) + + return client.send_event(description, details, contexts, client_url) + + @staticmethod + def _strip_invalid_contexts(contexts): + """When an array of contexts, will return a new array containing only valid ones.""" + if not isinstance(contexts, list): + LOGGER.warning('Invalid pagerduty.contexts provided: Not an array') + return [] + + def is_valid_context(context): + if not 'type' in context: + return False + + if context['type'] == 'link': + if 'href' not in context or 'text' not in context: + return False + elif context['type'] == 'image': + if 'src' not in context: + return False + else: + return False + + return True + + def standardize_context(context): + if context['type'] == 'link': + return { + 'type': 'link', + 'href': context['href'], + 'text': context['text'], + } + return { + 'type': 'image', + 'src': context['src'], + } + + return [standardize_context(x) for x in contexts if is_valid_context(x)] @StreamAlertOutput -class PagerDutyOutputV2(OutputDispatcher): +class PagerDutyOutputV2(OutputDispatcher, EventsV2DataProvider): """PagerDutyOutput handles all alert dispatching for PagerDuty Events API v2""" __service__ = 'pagerduty-v2' @@ -165,7 +350,7 @@ def _get_default_properties(cls): Returns: dict: Contains various default items for this output (ie: url) """ - return {'url': 'https://events.pagerduty.com/v2/enqueue'} + return {'url': PagerDutyEventsV2ApiClient.EVENTS_V2_API_ENQUEUE_ENDPOINT} @classmethod def get_user_defined_properties(cls): @@ -195,6 +380,9 @@ def get_user_defined_properties(cls): def _dispatch(self, alert, descriptor): """Send alert to Pagerduty + Publishing: + @see EventsV2DataProvider for more details + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -206,23 +394,32 @@ def _dispatch(self, alert, descriptor): if not creds: return False - data = events_v2_data(self, descriptor, alert, creds['routing_key']) + data = self.events_v2_data(alert, descriptor, creds['routing_key']) - try: - self._post_request_retry(creds['url'], data, None, True) - except OutputRequestFailure: + http = JsonHttpProvider(self) + client = PagerDutyEventsV2ApiClient(http, enqueue_endpoint=creds['url']) + + result = client.enqueue_event(data) + + if result is False: return False return True -class PagerdutySearchDelay(Exception): - """PagerdutyAlertDelay handles any delays looking up PagerDuty Incidents""" +@StreamAlertOutput +class PagerDutyIncidentOutput(OutputDispatcher, EventsV2DataProvider): + """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents REST API + In addition to using the REST API, this PagerDuty implementation also performs automatic + assignment of the incident, based upon context parameters. -@StreamAlertOutput -class PagerDutyIncidentOutput(OutputDispatcher): - """PagerDutyIncidentOutput handles all alert dispatching for PagerDuty Incidents API v2""" + context = { + 'assigned_user': 'somebody@somewhere.somewhere', + 'with_record': True|False, + 'note': 'String goes here' + } + """ __service__ = 'pagerduty-incident' INCIDENTS_ENDPOINT = 'incidents' USERS_ENDPOINT = 'users' @@ -247,7 +444,7 @@ def _get_default_properties(cls): Returns: dict: Contains various default items for this output (ie: url) """ - return {'api': 'https://api.pagerduty.com'} + return {'api': PagerDutyRestApiClient.REST_API_BASE_URL} @classmethod def get_user_defined_properties(cls): @@ -268,6 +465,10 @@ def get_user_defined_properties(cls): ('descriptor', OutputProperty(description='a short and unique descriptor for this ' 'PagerDuty integration')), + # The REST API Access Token. This needs to be generated through the PagerDuty console. + # Unlike the routing key this token is EXTREMELY IMPORTANT NOT TO LOSE as it grants + # access to all resources on PagerDuty, whereas the routing key only allows + # the generation of new events. ('token', OutputProperty(description='the token for this PagerDuty integration', mask_input=True, @@ -275,6 +476,11 @@ def get_user_defined_properties(cls): ('service_name', OutputProperty(description='the service name for this PagerDuty integration', cred_requirement=True)), + # The service ID is the unique resource ID of a PagerDuty service, created through + # the UI. You can find the service id by looking at the URL: + # - www.pagerduty.com/services/PDBBCC9 + # + # In the above case, the service id is 'PDBBCC9' ('service_id', OutputProperty(description='the service ID for this PagerDuty integration', cred_requirement=True)), @@ -282,226 +488,364 @@ def get_user_defined_properties(cls): OutputProperty(description='the name of the default escalation policy', input_restrictions={}, cred_requirement=True)), + # The escalation policy ID is the unique resource ID of a PagerDuty escalation policy, + # created through the UI. You can find it on the URL: + # - www.pagerduty.com/escalation_policies#PDBBBB0 + # + # In the above case, the escalation policy id is PDBBBB0 ('escalation_policy_id', OutputProperty(description='the ID of the default escalation policy', cred_requirement=True)), + # This must exactly match the email address of a user on the PagerDuty account. ('email_from', OutputProperty(description='valid user email from the PagerDuty ' 'account linked to the token', cred_requirement=True)), + # A version 4 UUID expressed as a 32 digit hexadecimal number. This is the same + # as the routing key that is used in the v2 Events API. ('integration_key', OutputProperty(description='the integration key for this PagerDuty integration', cred_requirement=True)) ]) - @staticmethod - def _get_endpoint(base_url, endpoint): - """Helper to get the full url for a PagerDuty Incidents endpoint. + def _dispatch(self, alert, descriptor): + """Send incident to Pagerduty Incidents REST API v2 + + Context: + + - with_record (bool): + - note (bool): + + Publishing: + This output has a more complex workflow. The magic publisher fields for @pagerduty-v2 + ALSO are respected by this output. + + - @pagerduty-incident.incident_title (str): + The provided string will override the PARENT INCIDENT's title. The child Alert's + title is controlled by other publisher magic fields. + + - @pagerduty-incident.incident_body (str): + This is text that shows up in the body of the newly created incident. + + (!) NOTE: Due to the way incidents are merged, this text is almost never + displayed properly on PagerDuty's UI. The only instance where it + shows up correctly is when incident merging fails and the newly + created incident does not have an alert attached to it. + + - @pagerduty-incident.note (str): + Due to legacy reasons, this PagerDuty services adds a note containing + "Creating SOX Incident" to the final PagerDuty incident. Providing a string + to this magic field will override that behavior. + + - @pagerduty-incident.urgency (str): + Either "low" or "high". By default urgency is "high" for all incidents. + + + In addition, the final event that is merged into the parent incident can be customized + as well. + @see EventsV2DataProvider for more details + Args: - base_url (str): Base URL for the API - endpoint (str): Endpoint that we want the full URL for + alert (Alert): Alert instance which triggered a rule + descriptor (str): Output descriptor Returns: - str: Full URL of the provided endpoint + bool: True if alert was sent successfully, False otherwise """ - return os.path.join(base_url, endpoint) + creds = self._load_creds(descriptor) + if not creds: + return - def _create_event(self, data): - """Helper to create an event in the PagerDuty Events API v2 + work = WorkContext(self, creds) + return work.run(alert, descriptor) - Args: - data (dict): JSON blob with the format of the PagerDuty Events API v2 - Returns: - dict: Contains the HTTP response of the request to the API + +class WorkContext(object): + """Class encapsulating a bunch of self-contained, interdependent PagerDuty work. + + Because PagerDuty work involves a lot of steps that share a lot of data, we carved this + section out. + """ + BACKOFF_MAX = 5 + BACKOFF_TIME = 5 + + def __init__(self, output_dispatcher, credentials): + self._output = output_dispatcher + self._credentials = credentials + + self._email_from = self._credentials['email_from'] + self._default_escalation_policy_id = self._credentials['escalation_policy_id'] + self._incident_service = self._credentials['service_id'] + + http = JsonHttpProvider(output_dispatcher) + self._api_client = PagerDutyRestApiClient( + self._credentials['token'], + self._credentials['email_from'], + http, + url=self._credentials['api'] + ) + self._events_client = PagerDutyEventsV2ApiClient(http) + + def run(self, alert, descriptor): + """Sets up an assigned incident. + + FIXME (derek.wang): + This work routine is a large, non-atomic set of jobs that can sometimes partially fail. + Partial failures can have side effects on PagerDuty, including the creation of + incomplete or partially complete alerts. Because the Alert Processor will automatically + retry the entire routine from scratch, this can cause partial alerts to get created + redundantly forever. The temporary solution is to delete the erroneous record from + DynamoDB manually, but in the future we should consider writing back state into the + DynamoDB alert record to track the steps involved in "fulfilling" the dispatch of this + alert. """ - url = 'https://events.pagerduty.com/v2/enqueue' - try: - resp = self._post_request_retry(url, data, None, False) - except OutputRequestFailure: + if not self.verify_user_exists(): return False - response = resp.json() - if not response: + # Extracting context data to assign the incident + rule_context = alert.context + if rule_context: + rule_context = rule_context.get(self._output.__service__, {}) + + publication = compose_alert(alert, self._output, descriptor) + + incident = self._create_base_incident(alert, publication, rule_context) + incident_id = incident.get('id') + if not incident or not incident_id: + LOGGER.error('[%s] Could not create main incident', self._output.__service__) return False - return response + # Create alert to hold all the incident details + event = self._create_base_alert_event(alert, descriptor, rule_context) + if not event: + LOGGER.error('[%s] Could not create incident event', self._output.__service__) + return False - @backoff.on_exception(backoff.constant, - PagerdutySearchDelay, - max_tries=BACKOFF_MAX, - interval=BACKOFF_TIME, - on_backoff=backoff_handler(), - on_success=success_handler(), - on_giveup=giveup_handler()) - def _get_event_incident_id(self, incident_key): - """Helper to lookup an incident using the incident_key and return the id + # FIXME (derek.wang), see above + # At this point, both the incident and the relevant alert event have been successfully + # created. Any further work that fails the dispatch call will cause the alert to retry + # and redundantly create more incidents and alert events. + # Therefore, the hack is to simply let further failures go by always returning True. + # The tradeoff is that incidents can be created on pagerduty in an incomplete state, + # but this is easier to manage than StreamAlert redundantly creating hundreds (or more!) + # redundant alerts. + stable = True - Args: - incident_key (str): Incident key that indentifies uniquely an incident + # Merge the incident with the event, so we can have a rich context incident + # assigned to a specific person, which the PagerDuty REST API v2 does not allow + merged_incident = self._merge_event_into_incident(incident, event) + if not merged_incident: + LOGGER.error( + '[%s] Failed to merge alert [%s] into [%s]', + self._output.__service__, + event.get('dedup_key'), + incident_id + ) + stable = False + + if merged_incident: + note = self._add_incident_note(incident, publication, rule_context) + if not note: + LOGGER.error( + '[%s] Failed to add note to incident (%s)', + self._output.__service__, + incident_id + ) + stable = False + + # If something went wrong, we can't throw an error anymore; log it on the Incident + if not stable: + self._add_instability_note(incident_id) - Returns: - str: ID of the incident after look up the incident_key + return True + + def _add_instability_note(self, incident_id): + instability_note = ''' +StreamAlert failed to correctly setup this incident. Please contact your StreamAlert administrator. + '''.strip() + self._api_client.add_note(incident_id, instability_note) + + def _create_base_incident(self, alert, publication, rule_context): + """Creates a container incident for this alert + In PagerDuty's REST API design, Incidents are designed to behave like containers for many + alerts. Unlike alerts, which must obey service escalation policies, Incidents can be given + custom assignments. + + Returns the newly created incident as a JSON dict. Returns False if anything goes wrong. """ - params = { - 'incident_key': incident_key - } - incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT) - response = self._generic_api_get(incidents_url, params) - incident = response.get('incidents', []) - if not incident: - raise PagerdutySearchDelay() + # Presentation defaults + default_incident_title = 'StreamAlert Incident - Rule triggered: {}'.format(alert.rule_name) + default_incident_body = alert.rule_description + default_urgency = None # Assumes the default urgency on the service referenced - return incident[0].get('id') + # Override presentation defaults with publisher fields + incident_title = publication.get( + '@pagerduty-incident.incident_title', + default_incident_title + ) + incident_body = publication.get('@pagerduty-incident.incident_body', default_incident_body) + incident_urgency = publication.get('@pagerduty-incident.urgency', default_urgency) - def _merge_incidents(self, url, to_be_merged_id): - """Helper to merge incidents by id using the PagerDuty REST API v2 + # FIXME (derek.wang) use publisher to set priority instead of context + # Use the priority provided in the context, use it or the incident will be low priority + incident_priority = self.get_standardized_priority(rule_context) - Args: - url (str): The url to send the requests to in the API - to_be_merged_id (str): ID of the incident to merge with + # FIXME (derek.wang) use publisher to set priority instead of context + assigned_key, assigned_value = self.get_incident_assignment(rule_context) - Returns: - dict: Contains the HTTP response of the request to the API - """ - params = { - 'source_incidents': [ - { - 'id': to_be_merged_id, - 'type': 'incident_reference' - } - ] + # https://api-reference.pagerduty.com/#!/Incidents/post_incidents + incident_data = { + 'incident': { + 'type': 'incident', + 'title': incident_title, + 'service': { + 'id': self._incident_service, + 'type': 'service_reference' + }, + 'priority': incident_priority, + 'incident_key': '', + 'body': { + 'type': 'incident_body', + 'details': incident_body, + }, + assigned_key: assigned_value + } } - try: - resp = self._put_request_retry(url, params, self._headers, False) - except OutputRequestFailure: - return False - response = resp.json() - if not response: - return False + # Urgency, if provided, must always be 'high' or 'low' or the API will error + if incident_urgency: + if incident_urgency in ['low', 'high']: + incident_data['incident']['urgency'] = incident_urgency + else: + LOGGER.warn( + '[%s] Invalid pagerduty incident urgency: "%s"', + self._output.__service__, + incident_urgency + ) - return response + return self._api_client.create_incident(incident_data) - def _generic_api_get(self, url, params): - """Helper to submit generic GET requests with parameters to the PagerDuty REST API v2 + def _create_base_alert_event(self, alert, descriptor, rule_context): + """Creates an alert on REST API v2 - Args: - url (str): The url to send the requests to in the API + Returns the JSON representation of the ENQUEUE RESPONSE. This actually does not return + either the alert nor the incident itself, but rather a small acknowledgement structure + containing a "dedup_key". This key can be used to find the incident that is created. - Returns: - dict: Contains the HTTP response of the request to the API + Returns False if event was not created. """ - try: - resp = self._get_request_retry(url, params, self._headers, False) - except OutputRequestFailure: - return False - - response = resp.json() - if not response: - return False - - return response + with_record = rule_context.get('with_record', True) + event_data = self._output.events_v2_data( + alert, + descriptor, + self._credentials['integration_key'], + with_record + ) - def _check_exists(self, filter_str, url, target_key, get_id=True): - """Generic method to run a search in the PagerDuty REST API and return the id - of the first occurence from the results. + return self._events_client.enqueue_event(event_data) - Args: - filter_str (str): The query filter to search for in the API - url (str): The url to send the requests to in the API - target_key (str): The key to extract in the returned results - get_id (boolean): Whether to generate a dict with result and reference + def _merge_event_into_incident(self, incident, event): + """Merges the given event into the incident. - Returns: - str: ID of the targeted element that matches the provided filter or - True/False whether a matching element exists or not. + Returns the final, merged incident as a JSON dict. Returns False if anything goes wrong. """ - params = { - 'query': filter_str - } - response = self._generic_api_get(url, params) - if not response: + # Extract the incident id from the incident that was just created + incident_id = incident.get('id') + if not incident_id: + LOGGER.error('[%s] Incident missing Id?', self._output.__service__) return False - if not get_id: - return True + # Lookup the incident_key returned as dedup_key to get the incident id + incident_key = event.get('dedup_key') + if not incident_key: + LOGGER.error('[%s] Event missing dedup_key', self._output.__service__) + return False - # We need the list to have elements - target_element = response.get(target_key, []) - if not target_element: + # Keep that id to be merged later with the created incident + event_incident_id = self.get_incident_id_from_event_incident_key(incident_key) + if not event_incident_id: + LOGGER.error( + '[%s] Failed to retrieve Event Incident Id from dedup_key (%s)', + self._output.__service__, + incident_key + ) return False - # If there are results, get the first occurence from the list - return target_element[0].get('id', False) + # Merge the incident with the event, so we can have a rich context incident + # assigned to a specific person, which the PagerDuty REST API v2 does not allow + return self._api_client.merge_incident(incident_id, event_incident_id) - def _user_verify(self, user, get_id=True): - """Method to verify the existance of an user with the API - Args: - user (str): User to query about in the API. - get_id (boolean): Whether to generate a dict with result and reference - Returns: - dict or False: JSON object be used in the API call, containing the user_id - and user_reference. False if user is not found - """ - return self._item_verify(user, self.USERS_ENDPOINT, 'user_reference', get_id) + def _add_incident_note(self, incident, publication, rule_context): + """Adds a note to the incident, when applicable. - def _policy_verify(self, policy, default_policy): - """Method to verify the existance of a escalation policy with the API - Args: - policy (str): Escalation policy to query about in the API - default_policy (str): Escalation policy to use if the first one is not verified Returns: - dict: JSON object be used in the API call, containing the policy_id - and escalation_policy_reference + bool: True if the note was created or no note needed to be created, False on error. """ - verified = self._item_verify(policy, self.POLICIES_ENDPOINT, 'escalation_policy_reference') - # If the escalation policy provided is not verified in the API, use the default - if verified: - return verified + # Add a note to the combined incident to help with triage + merged_id = incident.get('id') + if not merged_id: + LOGGER.error('[%s] Merged incident missing Id?', self._output.__service__) + return False - return self._item_verify(default_policy, self.POLICIES_ENDPOINT, - 'escalation_policy_reference') + default_incident_note = 'Creating SOX Incident' # For reverse compatibility reasons + incident_note = publication.get( + '@pagerduty-incident.note', + rule_context.get( + 'note', + default_incident_note + ) + ) - def _service_verify(self, service): - """Method to verify the existance of a service with the API + if not incident_note: + # Simply return early without adding a note; no need to add a blank one + return True - Args: - service (str): Service to query about in the API + return bool(self._api_client.add_note(merged_id, incident_note)) - Returns: - dict: JSON object be used in the API call, containing the service_id - and the service_reference - """ - return self._item_verify(service, self.SERVICES_ENDPOINT, 'service_reference') - def _item_verify(self, item_str, item_key, item_type, get_id=True): - """Method to verify the existance of an item with the API - Args: - item_str (str): Service to query about in the API - item_key (str): Endpoint/key to be extracted from search results - item_type (str): Type of item reference to be returned - get_id (boolean): Whether to generate a dict with result and reference - Returns: - dict: JSON object be used in the API call, containing the item id - and the item reference, True if it just exists or False if it fails + @backoff.on_exception(backoff.constant, + PagerdutySearchDelay, + max_tries=BACKOFF_MAX, + interval=BACKOFF_TIME, + on_backoff=backoff_handler(), + on_success=success_handler(), + on_giveup=giveup_handler()) + def get_incident_id_from_event_incident_key(self, incident_key): + """Queries the API to get the incident id from an incident key + + When creating an EVENT from the events-v2 API, events are created alongside an incident, + but only an incident_key is returned, which is not the same as the incident's REST API + resource id. """ - item_url = self._get_endpoint(self._base_url, item_key) - item_id = self._check_exists(item_str, item_url, item_key, get_id) - if not item_id: - LOGGER.info('%s not found in %s, %s', item_str, item_key, self.__service__) + if not incident_key: return False - if get_id: - return {'id': item_id, 'type': item_type} + event_incident = self._api_client.get_incident_by_key(incident_key) + if not event_incident: + raise PagerdutySearchDelay('Received no PagerDuty response') - return item_id + return event_incident.get('id') - def _priority_verify(self, context): - """Method to verify the existance of a incident priority with the API + def verify_user_exists(self): + """Verifies that the 'email_from' provided in the creds is valid and exists.""" + user = self._api_client.get_user_by_email(self._email_from) + + if not user: + LOGGER.error( + 'Could not verify header From: %s, %s', + self._email_from, + self._output.__service__ + ) + return False + + return True + + def get_standardized_priority(self, context): + """Method to verify the existence of a incident priority with the API Args: context (dict): Context provided in the alert record @@ -517,25 +861,15 @@ def _priority_verify(self, context): if not priority_name: return dict() - priorities_url = self._get_endpoint(self._base_url, self.PRIORITIES_ENDPOINT) - - try: - resp = self._get_request_retry(priorities_url, {}, self._headers, False) - except OutputRequestFailure: - return dict() - - response = resp.json() - if not response: - return dict() - - priorities = response.get('priorities', []) + priorities = self._api_client.get_priorities() if not priorities: return dict() # If the requested priority is in the list, get the id priority_id = next( - (item for item in priorities if item["name"] == priority_name), {}).get('id', False) + (item for item in priorities if item["name"] == priority_name), {} + ).get('id', False) # If the priority id is found, compose the JSON if priority_id: @@ -543,9 +877,12 @@ def _priority_verify(self, context): return dict() - def _incident_assignment(self, context): + def get_incident_assignment(self, context): """Method to determine if the incident gets assigned to a user or an escalation policy + Incident assignment goes in this order: + Provided user -> provided policy -> default escalation policy + Args: context (dict): Context provided in the alert record @@ -558,175 +895,425 @@ def _incident_assignment(self, context): # If provided, verify the user and get the id from API if user_to_assign: - user_assignee = self._user_verify(user_to_assign) - # User is verified, return tuple - if user_assignee: - return 'assignments', [{'assignee': user_assignee}] + user = self._api_client.get_user_by_email(user_to_assign) + if user and user.get('id'): + return 'assignments', [{'assignee': { + 'id': user.get('id'), + 'type': 'user_reference', + }}] # If escalation policy ID was not provided, use default one - policy_id_to_assign = context.get('assigned_policy_id', self._escalation_policy_id) + policy_id_to_assign = context.get( + 'assigned_policy_id', + self._default_escalation_policy_id + ) - # Assinged to escalation policy ID, return tuple + # Assigned to escalation policy ID, return tuple return 'escalation_policy', { 'id': policy_id_to_assign, 'type': 'escalation_policy_reference'} - def _add_incident_note(self, incident_id, note): - """Method to add a text note to the provided incident id +# pylint: disable=protected-access +class JsonHttpProvider(object): + """Wraps and re-uses the HTTP implementation on the output dispatcher. + + Intended to de-couple the ApiClient classes from the OutputDispatcher. It re-uses some + HTTP implementation that's baked into the OutputDispatcher. It is safe to ignore the + breach-of-abstraction violations here. + """ + + def __init__(self, output_dispatcher): + self._output_dispatcher = output_dispatcher + + def get(self, url, params, headers=None, verify=False): + """Returns the JSON response of the given request, or FALSE on failure""" + try: + result = self._output_dispatcher._get_request_retry(url, params, headers, verify) + except OutputRequestFailure: + return False + + response = result.json() + if not response: + return False + + return response + + def post(self, url, data, headers=None, verify=False): + """Returns the JSON response of the given request, or FALSE on failure""" + try: + result = self._output_dispatcher._post_request_retry(url, data, headers, verify) + except OutputRequestFailure: + return False + + response = result.json() + if not response: + return False + + return response + + def put(self, url, params, headers=None, verify=False): + """Returns the JSON response of the given request, or FALSE on failure""" + try: + result = self._output_dispatcher._put_request_retry(url, params, headers, verify) + except OutputRequestFailure: + return False + + response = result.json() + if not response: + return False + + return response + + +class SslVerifiable(object): + """Mixin for tracking whether or not this is an SSL verifiable. + + Mix this into API client types of classes. + + The idea is to only do host ssl certificate verification on the very first time a unique + host is hit, since the handshake process is slow. Subsequent requests to the same host + within the current request can void certificate verification to speed things up. + """ + def __init__(self): + self._host_ssl_verified = False + + def _should_do_ssl_verify(self): + """Returns whether or not the client should perform SSL host cert verification""" + return not self._host_ssl_verified + + def _update_ssl_verified(self, response): + """ Args: - incident_id (str): ID of the incident to add the note to + response (dict|bool): A return value from JsonHttpProvider Returns: - str: ID of the note after being added to the incident or False if it fails + dict|bool: Simply returns the response as-is + """ + if response is not False: + self._host_ssl_verified = True + + return response + + +class PagerDutyRestApiClient(SslVerifiable): + """API Client class for the PagerDuty REST API + + API Documentation can be found here: https://v2.developer.pagerduty.com/docs/rest-api + """ + + REST_API_BASE_URL = 'https://api.pagerduty.com' + + def __init__(self, authorization_token, user_email, http_provider, url=None): + super(PagerDutyRestApiClient, self).__init__() + + self._authorization_token = authorization_token + self._user_email = user_email + self._http_provider = http_provider # type: JsonHttpProvider + self._base_url = url if url else self.REST_API_BASE_URL + + def get_user_by_email(self, user_email): + """Fetches a pagerduty user by an email address. + + Returns false on failure or if no matching user is found. + """ + response = self._http_provider.get( + self._get_users_url(), + { + 'query': user_email, + }, + self._construct_headers(omit_email=True), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(response) + + if not response: + return False + + users = response.get('users', []) + + return users[0] if users else False + + def get_incident_by_key(self, incident_key): + """Fetches an incident resource given its key + + Returns False on failure or if no matching incident is found. + """ + incidents = self._http_provider.get( + self._get_incidents_url(), + { + 'incident_key': incident_key # Beware: this key is intentionally not "query" + }, + headers=self._construct_headers(), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(incidents) + + if not incidents: + return False + + incidents = incidents.get('incidents', []) + + return incidents[0] if incidents else False + + def get_priorities(self): + """Returns a list of all valid priorities""" + priorities = self._http_provider.get( + self._get_priorities_url(), + None, + headers=self._construct_headers(), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(priorities) + + if not priorities: + return False + + return priorities.get('priorities', []) + + def get_escalation_policy_by_id(self, escalation_policy_id): + """Given an escalation policy id, returns the resource + + Returns False on failure or if no escalation policy exists with that id + """ + escalation_policies = self._http_provider.get( + self._get_escalation_policies_url(), + { + 'query': escalation_policy_id, + }, + headers=self._construct_headers(), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(escalation_policies) + + if not escalation_policies: + return False + + escalation_policies = escalation_policies.get('escalation_policies', []) + + return escalation_policies[0] if escalation_policies else False + + def merge_incident(self, parent_incident_id, merged_incident_id): + """Given two incident ids, notifies PagerDuty to merge them into a single incident + + Returns the json representation of the merged incident, or False on failure. """ - notes_path = '{}/{}/notes'.format(self.INCIDENTS_ENDPOINT, incident_id) - incident_notes_url = self._get_endpoint(self._base_url, notes_path) data = { - 'note': { - 'content': note - } + 'source_incidents': [ + { + 'id': merged_incident_id, + 'type': 'incident_reference' + } + ] } - try: - resp = self._post_request_retry(incident_notes_url, data, self._headers, True) - except OutputRequestFailure: + merged_incident = self._http_provider.put( + self._get_incident_merge_url(parent_incident_id), + data, + headers=self._construct_headers(), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(merged_incident) + + if not merged_incident: return False - response = resp.json() - if not response: + return merged_incident.get('incident', False) + + def create_incident(self, incident_data): + """Creates a new incident + + Returns the incident json representation on success, or False on failure. + + Reference: https://api-reference.pagerduty.com/#!/Incidents/post_incidents + + (!) FIXME (derek.wang) + The legacy implementation utilizes this POST /incidents endpoint to create + incidents and merge them with events created through the events-v2 API, but + the PagerDuty API documentation explicitly says to NOT use the REST API + to create incidents. Research if our use of the POST /incidents endpoint is + incorrect. + Reference: https://v2.developer.pagerduty.com/docs/getting-started + + Args: + incident_data (dict) + + Returns: + dict + """ + incident = self._http_provider.post( + self._get_incidents_url(), + incident_data, + headers=self._construct_headers(), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(incident) + + if not incident: return False - note_rec = response.get('note', {}) + return incident.get('incident', False) + + def add_note(self, incident_id, note): + """Method to add a text note to the provided incident id - return note_rec.get('id', False) + Returns the note json representation on success, or False on failure. - def _dispatch(self, alert, descriptor): - """Send incident to Pagerduty Incidents API v2 + Reference: https://api-reference.pagerduty.com/#!/Incidents/post_incidents_id_notes Args: - alert (Alert): Alert instance which triggered a rule - descriptor (str): Output descriptor + incident_id (str): ID of the incident to add the note to Returns: - bool: True if alert was sent successfully, False otherwise + str: ID of the note after being added to the incident or False if it fails """ - creds = self._load_creds(descriptor) - if not creds: + note = self._http_provider.post( + self._get_incident_notes_url(incident_id), + { + 'note': { + 'content': note, + } + }, + self._construct_headers(), + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(note) + + if not note: return False - # Cache base_url - self._base_url = creds['api'] + return note.get('note', False) + + def _construct_headers(self, omit_email=False): + """Returns a dict containing all headers to send for PagerDuty requests - # Preparing headers for API calls - self._headers = { + PagerDuty performs validation on the email provided in the From: header. PagerDuty will + error if the requested email does not exist. In one specific case, we do not want this to + happen; when we are querying for the existence of a user with this email. + """ + headers = { 'Accept': 'application/vnd.pagerduty+json;version=2', + 'Authorization': 'Token token={}'.format(self._authorization_token), 'Content-Type': 'application/json', - 'Authorization': 'Token token={}'.format(creds['token']) } + if not omit_email: + headers['From'] = self._user_email - # Get user email to be added as From header and verify - user_email = creds['email_from'] - if not self._user_verify(user_email, False): - LOGGER.error('Could not verify header From: %s, %s', user_email, self.__service__) - return False + return headers - # Add From to the headers after verifying - self._headers['From'] = user_email + def _get_escalation_policies_url(self): + return '{base_url}/escalation_policies'.format(base_url=self._base_url) - # Cache default escalation policy - self._escalation_policy_id = creds['escalation_policy_id'] + def _get_priorities_url(self): + return '{base_url}/priorities'.format(base_url=self._base_url) - # Extracting context data to assign the incident - publication = alert.publish_for(self, descriptor) + def _get_incidents_url(self): + return '{base_url}/incidents'.format(base_url=self._base_url) - rule_context = alert.context - if rule_context: - rule_context = rule_context.get(self.__service__, {}) + def _get_incident_url(self, incident_id): + return '{incidents_url}/{incident_id}'.format( + incidents_url=self._get_incidents_url(), + incident_id=incident_id + ) - # Presentation defaults - default_incident_title = 'StreamAlert Incident - Rule triggered: {}'.format( - alert.rule_name + def _get_incident_merge_url(self, incident_id): + return '{incident_url}/merge'.format(incident_url=self._get_incident_url(incident_id)) + + def _get_incident_notes_url(self, incident_id): + return '{incident_url}/notes'.format(incident_url=self._get_incident_url(incident_id)) + + def _get_users_url(self): + return '{base_url}/users'.format(base_url=self._base_url) + + +class PagerDutyEventsV2ApiClient(SslVerifiable): + """Service for finding URLs of various resources on the Events v2 API + + Documentation on Events v2 API: https://v2.developer.pagerduty.com/docs/events-api-v2 + """ + + EVENTS_V2_API_ENQUEUE_ENDPOINT = 'https://events.pagerduty.com/v2/enqueue' + + def __init__(self, http_provider, enqueue_endpoint=None): + super(PagerDutyEventsV2ApiClient, self).__init__() + + self._http_provider = http_provider # type: JsonHttpProvider + self._enqueue_endpoint = ( + enqueue_endpoint if enqueue_endpoint else self.EVENTS_V2_API_ENQUEUE_ENDPOINT ) - default_incident_body = { - 'type': 'incident_body', - 'details': alert.rule_description, - } - # Override presentation defaults with publisher fields - incident_title = publication.get( - 'pagerduty-incident.incident_title', - default_incident_title + def enqueue_event(self, event_data): + """Enqueues a new event. + + Returns the event json representation on success, or False on failure. + + Note: For API v2, all authentication information is baked directly into the event_data, + rather than being provided in the headers. + """ + event = self._http_provider.post( + self._get_event_enqueue_v2_url(), + event_data, + headers=None, + verify=self._should_do_ssl_verify() ) - incident_body = publication.get('pagerduty-incident.incident_body', default_incident_body) + self._update_ssl_verified(event) - # FIXME (derek.wang) use publisher to set priority instead of context - # Use the priority provided in the context, use it or the incident will be low priority - incident_priority = self._priority_verify(rule_context) + return event - # FIXME (derek.wang) use publisher to set priority instead of context - # Incident assignment goes in this order: - # Provided user -> provided policy -> default policy - assigned_key, assigned_value = self._incident_assignment(rule_context) + def _get_event_enqueue_v2_url(self): + if self._enqueue_endpoint: + return self._enqueue_endpoint - # Using the service ID for the PagerDuty API - incident_service = {'id': creds['service_id'], 'type': 'service_reference'} - incident_data = { - 'incident': { - 'type': 'incident', - 'title': incident_title, - 'service': incident_service, - 'priority': incident_priority, - 'body': incident_body, - assigned_key: assigned_value - } - } - incidents_url = self._get_endpoint(self._base_url, self.INCIDENTS_ENDPOINT) + return '{}'.format(self.EVENTS_V2_API_ENQUEUE_ENDPOINT) - try: - incident = self._post_request_retry(incidents_url, incident_data, self._headers, True) - except OutputRequestFailure: - incident = False - if not incident: - LOGGER.error('Could not create main incident, %s', self.__service__) - return False +class PagerDutyEventsV1ApiClient(SslVerifiable): + """Service for finding URLs of various resources on the Events v1 API - # Extract the json blob from the response, returned by self._post_request_retry - incident_json = incident.json() - if not incident_json: - return False + API Documentation can be found here: https://v2.developer.pagerduty.com/docs/events-api + """ - # Extract the incident id from the incident that was just created - incident_id = incident_json.get('incident', {}).get('id') + EVENTS_V1_API_ENDPOINT = 'https://events.pagerduty.com/generic/2010-04-15/create_event.json' - # Create alert to hold all the incident details - with_record = rule_context.get('with_record', True) - event_data = events_v2_data(self, descriptor, alert, creds['integration_key'], with_record) - event = self._create_event(event_data) - if not event: - LOGGER.error('Could not create incident event, %s', self.__service__) - return False + EVENT_TYPE_TRIGGER = 'trigger' + EVENT_TYPE_ACKNOWLEDGE = 'acknowledge' + EVENT_TYPE_RESOLVE = 'resolve' - # Lookup the incident_key returned as dedup_key to get the incident id - incident_key = event.get('dedup_key') + CLIENT_STREAMALERT = 'streamalert' - if not incident_key: - LOGGER.error('Could not get incident key, %s', self.__service__) - return False + def __init__(self, service_key, http_provider, api_endpoint=None): + super(PagerDutyEventsV1ApiClient, self).__init__() - # Keep that id to be merged later with the created incident - event_incident_id = self._get_event_incident_id(incident_key) + self._service_key = service_key + self._http_provider = http_provider # type: JsonHttpProvider + self._api_endpoint = api_endpoint if api_endpoint else self.EVENTS_V1_API_ENDPOINT - # Merge the incident with the event, so we can have a rich context incident - # assigned to a specific person, which the PagerDuty REST API v2 does not allow - merging_url = '{}/{}/merge'.format(incidents_url, incident_id) - merged = self._merge_incidents(merging_url, event_incident_id) + def send_event(self, incident_description, incident_details, contexts, client_url=''): + """ + Args: + incident_description (str): The title of the alert + incident_details (dict): Arbitrary JSON object that is rendered in custom details field + contexts (array): Array of context dicts, which can be used to embed links or images. + client_url (string): An external URL that appears as a link on the event. - # Add a note to the combined incident to help with triage - if not merged: - LOGGER.error('Could not add note to incident, %s', self.__service__) - else: - merged_id = merged.get('incident', {}).get('id') - note = rule_context.get('note', 'Creating SOX Incident') - self._add_incident_note(merged_id, note) + Return: + dict: The JSON representation of the created event + """ + # Structure of body: https://v2.developer.pagerduty.com/docs/trigger-events + data = { + 'service_key': self._service_key, + 'event_type': self.EVENT_TYPE_TRIGGER, + + 'description': incident_description, + 'details': incident_details, + 'client': self.CLIENT_STREAMALERT, + 'client_url': client_url, + 'contexts': contexts, + } + result = self._http_provider.post( + self._api_endpoint, + data, + headers=None, + verify=self._should_do_ssl_verify() + ) + self._update_ssl_verified(result) - return True + return result diff --git a/stream_alert/alert_processor/outputs/phantom.py b/stream_alert/alert_processor/outputs/phantom.py index 1ce020748..cbc0b639a 100644 --- a/stream_alert/alert_processor/outputs/phantom.py +++ b/stream_alert/alert_processor/outputs/phantom.py @@ -16,6 +16,7 @@ from collections import OrderedDict import os +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -135,6 +136,10 @@ def _setup_container(cls, rule_name, rule_description, base_url, headers): def _dispatch(self, alert, descriptor): """Send alert to Phantom + Publishing: + By default this output sends the current publication in as JSON to Phantom. + There is no "magic" field to "override" it: Simply publish what you want to send! + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -146,7 +151,7 @@ def _dispatch(self, alert, descriptor): if not creds: return False - publication = alert.publish_for(self, descriptor) + publication = compose_alert(alert, self, descriptor) record = alert.record headers = {"ph-auth-token": creds['ph_auth_token']} diff --git a/stream_alert/alert_processor/outputs/slack.py b/stream_alert/alert_processor/outputs/slack.py index fab4b5ea4..5865713aa 100644 --- a/stream_alert/alert_processor/outputs/slack.py +++ b/stream_alert/alert_processor/outputs/slack.py @@ -16,6 +16,7 @@ import cgi from collections import OrderedDict +from stream_alert.alert_processor.helpers import compose_alert, elide_string_middle from stream_alert.alert_processor.outputs.output_base import ( OutputDispatcher, OutputProperty, @@ -67,11 +68,11 @@ def _split_attachment_text(cls, alert_record): """Yield messages that should be sent to slack. Args: - alert_record (dict): Dictionary represntation of the alert + alert_record (dict): Dictionary representation of the alert relevant to the triggered rule Yields: - str: Properly split messages to be sent as attachemnts to slack + str: Properly split messages to be sent as attachments to slack """ # Convert the alert we have to a nicely formatted string for slack alert_text = '\n'.join(cls._json_to_slack_mrkdwn(alert_record, 0)) @@ -100,23 +101,25 @@ def _split_attachment_text(cls, alert_record): alert_text = alert_text[index+1:] @classmethod - def _format_attachments(cls, alert_publication, header_text): + def _format_default_attachments(cls, alert, alert_publication, fallback_text): """Format the message to be sent to slack. Args: + alert (Alert): The alert alert_publication (dict): Alert relevant to the triggered rule - header_text (str): A formatted rule header to be included with each - attachemnt as fallback text (to be shown on mobile, etc) + fallback_text (str): A formatted rule header to be included with each + attachment as fallback text (to be shown on mobile, etc) - Yields: - dict: A dictionary with the formatted attachemnt to be sent to Slack + Returns: + list(dict): A list of dictionaries with the formatted attachment to be sent to Slack as the text """ - record = alert_publication.get('record', {}) - rule_description = alert_publication.get('rule_description', '') + record = alert.record + rule_description = alert.rule_description messages = list(cls._split_attachment_text(record)) + attachments = [] for index, message in enumerate(messages, start=1): title = 'Record:' if len(messages) > 1: @@ -127,27 +130,175 @@ def _format_attachments(cls, alert_publication, header_text): rule_desc = rule_description rule_desc = '*Rule Description:*\n{}\n'.format(rule_desc) - # Yield this attachemnt to be sent with the list of attachments - yield { - 'fallback': header_text, + # https://api.slack.com/docs/message-attachments#attachment_structure + attachments.append({ + 'fallback': fallback_text, 'color': '#b22222', 'pretext': rule_desc, 'title': title, 'text': message, 'mrkdwn_in': ['text', 'pretext'] - } + }) if index == cls.MAX_ATTACHMENTS: LOGGER.warning('%s: %d-part message truncated to %d parts', alert_publication, len(messages), cls.MAX_ATTACHMENTS) break + return attachments + + @classmethod + def _get_attachment_skeleton(cls): + """Returns a skeleton for a Slack attachment containing various default values. + + Return: + dict + """ + return { + # String + # Plaintext summary of the attachment; renders in non-markdown compliant clients, + # such as push notifications. + 'fallback': '', + + # String, hex color + # Colors the vertical bar to the left of the text. + 'color': '#36a64f', + + # String + # Text that appears above the vertical bar to the left of the attachment. + # Supports markdown if it's included in "mrkdwn_in" + 'pretext': '', + + # String + # The attachment's author name. + # If this field is omitted, then the entire author row is omitted. + 'author_name': '', + + # String, URL + # Provide a URL; Adds a clickable link to the author name + 'author_link': '', + + # String, URL of an image + # The icon appears to the left of the author name + 'author_icon': '', + + # String + # Appears as bold text above the attachment itself. + # If this field is omitted, the entire title row is omitted. + 'title': '', + + # String, URL + # Adds a clickable link to the title + 'title_link': '', + + # String + # Raw text that appears in the attachment, below the title but above the fields + # Supports markdown if it's included in "mrkdwn_in". + # Use \n for newline characters. + # This field has a field limit of cls.MAX_MESSAGE_SIZE + 'text': '', + + # Array of dicts; Each dict should have keys "title", "value", "short" + # An array of fields that appears below the text. These fields are clearly delineated + # with title and value. + 'fields': [ + # Sample field: + # { + # "title": "Priority", + # "value": "High", + # "short": False + # } + ], + + # String, URL of an image + # Large image that appears as an attachment + 'image_url': '', + + # String, URL of an image + # When image_url is omitted, this one renders a smaller image to the right + 'thumb_url': '', + + # String + # Appears at the very bottom + # If this field is omitted, also omits the footer icon + 'footer': '', + + # String, URL + # This icon appears to the left of the footer + 'footer_icon': '', + + # Integer, Unix timestamp + # This will show up next to the footer at the bottom. + # This timestamp does not change the time the message is actually sent. + 'ts': '', + + # List of strings + # Defines which of the above fields will support Slack's simple markdown (with special + # characters like *, ~, _, `, or ```... etc) + # By default, we respect markdown in "text" and "pretext" + "mrkdwn_in": [ + 'text', + 'pretext', + ], + } + + @classmethod + def _standardize_custom_attachments(cls, custom_slack_attachments): + """Supplies default fields to given attachments and validates their structure. + + You can test out custom attachments using this tool: + https://api.slack.com/docs/messages/builder + + When publishers provider custom slack attachments to the SlackOutput, it offers increased + flexibility, but requires more work. Publishers need to pay attention to the following: + + - Slack requires escaping the characters: '&', '>' and '<' + - Slack messages have a limit of 4000 characters + - Individual slack messages support a maximum of 20 attachments + + + Args: + custom_slack_attachments (list): A list of dicts that is provided by the publisher. + + Returns: + list: The value to the "attachments" Slack API argument + """ + + attachments = [] + + for custom_slack_attachment in custom_slack_attachments: + attachment = cls._get_attachment_skeleton() + attachment.update(custom_slack_attachment) + + # Enforce maximum text length; make sure to check size AFTER escaping in case + # extra escape characters pushes it over the limit + if len(attachment['text']) > cls.MAX_MESSAGE_SIZE: + LOGGER.warning( + 'Custom attachment was truncated to length %d. Full message: %s', + cls.MAX_MESSAGE_SIZE, + attachment['text'] + ) + attachment['text'] = elide_string_middle(attachment['text'], cls.MAX_MESSAGE_SIZE) + + attachments.append(attachment) + + # Enforce maximum number of attachments + if len(attachments) >= cls.MAX_ATTACHMENTS: + LOGGER.warning( + 'Message with %d custom attachments was truncated to %d attachments', + len(custom_slack_attachments), + cls.MAX_ATTACHMENTS + ) + break + + return attachments + @classmethod - def _format_message(cls, rule_name, alert_publication): + def _format_message(cls, alert, alert_publication): """Format the message to be sent to slack. Args: - rule_name (str): The name of the rule that triggered the alert + alert (Alert): The alert alert_publication (dict): Alert relevant to the triggered rule Returns: @@ -160,8 +311,17 @@ def _format_message(cls, rule_name, alert_publication): Record (Part 1 of 2): ... """ - header_text = '*StreamAlert Rule Triggered: {}*'.format(rule_name) - attachments = list(cls._format_attachments(alert_publication, header_text)) + default_header_text = '*StreamAlert Rule Triggered: {}*'.format(alert.rule_name) + header_text = alert_publication.get('@slack.text', default_header_text) + + if '@slack.attachments' in alert_publication: + attachments = cls._standardize_custom_attachments( + alert_publication.get('@slack.attachments') + ) + else: + # Default attachments + attachments = cls._format_default_attachments(alert, alert_publication, header_text) + full_message = { 'text': header_text, 'mrkdwn': True, @@ -257,6 +417,25 @@ def _json_list_to_text(cls, json_values, tab, indent_count): def _dispatch(self, alert, descriptor): """Send alert text to Slack + Publishing: + By default the slack output sends a slack message comprising some default intro text + and a series of attachments containing: + * alert description + * alert record, chunked into pieces if it's too long + + To override this behavior use the following fields: + + - @slack.text (str): + Replaces the text that appears as the first line in the slack message. + + - @slack.attachments (list[dict]): + A list of individual slack attachments to include in the message. Each + element of this list is a dict that must adhere to the syntax of attachments + on Slack's API. + + @see cls._standardize_custom_attachments() for some insight into how individual + attachments can be written. + Args: alert (Alert): Alert instance which triggered a rule descriptor (str): Output descriptor @@ -268,10 +447,9 @@ def _dispatch(self, alert, descriptor): if not creds: return False - publication = alert.publish_for(self, descriptor) - rule_name = publication.get('rule_name', '') + publication = compose_alert(alert, self, descriptor) - slack_message = self._format_message(rule_name, publication) + slack_message = self._format_message(alert, publication) try: self._post_request_retry(creds['url'], slack_message) diff --git a/stream_alert/rules_engine/rules_engine.py b/stream_alert/rules_engine/rules_engine.py index 1196fa3e0..e04d73844 100644 --- a/stream_alert/rules_engine/rules_engine.py +++ b/stream_alert/rules_engine/rules_engine.py @@ -16,15 +16,17 @@ from datetime import datetime, timedelta from os import environ as env +from stream_alert.shared.publisher import AlertPublisherRepository from stream_alert.rules_engine.alert_forwarder import AlertForwarder from stream_alert.rules_engine.threat_intel import ThreatIntel from stream_alert.shared import resources, RULES_ENGINE_FUNCTION_NAME as FUNCTION_NAME from stream_alert.shared.alert import Alert from stream_alert.shared.config import load_config -from stream_alert.shared.rule import import_folders, Rule +from stream_alert.shared.importer import import_folders from stream_alert.shared.logger import get_logger from stream_alert.shared.lookup_tables import LookupTables from stream_alert.shared.metrics import MetricLogger +from stream_alert.shared.rule import Rule from stream_alert.shared.rule_table import RuleTable from stream_alert.shared.stats import get_rule_stats @@ -178,27 +180,22 @@ def _rule_analysis(self, payload, rule): """Analyze a record with the rule, adding a new alert if applicable Args: - record (dict): Record to perform rule analysis against + payload (dict): Representation of event to perform rule analysis against rule (rule.Rule): Attributes for the rule which triggered the alert """ rule_result = rule.process(payload['record']) if not rule_result: return - # Check if the rule is staged and, if so, only use the required alert outputs - if rule.is_staged(self._rule_table): - all_outputs = self._required_outputs_set - else: # Otherwise, combine the required alert outputs with the ones for this rule - all_outputs = self._required_outputs_set.union(rule.outputs_set) - alert = Alert( - rule.name, payload['record'], all_outputs, + rule.name, payload['record'], self._configure_outputs(rule), cluster=payload['cluster'], context=rule.context, log_source=payload['log_schema_type'], log_type=payload['data_type'], merge_by_keys=rule.merge_by_keys, merge_window=timedelta(minutes=rule.merge_window_mins), + publishers=self._configure_publishers(rule), rule_description=rule.description, source_entity=payload['resource'], source_service=payload['service'], @@ -211,6 +208,129 @@ def _rule_analysis(self, payload, rule): return alert + def _configure_outputs(self, rule): + # Check if the rule is staged and, if so, only use the required alert outputs + if rule.is_staged(self._rule_table): + all_outputs = self._required_outputs_set + else: # Otherwise, combine the required alert outputs with the ones for this rule + all_outputs = self._required_outputs_set.union(rule.outputs_set) + + return all_outputs + + @classmethod + def _configure_publishers(cls, rule): + """Assigns publishers to each output. + + The @Rule publisher syntax accepts several formats, including a more permissive blanket + option. + + In this configuration we DELIBERATELY do not include required_outputs as required outputs + should never have their alerts transformed. + + Args: + rule (Rule): The rule to create publishers for + + Returns: + dict: Maps string outputs names to lists of strings of their fully qualified publishers + """ + requested_outputs = rule.outputs_set + requested_publishers = rule.publishers + if not requested_publishers: + return None + + configured_publishers = {} + + for output in requested_outputs: + assigned_publishers = [] + + if cls.is_publisher_declaration(requested_publishers): + # Case 1: The publisher is a single string. + # apply this single publisher to all outputs + descriptors + cls.add_publisher(requested_publishers, assigned_publishers) + elif isinstance(requested_publishers, list): + # Case 2: The publisher is an array of strings. + # apply all publishers to all outputs + descriptors + cls.add_publishers(requested_publishers, assigned_publishers) + elif isinstance(requested_publishers, dict): + # Case 3: The publisher is a dict mapping output strings -> strings or list of + # strings. Apply only publishers under a matching output key. + # + # We look under 2 keys: + # - [Output]: Applies publishers to all outputs for a specific output type. + # - [Output+Descriptor]: Applies publishers only to the specific output that + # exactly matches the output+descriptor key. + output_service = output.split(':')[0] + + # Order is important here; We load output-specific publishers first + if output_service in requested_publishers: + specific_publishers = requested_publishers[output_service] + if cls.is_publisher_declaration(specific_publishers): + cls.add_publisher(specific_publishers, assigned_publishers) + elif isinstance(specific_publishers, list): + cls.add_publishers(specific_publishers, assigned_publishers) + + # Then we load the output+descriptor-specific publishers second + if output in requested_publishers: + specific_publishers = requested_publishers[output] + if cls.is_publisher_declaration(specific_publishers): + cls.add_publisher(specific_publishers, assigned_publishers) + elif isinstance(specific_publishers, list): + cls.add_publishers(specific_publishers, assigned_publishers) + else: + LOGGER.error('Invalid publisher argument: %s', requested_publishers) + + configured_publishers[output] = assigned_publishers + + return configured_publishers + + @classmethod + def standardize_publisher_list(cls, list_of_references): + """Standardizes a list of requested publishers""" + publisher_names = [cls.standardize_publisher_name(x) for x in list_of_references] + + # Filter out None from the array + return [x for x in publisher_names if x is not None] + + @classmethod + def standardize_publisher_name(cls, string_or_reference): + """Standardizes a requested publisher into a string name + + Requested publishers can be either the fully qualified string name, OR it can be a + direct reference to the function or class. + """ + if not cls.is_publisher_declaration(string_or_reference): + LOGGER.error('Invalid publisher requested: %s', string_or_reference) + return None + + if isinstance(string_or_reference, basestring): + publisher_name = string_or_reference + else: + publisher_name = AlertPublisherRepository.get_publisher_name( + string_or_reference + ) + + if AlertPublisherRepository.has_publisher(publisher_name): + return publisher_name + + LOGGER.warning('Requested publisher named (%s) is not registered.', publisher_name) + + @classmethod + def is_publisher_declaration(cls, string_or_reference): + """Returns TRUE if the requested publisher is valid (a string name or reference)""" + return ( + isinstance(string_or_reference, basestring) or + AlertPublisherRepository.is_valid_publisher(string_or_reference) + ) + + @classmethod + def add_publisher(cls, publisher_reference, current_list): + _publisher = cls.standardize_publisher_name(publisher_reference) + current_list += [_publisher] if _publisher is not None else [] + + @classmethod + def add_publishers(cls, publisher_references, current_list): + current_list += cls.standardize_publisher_list(publisher_references) + def run(self, records): """Run rules against the records sent from the Classifier function diff --git a/stream_alert/shared/alert.py b/stream_alert/shared/alert.py index 706bc9b0b..f0a503380 100644 --- a/stream_alert/shared/alert.py +++ b/stream_alert/shared/alert.py @@ -29,8 +29,8 @@ class Alert(object): _EXPECTED_INIT_KWARGS = { 'alert_id', 'attempts', 'cluster', 'context', 'created', 'dispatched', 'log_source', - 'log_type', 'merge_by_keys', 'merge_window', 'outputs_sent', 'rule_description', - 'source_entity', 'source_service', 'staged' + 'log_type', 'merge_by_keys', 'merge_window', 'outputs_sent', 'publishers', + 'rule_description', 'source_entity', 'source_service', 'staged' } DATETIME_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' @@ -55,6 +55,16 @@ def __init__(self, rule_name, record, outputs, **kwargs): keys are equal. Keys can be present at any depth in the record. merge_window (timedelta): Merged alerts are sent at this interval. outputs_sent (set): Subset of outputs which have sent successfully. + publishers (str|list|dict): A structure of Strings, representing either fully qualified + function names, or publisher classes. Adopts one of the following formats: + + - None, or empty array; DefaultPublisher is run on all outputs. + - Single string; One publisher is run on all outputs. + - List of strings; All publishers are run on all outputs in order of declaration. + - Dict mapping to strings or lists: The dict maps output service keys to a string + or lists of strings. These strings corresponds to all publishers that are run, + in order, for only that specific output service. + rule_description (str): Description associated with the triggering rule. source_entity (str): Name of location from which the record originated. E.g. "mychannel" source_service (str): Input type from which the record originated. E.g. "slack" @@ -83,6 +93,7 @@ def __init__(self, rule_name, record, outputs, **kwargs): self.attempts = int(kwargs.get('attempts', 0)) or 0 # Convert possible Decimal to int self.cluster = kwargs.get('cluster') or None self.context = kwargs.get('context') or {} + self.publishers = kwargs.get('publishers') or {} # datetime.min isn't supported by strftime, so use Unix epoch instead for default value self.dispatched = kwargs.get('dispatched') or datetime(year=1970, month=1, day=1) @@ -149,6 +160,7 @@ def dynamo_record(self): 'MergeWindowMins': int(self.merge_window.total_seconds() / 60), 'Outputs': self.outputs, 'OutputsSent': self.outputs_sent or None, # Empty sets not allowed by Dynamo + 'Publishers': self.publishers or None, # Compact JSON encoding (no spaces). We have to JSON-encode here # (instead of just passing the dict) because Dynamo does not allow empty string values. 'Record': json.dumps(self.record, separators=(',', ':'), default=list), @@ -188,6 +200,7 @@ def create_from_dynamo_record(cls, record): merge_by_keys=record.get('MergeByKeys'), merge_window=timedelta(minutes=int(record.get('MergeWindowMins', 0))), outputs_sent=set(record.get('OutputsSent') or []), + publishers=record.get('Publishers'), rule_description=record.get('RuleDescription'), source_entity=record.get('SourceEntity'), source_service=record.get('SourceService'), @@ -199,6 +212,10 @@ def create_from_dynamo_record(cls, record): def output_dict(self): """Convert the alert into a dictionary ready to send to an output. + (!) This method is deprecated. Going forward, try to use the method: + + stream_alert.alert_processor.helpers.compose_alert + Returns: dict: An alert dictionary for sending to outputs. The result is JSON-compatible, backwards-compatible (existing keys are not changed), @@ -214,30 +231,16 @@ def output_dict(self): 'id': self.alert_id, 'log_source': self.log_source or '', 'log_type': self.log_type or '', - 'outputs': list(sorted(self.outputs)), # List instead of set for JSON-compatibility + 'outputs': sorted(self.outputs), # List instead of set for JSON-compatibility + 'publishers': self.publishers or {}, 'record': self.record, 'rule_description': self.rule_description or '', 'rule_name': self.rule_name or '', 'source_entity': self.source_entity or '', 'source_service': self.source_service or '', - 'staged': self.staged + 'staged': self.staged, } - def publish_for(self, output_class, descriptor): # pylint: disable=unused-argument - """Presents the current alert as a dict of information for OutputDispatchers to send. - - Args: - output_class (OutputDispatcher): The output dispatching service - descriptor (string): The output's descriptor - - Returns: - dict: A dict of published data - """ - # FIXME (derek.wang) Currently, this completely disregards the output_class and descriptor - # as this Alert entity does not yet have the "publishers" field available to determine - # how to publish itself. - return self.output_dict() - # ---------- Alert Merging ---------- def can_merge(self, other): @@ -422,6 +425,7 @@ def merge(cls, alerts): context=alerts[0].context, log_source=alerts[0].log_source, log_type=alerts[0].log_type, + publishers=alerts[0].publishers, rule_description=alerts[0].rule_description, source_entity=alerts[0].source_entity, source_service=alerts[0].source_service, diff --git a/stream_alert/shared/description.py b/stream_alert/shared/description.py new file mode 100644 index 000000000..e0634aac9 --- /dev/null +++ b/stream_alert/shared/description.py @@ -0,0 +1,164 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import re + + +class RuleDescriptionParser(object): + """Class that does fuzzy parsing information out of the rule description + + In general, rule descriptions follow a very FUZZY scheme where they are newline-delimited + and have at most one field per line (although it's possible for a single field to span + multiple lines). Each field is one or more words preceding a colon. + + Example: + + author: Derek + description: Blah lorem ipsum + bacon bleu cheese + playbook: etc + + Another possible format is to have a string description preceding the set of (optional) fields: + + Example: + + This rule is triggered when the speed hits over 9000 + + author: Derek + playbook: etc + + Additionally, certain fields can have URL values. Long URLs are split across multiple lines + but are conjoined in the final parsed product + + Example: + + author: Derek + reference: https://this.is.a.really.really/long/url + ?that=does+not+fit+on+one+line#but=gets%53eventually+smushed+together + + Lastly, by default all line breaks are stripped out. However, if there is a double-line-break + in the middle of a text, this double-line-break will appear in the final text as two newline + characters. + + Example: + + description: + This is paragraph 1 and remains unbroken despite having + a linebreak in the middle of it. + + However, this paragraph 2 is broken from paragraph 1 because + it has a double break in between. + """ + + # Match alphanumeric, plus underscores, dashes, spaces, and & signs + # Labels are a maximum of 20 characters long. They also never start with http or https + _FIELD_REGEX = re.compile( + r'^(?!http:|https:)(?P[a-zA-Z\d\-_&\s]{0,20}):(?P.*)$' + ) + _URL_REGEX = re.compile( + r'^(?:http(s)?://)?[\w.-]+(?:\.[\w\.-]+)+[\w\-\._~:/?#[\]@!\$&\'\(\)\*\+,;=.]+$' + ) + + @classmethod + def parse(cls, rule_description): + """Parses a multiline rule description string + + Args: + rule_description (str): The rule's description + + Return: + dict: A dict mapping fields to lists of strings, each corresponding to a line belonging + to that field. All field names are lowercase. + """ + rule_description = '' if not rule_description else rule_description + tokens = [line.strip() for line in rule_description.strip().split('\n')] + + field_lines = {} + + current_field = 'description' + for token in tokens: + if current_field not in field_lines: + field_lines[current_field] = [] + + if not token: + # preserve an empty line + field_lines[current_field].append('') + continue + + # Python regex does not support possessive qualifiers, which means it not easy + # to write a regex that detects for \s++(?:http:) because operator will attempt to + # give up characters to the negative lookahead. So, we strip the line first before + # doing the negative lookahead. + match = cls._FIELD_REGEX.match(token) + + if match is not None: + current_field = match.group('field').strip().lower() + value = match.group('remainder').strip() + else: + value = token + + if current_field not in field_lines: + field_lines[current_field] = [] + field_lines[current_field].append(value) + + return field_lines + + @classmethod + def present(cls, rule_description): + def join_lines(lines): + if not isinstance(lines, list) or len(lines) <= 0: + return '' + + document = None + buffered_newlines = '' + for line in lines: + if not line: + buffered_newlines += '\n' + continue + + if document is None: + buffered_newlines = '' + document = line + else: + match = cls._URL_REGEX.match(document + line) + if match is not None: + document += line + else: + space = buffered_newlines if buffered_newlines else ' ' + buffered_newlines = '' + document += space + line + + if document is None: + document = '' + + return document + + fragments = cls.parse(rule_description) + + presentation = { + 'author': '', + 'description': '', + 'fields': {}, + } + + for key, value in fragments.iteritems(): + if key in ['author', 'maintainer']: + presentation['author'] = join_lines(value) + elif key in ['description']: + presentation['description'] = join_lines(value) + else: + presentation['fields'][key] = join_lines(value) + + return presentation diff --git a/stream_alert/shared/importer.py b/stream_alert/shared/importer.py new file mode 100644 index 000000000..101d59e45 --- /dev/null +++ b/stream_alert/shared/importer.py @@ -0,0 +1,61 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import importlib +import os + + +def import_folders(*paths): + """Dynamically import all rules files. + + Args: + *paths (string): Variable length tuple of paths from within which any + .py files should be imported + """ + for path in _python_file_paths(*paths): + importlib.import_module(_path_to_module(path)) + + +def _python_file_paths(*paths): + """Yields all .py files in the passed paths + + Args: + *paths (string): Variable length tuple of paths from within which any + .py files should be imported + + Yields: + str: Relative path to .py file to me imported using importlib + """ + for folder in paths: + for root, _, files in os.walk(folder): + for file_name in files: + if file_name.endswith('.py') and not file_name.startswith('__'): + yield os.path.join(root, file_name) + + +def _path_to_module(path): + """Convert a Python rules file path to an importable module name. + + For example, "rules/community/cloudtrail_critical_api_calls.py" becomes + "rules.community.cloudtrail_critical_api_calls" + + Raises: + NameError if a '.' appears anywhere in the path except the file extension. + """ + base_name = os.path.splitext(path)[0] + if '.' in base_name: + raise NameError('Python file "{}" cannot be imported ' + 'because of "." in the name'.format(path)) + return base_name.replace('/', '.') diff --git a/stream_alert/shared/publisher.py b/stream_alert/shared/publisher.py new file mode 100644 index 000000000..0473972fc --- /dev/null +++ b/stream_alert/shared/publisher.py @@ -0,0 +1,266 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from abc import abstractmethod +from copy import deepcopy +from inspect import isclass + +from stream_alert.shared.importer import import_folders +from stream_alert.shared.logger import get_logger + +LOGGER = get_logger(__name__) + + +class PublisherError(Exception): + """Exception to raise for any errors with invalid publishers""" + + +class PublisherRegistrationError(PublisherError): + """Exception to raise when an error occurs during the @Register step of a publisher""" + + +class PublisherAssemblyError(PublisherError): + """Exception to raise when a publisher fails lookup or assembly""" + + +class Register(object): + """This is a decorator used to register publishers into the AlertPublisherRepository.""" + + def __new__(cls, class_or_function): + AlertPublisherRepository.register_publisher(class_or_function) + + return class_or_function # Return the definition, not the instantiated object + + +class AlertPublisher(object): + """Interface for a Publisher. All class-based publishers must inherit from this class.""" + + @abstractmethod + def publish(self, alert, publication): + """Publishes the given alert. + + As a general rule of thumb, published fields that are specific to a certain output are + published as top-level keys of the following format: + + [output service name].[field name] + + E.g. "demisto.blah" + + Args: + alert (Alert): The alert instance to publish. + publication (dict): An existing publication generated by previous publishers in the + series of publishers, or {}. + + Returns: + dict: The published alert. + """ + + +class CompositePublisher(AlertPublisher): + """A publisher class that combines the logic of multiple other publishers together in series + + To reduce the chance that one publisher has side effects in other publishers in the chain, + we use deepcopy between the publishers. + + Note: This publisher is not meant to be @Register'd as it does not have any logic on its own. + It is only meant to be composed by AlertPublisherRepository to give a common interface to + multiple publishers chained in sequence. + """ + + def __init__(self, publishers): + self._publishers = publishers # Type list(AlertPublisher) + + for publisher in self._publishers: + if not isinstance(publisher, AlertPublisher): + LOGGER.error('CompositePublisher given invalid publisher') + + def publish(self, alert, publication): + for publisher in self._publishers: + try: + publication = deepcopy(publication) + publication = publisher.publish(alert, publication) + except KeyError: + LOGGER.exception( + 'CompositePublisher encountered KeyError with publisher: %s', + publisher.__class__.__name__ + ) + raise + + return publication + + +class WrappedFunctionPublisher(AlertPublisher): + """A class only used to wrap a function publisher.""" + + def __init__(self, function): + self._function = function + + def publish(self, alert, publication): + return self._function(alert, publication) + + +class AlertPublisherRepository(object): + """A repository mapping names -> publishers + + As a usability optimization, using this Repository will eagerly load and register all + publishers in the application. + """ + _PUBLISHERS_DIRECTORY = 'publishers' + _publishers = {} + _is_imported = False + + @classmethod + def import_publishers(cls): + if not cls._is_imported: + import_folders(cls._PUBLISHERS_DIRECTORY) + cls._is_imported = True + + + @staticmethod + def is_valid_publisher(thing): + """Returns TRUE if the given reference can be registered as a publisher + + Publishers are valid if and only if they fall into one of the following categories: + + * They are a python function that accepts 2 arguments: (alert: Alert, publication: dict) + * They are a python class that extends AlertPublisher + + Args: + thing (mixed): Any primitive or reference to be checked + + Returns: + bool + """ + + # We have to put the isclass() check BEFORE the callable() check because classes are also + # callable! + return issubclass(thing, AlertPublisher) if isclass(thing) else callable(thing) + + @staticmethod + def get_publisher_name(class_or_function): + """Given a class or function, will return its fully qualified name. + + This is useful for assigning a unique string name for a publisher. + + Args: + class_or_function (callable|Class): A reference to a python function or class + + Returns: + string + """ + return '{}.{}'.format(class_or_function.__module__, class_or_function.__name__) + + @classmethod + def register_publisher(cls, publisher): + """Registers the publisher into the repository. + + To standardize the interface of publishers, if a function publisher is given, it will be + wrapped with a WrappedFunctionPublisher instance prior to being registed into the + Repository. + + Args: + publisher (callable|AlertPublisher): An instance of a publisher class or a function + """ + if not AlertPublisherRepository.is_valid_publisher(publisher): + error = ( + 'Could not register publisher {}; Not callable nor subclass of AlertPublisher' + ).format(publisher) + raise PublisherRegistrationError(error) + + elif isclass(publisher): + # If the provided publisher is a Class, then we simply need to instantiate an instance + # of the class and register it. + publisher_instance = publisher() + else: + # If the provided publisher is a function, we wrap it with a WrappedFunctionPublisher + # to make them easier to handle. + publisher_instance = WrappedFunctionPublisher(publisher) + + name = AlertPublisherRepository.get_publisher_name(publisher) + + if name in cls._publishers: + error = 'Publisher with name [{}] has already been registered.'.format(name) + raise PublisherRegistrationError(error) + + cls._publishers[name] = publisher_instance + + @classmethod + def get_publisher(cls, name): + """Returns the publisher with the given name + + Args: + name (str): The name of the publisher. + + Returns: + AlertPublisher|None + """ + if cls.has_publisher(name): + return cls._publishers[name] + + LOGGER.error('Publisher [%s] does not exist', name) + + @classmethod + def has_publisher(cls, name): + """Returns true if the given publisher name has been registered in this Repository + """ + cls.import_publishers() + return name in cls._publishers + + @classmethod + def all_publishers(cls): + """Returns all registered publishers in a dict mapping their unique name to instances. + + Remember: Function publishers are wrapped with WrappedFunctionPublisher + Also remember: These publishers are INSTANCES of the publisher classes, not the classes + themselves. + + Returns: + dict + """ + return cls._publishers + + @classmethod + def create_composite_publisher(cls, publisher_names): + """Assembles a single publisher that combines logic from multiple publishers + + Args: + publisher_names (list(str)): A list of string names of publishers + + Return: + CompositePublisher|DefaultPublisher + """ + publisher_names = publisher_names or [] + publishers = [] + + for publisher_name in publisher_names: + publisher = cls.get_publisher(publisher_name) + if publisher: + publishers.append(publisher) + + if not publishers: + # If no publishers were given, or if all of the publishers failed to load, then we + # load a default publisher. + default_publisher_name = cls.get_publisher_name(DefaultPublisher) + return cls.get_publisher(default_publisher_name) + + return CompositePublisher(publishers) + + +@Register +class DefaultPublisher(AlertPublisher): + """The default publisher that is used when no other publishers are provided""" + + def publish(self, alert, publication): + return alert.output_dict() diff --git a/stream_alert/shared/rule.py b/stream_alert/shared/rule.py index b21474b48..1adcb4367 100644 --- a/stream_alert/shared/rule.py +++ b/stream_alert/shared/rule.py @@ -16,9 +16,7 @@ import ast from copy import deepcopy import hashlib -import importlib import inspect -import os from stream_alert.shared.logger import get_logger from stream_alert.shared.stats import time_rule @@ -27,50 +25,6 @@ LOGGER = get_logger(__name__) -def _python_file_paths(*paths): - """Yields all .py files in the passed paths - - Args: - *paths (string): Variable length tuple of paths from within which any - .py files should be imported - - Yields: - str: Relative path to .py file to me imported using importlib - """ - for folder in paths: - for root, _, files in os.walk(folder): - for file_name in files: - if file_name.endswith('.py') and not file_name.startswith('__'): - yield os.path.join(root, file_name) - - -def _path_to_module(path): - """Convert a Python rules file path to an importable module name. - - For example, "rules/community/cloudtrail_critical_api_calls.py" becomes - "rules.community.cloudtrail_critical_api_calls" - - Raises: - NameError if a '.' appears anywhere in the path except the file extension. - """ - base_name = os.path.splitext(path)[0] - if '.' in base_name: - raise NameError('Python file "{}" cannot be imported ' - 'because of "." in the name'.format(path)) - return base_name.replace('/', '.') - - -def import_folders(*paths): - """Dynamically import all rules files. - - Args: - *paths (string): Variable length tuple of paths from within which any - .py files should be imported - """ - for path in _python_file_paths(*paths): - importlib.import_module(_path_to_module(path)) - - class RuleCreationError(Exception): """Exception to raise for any errors with invalid rules""" @@ -105,6 +59,7 @@ def __init__(self, func, **kwargs): self.merge_by_keys = kwargs.get('merge_by_keys') self.merge_window_mins = kwargs.get('merge_window_mins') or 0 self.outputs = kwargs.get('outputs') + self.publishers = kwargs.get('publishers') self.req_subkeys = kwargs.get('req_subkeys') self.initial_context = kwargs.get('context') self.context = None diff --git a/stream_alert/shared/rule_table.py b/stream_alert/shared/rule_table.py index 22d415a1d..f9e0f28ea 100644 --- a/stream_alert/shared/rule_table.py +++ b/stream_alert/shared/rule_table.py @@ -18,8 +18,9 @@ import boto3 from stream_alert.shared.helpers.dynamodb import ignore_conditional_failure +from stream_alert.shared.importer import import_folders from stream_alert.shared.logger import get_logger -from stream_alert.shared.rule import import_folders, Rule +from stream_alert.shared.rule import Rule LOGGER = get_logger(__name__) diff --git a/stream_alert_cli/manage_lambda/package.py b/stream_alert_cli/manage_lambda/package.py index 18e144309..d3d8df193 100644 --- a/stream_alert_cli/manage_lambda/package.py +++ b/stream_alert_cli/manage_lambda/package.py @@ -203,6 +203,7 @@ class RulesEnginePackage(LambdaPackage): lambda_handler = 'stream_alert.rules_engine.main.handler' package_files = { 'conf', + 'publishers', 'rules', 'stream_alert/__init__.py', 'stream_alert/rules_engine', @@ -218,6 +219,7 @@ class AlertProcessorPackage(LambdaPackage): lambda_handler = 'stream_alert.alert_processor.main.handler' package_files = { 'conf', + 'publishers', 'stream_alert/__init__.py', 'stream_alert/alert_processor', 'stream_alert/shared' diff --git a/stream_alert_cli/test/handler.py b/stream_alert_cli/test/handler.py index e44f4d915..86b0ecb7d 100644 --- a/stream_alert_cli/test/handler.py +++ b/stream_alert_cli/test/handler.py @@ -23,9 +23,11 @@ import time import zlib -from mock import patch +from mock import patch, MagicMock from stream_alert.alert_processor import main as alert_processor +from stream_alert.alert_processor.helpers import compose_alert +from stream_alert.alert_processor.outputs.output_base import OutputDispatcher from stream_alert.classifier import classifier from stream_alert.classifier.parsers import ParserBase from stream_alert.rules_engine import rules_engine @@ -111,9 +113,14 @@ def _run_rules_engine(self, record): with patch.object(rules_engine.ThreatIntel, '_query') as ti_mock, \ patch.object(rules_engine.LookupTables, 'load_lookup_tables') as lt_mock, \ patch.object(rules_engine, 'AlertForwarder'), \ - patch.object(rules_engine, 'RuleTable'), \ + patch.object(rules_engine, 'RuleTable') as rule_table, \ patch('rules.helpers.base.random_bool', return_value=True): + # Emptying out the rule table will force all rules to be unstaged, which causes + # non-required outputs to get properly populated on the Alerts that are generated + # when running the Rules Engine. + rule_table.return_value = False + ti_mock.side_effect = self._threat_intel_mock # pylint: disable=protected-access @@ -142,7 +149,7 @@ def _finalize(self): format_underline('\nSummary:\n'), 'Total Tests: {}'.format(self._passed + self._failed), format_green('Pass: {}'.format(self._passed)) if self._passed else 'Pass: 0', - format_red('Fail: {}\n'.format(self._failed)) if self._failed else 'Fail: 0\n' + format_red('Fail: {}\n'.format(self._failed)) if self._failed else 'Fail: 0\n', ] print('\n'.join(summary)) @@ -218,6 +225,11 @@ def run(self): alerts = self._run_rules_engine(classifier_result[0].sqs_messages) test_result.alerts = alerts + if not original_event.get('skip_publishers'): + for alert in alerts: + publication_results = self._run_publishers(alert) + test_result.set_publication_results(publication_results) + if self._type == self.Types.LIVE: for alert in alerts: alert_result = self._run_alerting(alert) @@ -236,6 +248,38 @@ def run(self): return self._failed == 0 + @staticmethod + def _run_publishers(alert): + """Runs publishers for all currently configured outputs on the given alert + + Args: + - alert (Alert): The alert + + Returns: + dict: A dict keyed by output:descriptor strings, mapped to nested dicts. + The nested dicts have 2 keys: + - publication (dict): The dict publication + - success (bool): True if the publishing finished, False if it errored. + """ + configured_outputs = alert.outputs + + results = {} + for configured_output in configured_outputs: + [output_name, descriptor] = configured_output.split(':') + + try: + output = MagicMock(spec=OutputDispatcher, __service__=output_name) + results[configured_output] = { + 'publication': compose_alert(alert, output, descriptor), + 'success': True, + } + except (RuntimeError, TypeError, NameError) as err: + results[configured_output] = { + 'success': False, + 'error': err, + } + return results + def _get_test_files(self): """Helper to get rule files to be tested diff --git a/stream_alert_cli/test/results.py b/stream_alert_cli/test/results.py index 729321354..b9aa1f190 100644 --- a/stream_alert_cli/test/results.py +++ b/stream_alert_cli/test/results.py @@ -61,34 +61,45 @@ class TestResult(object): """TestResult contains information useful for tracking test results""" _NONE_STRING = '' - _SIMPLE_TEMPLATE = '{header}: {status}' - _VERBOSE_TEMPLATE = ( + _PASS_STRING = format_green('Pass') + _FAIL_STRING = format_red('Fail') + _SIMPLE_TEMPLATE = '{header}:' + _PASS_TEMPLATE = '{header}: {pass}' + _DESCRIPTION_LINE = ( ''' - Description: {description} - Classified Type: {classified_type} - Expected Type: {expected_type}''' + Description: {description}''' ) - - _VALIDATION_ONLY = ( + _CLASSIFICATION_STATUS_TEMPLATE = ( ''' - Validation Only: True''' + Classification: {classification_status} + Classified Type: {classified_type} + Expected Type: {expected_type}''' ) - - _RULES_TEMPLATE = ( + _RULES_STATUS_TEMPLATE = ( ''' - Triggered Rules: {triggered_rules} - Expected Rules: {expected_rules}''' + Rules: {rules_status} + Triggered Rules: {triggered_rules} + Expected Rules: {expected_rules}''' ) - _DISABLED_RULES_TEMPLATE = ( ''' - Disabled Rules: {disabled_rules}''' + Disabled Rules: {disabled_rules}''' + ) + _PUBLISHERS_STATUS_TEMPLATE = ( + ''' + Publishers: {publishers_status} + Errors: +{publisher_errors}''' + ) + _VALIDATION_ONLY = ( + ''' + Validation Only: True''' ) - _ALERTS_TEMPLATE = ( ''' - Sent Alerts: {sent_alerts} - Failed Alerts: {failed_alerts}''' + Live Alerts: + Sent Alerts: {sent_alerts} + Failed Alerts: {failed_alerts}''' ) _DEFAULT_INDENT = 4 @@ -99,6 +110,7 @@ def __init__(self, index, test_event, classified_result, with_rules=False, verbo self._with_rules = with_rules self._verbose = verbose self._live_test_results = {} + self._publication_results = {} self.alerts = [] def __nonzero__(self): @@ -108,23 +120,24 @@ def __nonzero__(self): __bool__ = __nonzero__ def __str__(self): - - # Store the computed property - passed = self.passed - fmt = { 'header': 'Test #{idx:02d}'.format(idx=self._idx + 1), - 'status': format_green('Pass') if passed else format_red('Fail') } + if self.passed and not self._verbose: + # Simply render "Test #XYZ: Pass" if the whole test case passed + template = self._PASS_TEMPLATE + fmt['pass'] = self._PASS_STRING + return template.format(**fmt) + + # Otherwise, expand the entire test with verbose details + template = self._SIMPLE_TEMPLATE + '\n' + self._DESCRIPTION_LINE + fmt['description'] = self._test_event['description'] - if passed and not self._verbose: - return self._SIMPLE_TEMPLATE.format(**fmt) - - template = '{}{}'.format( - self._SIMPLE_TEMPLATE.rjust(len(self._SIMPLE_TEMPLATE) + self._DEFAULT_INDENT * 2), - self._VERBOSE_TEMPLATE + # First, render classification + template += '\n' + self._CLASSIFICATION_STATUS_TEMPLATE + fmt['classification_status'] = ( + self._PASS_STRING if self.classification_tests_passed else self._FAIL_STRING ) - fmt['description'] = self._test_event['description'] fmt['expected_type'] = self._test_event['log'] fmt['classified_type'] = ( self._classified_result.log_schema_type @@ -134,11 +147,16 @@ def __str__(self): ) ) - if self._test_event.get('validate_schema_only'): - line = 'Validation Only: True' - template += '\n' + line.rjust(len(line) + self._DEFAULT_INDENT * 3) - elif self._with_rules: - template += self._RULES_TEMPLATE + # If it was classification-only, note it down + if self.validate_schema_only: + template += self._VALIDATION_ONLY + + # Render the result of rules engine run + if self.rule_tests_were_run: + template += '\n' + self._RULES_STATUS_TEMPLATE + fmt['rules_status'] = ( + self._PASS_STRING if self.rule_tests_passed else self._FAIL_STRING + ) fmt['triggered_rules'] = self._format_rules( self._triggered_rules, self.expected_rules @@ -154,10 +172,34 @@ def __str__(self): template += self._DISABLED_RULES_TEMPLATE fmt['disabled_rules'] = ', '.join(disabled) - if self._live_test_results: + # Render live test results + if self.has_live_tests: template += self._ALERTS_TEMPLATE fmt['sent_alerts'], fmt['failed_alerts'] = self._format_alert_results() + # Render any publisher errors + if self.publisher_tests_were_run: + template += '\n' + self._PUBLISHERS_STATUS_TEMPLATE + + num_pass = 0 + num_total = 0 + for _, result in self._publication_results.iteritems(): + num_total += 1 + num_pass += 1 if result['success'] else 0 + fmt['publishers_status'] = ( + format_green('{}/{} Passed'.format(num_pass, num_total)) + if num_pass == num_total + else format_red('{}/{} Passed'.format(num_pass, num_total)) + ) + pad = ' ' * self._DEFAULT_INDENT * 3 + fmt['publisher_errors'] = ( + format_red('\n'.join([ + '{}{}'.format(pad, error) for error in self.publisher_errors + ])) + if self.publisher_errors + else '{}{}'.format(pad, self._NONE_STRING) + ) + return textwrap.dedent(template.format(**fmt)).rstrip() + '\n' __repr__ = __str__ @@ -241,6 +283,104 @@ def _alert_result_block(self, values, failed=False): return self._NONE_STRING if not result_block else '\n{}'.format('\n'.join(result_block)) + @property + def validate_schema_only(self): + """Returns True if the testcase only requires classification and skips rules""" + return self._test_event.get('validate_schema_only') + + @property + def skip_publishers(self): + """Returns True if the testcase skips running publisher tests""" + return self._test_event.get('skip_publishers') + + @property + def rule_tests_were_run(self): + """Returns True if this testcase ran Rules Engine tests""" + return not self.validate_schema_only and self._with_rules + + @property + def publisher_tests_were_run(self): + """Returns True if this test ran Publisher tests for each output""" + return ( + self.rule_tests_were_run + and not self.skip_publishers + and self._publication_results + ) + + @property + def classification_tests_passed(self): + """Returns True if all classification tests passed""" + return self._classified + + @property + def rule_tests_passed(self): + """Returns True if all rules engine tests passed + + Also returns False if the rules engine tests were not run + """ + return self.rule_tests_were_run and (self._triggered_rules == self.expected_rules) + + @property + def has_live_tests(self): + """Returns True if this testcase ran any live tests""" + return self._live_test_results + + @property + def live_tests_passed(self): + """Returns True if all live tests passed + + Also returns False if live tests were not run + """ + if not self.has_live_tests: + return False + for result in self._live_test_results.itervalues(): + if not all(status for status in result.itervalues()): + return False + return True + + @property + def publisher_tests_passed(self): + """Returns True if all publisher tests were passed + + Also returns False if publisher tests were not run + """ + if not self.publisher_tests_were_run: + return False + + for _, result in self._publication_results.iteritems(): + if not result['success']: + return False + + return True + + @property + def publisher_errors(self): + """Returns an array of strings describing errors in the publisher tests + + The strings take the form: + + [output:descriptor]: (Error Type) Error message + """ + if not self.publisher_tests_were_run: + return [] + + return [ + "{}: ({}) {}".format(output_descriptor, type(item['error']).__name__, item['error']) + for output_descriptor, item + in self._publication_results.iteritems() + if not item['success'] + ] + + @property + def count_publisher_tests_passed(self): + """Returns number of publisher tests that failed""" + return sum(1 for _, result in self._publication_results.iteritems() if result['success']) + + @property + def count_publisher_tests_run(self): + """Returns total number of publisher tests""" + return len(self._publication_results) + @property def passed(self): """A test has passed if it meets the following criteria: @@ -249,23 +389,25 @@ def passed(self): 2) If rules are being tested, all triggered rules match expected rules 3) If a live test is being performed, all alerts sent to outputs successfully """ - if not self._classified: + if not self.classification_tests_passed: return False - if self._test_event.get('validate_schema_only'): - return True - - if not self._with_rules: - return True + if self.rule_tests_were_run: + if not self.rule_tests_passed: + return False - if not self._triggered_rules == self.expected_rules: - return False + if self.has_live_tests: + if not self.live_tests_passed: + return False - for result in self._live_test_results.itervalues(): - if not all(status for status in result.itervalues()): + if self.publisher_tests_were_run: + if not self.publisher_tests_passed: return False return True + def set_publication_results(self, publication_results): + self._publication_results = publication_results + def add_live_test_result(self, rule_name, result): self._live_test_results[rule_name] = result diff --git a/tests/integration/rules/duo/duo_fraud.json b/tests/integration/rules/duo/duo_fraud.json index f5de12b5a..003734591 100644 --- a/tests/integration/rules/duo/duo_fraud.json +++ b/tests/integration/rules/duo/duo_fraud.json @@ -65,4 +65,4 @@ "source": "prefix_cluster_duo_auth_sm-app-name_app", "trigger_rules": [] } -] \ No newline at end of file +] diff --git a/tests/unit/publishers/__init__.py b/tests/unit/publishers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/publishers/community/__init__.py b/tests/unit/publishers/community/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/publishers/community/pagerduty/__init__.py b/tests/unit/publishers/community/pagerduty/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/publishers/community/pagerduty/test_pagerduty_layout.py b/tests/unit/publishers/community/pagerduty/test_pagerduty_layout.py new file mode 100644 index 000000000..630f57045 --- /dev/null +++ b/tests/unit/publishers/community/pagerduty/test_pagerduty_layout.py @@ -0,0 +1,229 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access,attribute-defined-outside-init +from datetime import datetime + +from mock import MagicMock +from nose.tools import assert_equal, assert_true, assert_false + +from stream_alert.alert_processor.helpers import compose_alert +from stream_alert.alert_processor.outputs.output_base import OutputDispatcher +from tests.unit.stream_alert_alert_processor.helpers import get_alert + + +def test_shorten_title(): + """Publishers - PagerDuty - ShortenTitle""" + alert = get_alert(context={'context': 'value'}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': 'publishers.community.pagerduty.pagerduty_layout.ShortenTitle', + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = { + '@pagerduty.description': 'cb_binarystore_file_added', + '@pagerduty-v2.summary': 'cb_binarystore_file_added', + '@pagerduty-incident.incident_title': 'cb_binarystore_file_added' + } + assert_equal(publication, expectation) + + +def test_as_custom_details_default(): + """Publishers - PagerDuty - as_custom_details - Default""" + alert = get_alert(context={'context': 'value'}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': [ + 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers.community.pagerduty.pagerduty_layout.as_custom_fields' + ] + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = { + 'publishers': { + 'pagerduty': [ + 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers.community.pagerduty.pagerduty_layout.as_custom_fields' + ] + }, + 'source_entity': 'corp-prefix.prod.cb.region', + 'outputs': ['slack:unit_test_channel'], + 'cluster': '', + 'rule_description': 'Info about this rule and what actions to take', + 'log_type': 'json', + 'rule_name': 'cb_binarystore_file_added', + 'source_service': 's3', + 'created': '2019-01-01T00:00:00.000000Z', + 'log_source': 'carbonblack:binarystore.file.added', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'record': { + 'compressed_size': '9982', 'node_id': '1', 'cb_server': 'cbserver', + 'timestamp': '1496947381.18', 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }, + 'context': {'context': 'value'}, 'staged': False + } + assert_equal(publication, expectation) + + +def test_as_custom_details_ignores_custom_fields(): + """Publishers - PagerDuty - as_custom_details - Ignore Magic Keys""" + alert = get_alert(context={'context': 'value'}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': [ + 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers.community.pagerduty.pagerduty_layout.ShortenTitle', + 'publishers.community.pagerduty.pagerduty_layout.as_custom_details', + ] + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + # We don't care about the entire payload; let's check a few top-level keys we know + # are supposed to be here.. + assert_true(publication['source_entity']) + assert_true(publication['outputs']) + assert_true(publication['log_source']) + + # Check that the title keys exists + assert_true(publication['@pagerduty.description']) + + # now check that the details key exists + assert_true(publication['@pagerduty.details']) + + # And check that it has no magic keys + assert_false('@pagerduty.description' in publication['@pagerduty.details']) + assert_false('@pagerduty-v2.summary' in publication['@pagerduty.details']) + + +def test_v2_high_urgency(): + """Publishers - PagerDuty - v2_high_urgency""" + alert = get_alert(context={'context': 'value'}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': [ + 'publishers.community.pagerduty.pagerduty_layout.v2_high_urgency' + ] + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = {'@pagerduty-incident.urgency': 'high', '@pagerduty-v2.severity': 'critical'} + assert_equal(publication, expectation) + + +def test_v2_low_urgency(): + """Publishers - PagerDuty - v2_low_urgency""" + alert = get_alert(context={'context': 'value'}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': [ + 'publishers.community.pagerduty.pagerduty_layout.v2_low_urgency' + ] + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = {'@pagerduty-incident.urgency': 'low', '@pagerduty-v2.severity': 'warning'} + assert_equal(publication, expectation) + + +def test_pretty_print_arrays(): + """Publishers - PagerDuty - PrettyPrintArrays""" + alert = get_alert(context={'populate_fields': ['publishers', 'cb_server', 'staged']}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': [ + 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers.community.generic.populate_fields', + 'publishers.community.pagerduty.pagerduty_layout.PrettyPrintArrays' + ] + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = { + 'publishers': [ + { + 'pagerduty': ( + 'stream_alert.shared.publisher.DefaultPublisher\n\n----------\n\n' + 'publishers.community.generic.populate_fields\n\n----------\n\n' + 'publishers.community.pagerduty.pagerduty_layout.PrettyPrintArrays' + ) + } + ], + 'staged': 'False', + 'cb_server': 'cbserver' + } + assert_equal(publication, expectation) + + +def test_attach_image(): + """Publishers - PagerDuty - AttachImage""" + alert = get_alert() + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'pagerduty': [ + 'publishers.community.pagerduty.pagerduty_layout.AttachImage' + ] + } + + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'pagerduty' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = { + '@pagerduty-v2.images': [ + { + 'src': 'https://streamalert.io/en/stable/_images/sa-banner.png', + 'alt': 'StreamAlert Docs', + 'href': 'https://streamalert.io/en/stable/' + } + ], + '@pagerduty.contexts': [ + { + 'src': 'https://streamalert.io/en/stable/_images/sa-banner.png', + 'type': 'image' + } + ] + } + assert_equal(publication, expectation) diff --git a/tests/unit/publishers/community/test_generic.py b/tests/unit/publishers/community/test_generic.py new file mode 100644 index 000000000..bc13ccc1d --- /dev/null +++ b/tests/unit/publishers/community/test_generic.py @@ -0,0 +1,506 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access,attribute-defined-outside-init,invalid-name +from datetime import datetime + +from mock import MagicMock +from nose.tools import assert_equal, assert_true, assert_false + +from publishers.community.generic import _delete_dictionary_fields, StringifyArrays +from stream_alert.alert_processor.helpers import compose_alert +from stream_alert.alert_processor.outputs.output_base import OutputDispatcher +from stream_alert.alert_processor.outputs.slack import SlackOutput +from tests.unit.stream_alert_alert_processor.helpers import get_alert + + +class TestPublishersForOutput(object): + + @staticmethod + def test_publisher_for_output(): + alert = get_alert(context={'context': 'value'}) + alert.created = datetime(2019, 1, 1) + alert.publishers = { + 'slack': 'stream_alert.shared.publisher.DefaultPublisher', + 'slack:unit_test_channel': 'publishers.community.generic.remove_internal_fields', + 'demisto': 'publishers.community.generic.blank', + } + output = MagicMock(spec=OutputDispatcher) + output.__service__ = 'slack' + descriptor = 'unit_test_channel' + + publication = compose_alert(alert, output, descriptor) + + expectation = { + 'source_entity': 'corp-prefix.prod.cb.region', + 'rule_name': 'cb_binarystore_file_added', + 'created': '2019-01-01T00:00:00.000000Z', + 'log_source': 'carbonblack:binarystore.file.added', + 'log_type': 'json', + 'cluster': '', + 'context': {'context': 'value'}, + 'source_service': 's3', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'rule_description': 'Info about this rule and what actions to take', + 'record': { + 'compressed_size': '9982', + 'timestamp': '1496947381.18', + 'node_id': '1', + 'cb_server': 'cbserver', + 'size': '21504', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1' + } + } + assert_equal(publication, expectation) + + +class TestDefaultPublisher(object): + PUBLISHER_NAME = 'stream_alert.shared.publisher.DefaultPublisher' + + def setup(self): + self._alert = get_alert(context={'context': 'value'}) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_default_publisher(self): + """AlertPublisher - DefaultPublisher - Positive Case""" + publication = compose_alert(self._alert, self._output, 'test') + expectation = { + 'publishers': ['stream_alert.shared.publisher.DefaultPublisher'], + 'source_entity': 'corp-prefix.prod.cb.region', + 'outputs': ['slack:unit_test_channel'], + 'cluster': '', + 'rule_description': 'Info about this rule and what actions to take', + 'log_type': 'json', + 'rule_name': 'cb_binarystore_file_added', + 'source_service': 's3', + 'created': '2019-01-01T00:00:00.000000Z', + 'log_source': 'carbonblack:binarystore.file.added', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'record': { + 'compressed_size': '9982', + 'node_id': '1', + 'cb_server': 'cbserver', + 'timestamp': '1496947381.18', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }, + 'context': {'context': 'value'}, + 'staged': False + } + assert_equal(publication, expectation) + + +class TestRecordPublisher(object): + PUBLISHER_NAME = 'publishers.community.generic.add_record' + + def setup(self): + self._alert = get_alert(context={'context': 'value'}) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_default_publisher(self): + """AlertPublisher - add_record - Positive Case""" + publication = compose_alert(self._alert, self._output, 'test') + expectation = { + 'record': { + 'compressed_size': '9982', + 'node_id': '1', + 'cb_server': 'cbserver', + 'timestamp': '1496947381.18', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }, + } + assert_equal(publication, expectation) + + +class TestRemoveInternalFieldsPublisher(object): + PUBLISHER_NAME = 'publishers.community.generic.remove_internal_fields' + + def setup(self): + self._alert = get_alert(context={'context': 'value'}) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [TestDefaultPublisher.PUBLISHER_NAME, self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_remove_internal_fields(self): + """AlertPublisher - remove_internal_fields""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = { + 'source_entity': 'corp-prefix.prod.cb.region', + 'rule_name': 'cb_binarystore_file_added', + 'source_service': 's3', + 'created': '2019-01-01T00:00:00.000000Z', + 'log_source': 'carbonblack:binarystore.file.added', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'cluster': '', + 'context': {'context': 'value'}, + 'record': { + 'compressed_size': '9982', + 'timestamp': '1496947381.18', + 'node_id': '1', + 'cb_server': 'cbserver', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }, + 'log_type': 'json', + 'rule_description': 'Info about this rule and what actions to take' + } + assert_equal(publication, expectation) + + +class TestRemoveStreamAlertNormalizationFields(object): + PUBLISHER_NAME = 'publishers.community.generic.remove_streamalert_normalization' + + def setup(self): + self._alert = get_alert(context={'context': 'value'}) + self._alert.created = datetime(2019, 1, 1) + self._alert.record['added_fields'] = { + 'streamalert': { + 'yay': 'no', + }, + 'oof': [ + { + 'streamalert:normalization': '/////', + 'other': 'key' + } + ], + 'streamalert:normalization': { + 'bunch of stuff': 'that does not belong' + }, + } + self._alert.publishers = [TestDefaultPublisher.PUBLISHER_NAME, self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_works(self): + """AlertPublisher - FilterFields - Nothing""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = { + 'staged': False, + 'publishers': [ + 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers.community.generic.remove_streamalert_normalization' + ], + 'source_entity': 'corp-prefix.prod.cb.region', + 'rule_name': 'cb_binarystore_file_added', + 'outputs': ['slack:unit_test_channel'], + 'created': '2019-01-01T00:00:00.000000Z', + 'log_source': 'carbonblack:binarystore.file.added', + 'log_type': 'json', 'cluster': '', + 'context': {'context': 'value'}, + 'source_service': 's3', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'rule_description': 'Info about this rule and what actions to take', + 'record': { + 'compressed_size': '9982', + 'added_fields': { + 'streamalert': {'yay': 'no'}, + 'oof': [{'other': 'key'}], + }, + 'timestamp': '1496947381.18', + 'node_id': '1', + 'cb_server': 'cbserver', + 'size': '21504', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1' + } + } + assert_equal(publication, expectation) + + +class TestEnumerateFields(object): + PUBLISHER_NAME = 'publishers.community.generic.enumerate_fields' + + def setup(self): + self._alert = get_alert(context={ + 'context1': 'value', + 'attribs': [ + {'type': 'Name', 'value': 'Bob'}, + {'type': 'Age', 'value': '42'}, + {'type': 'Profession', 'value': 'Software engineer'}, + ] + }) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [TestDefaultPublisher.PUBLISHER_NAME, self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_enumerate_fields(self): + """AlertPublisher - enumerate_fields""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = { + 'cluster': '', + 'context.context1': 'value', + 'context.attribs[0].type': 'Name', + 'context.attribs[0].value': 'Bob', + 'context.attribs[1].type': 'Age', + 'context.attribs[1].value': '42', + 'context.attribs[2].value': 'Software engineer', + 'context.attribs[2].type': 'Profession', + 'created': '2019-01-01T00:00:00.000000Z', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'log_source': 'carbonblack:binarystore.file.added', + 'log_type': 'json', + 'outputs[0]': 'slack:unit_test_channel', + 'publishers[0]': 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers[1]': 'publishers.community.generic.enumerate_fields', + 'record.timestamp': '1496947381.18', + 'record.compressed_size': '9982', + 'record.cb_server': 'cbserver', + 'record.file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'record.md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'record.node_id': '1', + 'record.size': '21504', + 'record.type': 'binarystore.file.added', + 'rule_description': 'Info about this rule and what actions to take', + 'rule_name': 'cb_binarystore_file_added', + 'source_entity': 'corp-prefix.prod.cb.region', + 'source_service': 's3', + 'staged': False, + } + assert_equal(publication, expectation) + + def test_enumerate_fields_alphabetical_order(self): + """AlertPublisher - enumerate_fields - enforce alphabetical order""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = [ + 'cluster', + 'context.attribs[0].type', + 'context.attribs[0].value', + 'context.attribs[1].type', + 'context.attribs[1].value', + 'context.attribs[2].type', + 'context.attribs[2].value', + 'context.context1', + 'created', + 'id', + 'log_source', + 'log_type', + 'outputs[0]', + 'publishers[0]', + 'publishers[1]', + 'record.cb_server', + 'record.compressed_size', + 'record.file_path', + 'record.md5', + 'record.node_id', + 'record.size', + 'record.timestamp', + 'record.type', + 'rule_description', + 'rule_name', + 'source_entity', + 'source_service', + 'staged', + ] + + assert_equal(publication.keys(), expectation) + + +def test_delete_dictionary_fields(): + """Generic - _delete_dictionary_fields""" + pub = { + 'level1-1': { + 'level2-1': [ + { + 'level3-1': 'level4', + 'level3-2': 'level4', + } + ], + 'level2-2': { + 'level3': 'level4', + } + }, + 'level1-2': [ + { + 'thereisno': 'spoon' + } + ] + } + + result = _delete_dictionary_fields(pub, '^level3-1$') + + expectation = { + 'level1-1': { + 'level2-1': [ + { + 'level3-2': 'level4', + } + ], + 'level2-2': { + 'level3': 'level4', + } + }, + 'level1-2': [ + { + 'thereisno': 'spoon' + } + ] + } + + assert_equal(result, expectation) + + +class TestRemoveFields(object): + PUBLISHER_NAME = 'publishers.community.generic.remove_fields' + + def setup(self): + self._alert = get_alert(context={ + 'remove_fields': [ + 'streamalert', '^publishers', 'type$', + '^outputs$', '^cluster$', '^context$' + ] + }) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [TestDefaultPublisher.PUBLISHER_NAME, self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_remove_fields(self): + """AlertPublisher - enumerate_fields - enforce alphabetical order""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = { + 'staged': False, + 'source_entity': 'corp-prefix.prod.cb.region', + 'rule_name': 'cb_binarystore_file_added', + 'created': '2019-01-01T00:00:00.000000Z', + 'log_source': 'carbonblack:binarystore.file.added', + 'source_service': 's3', + 'id': '79192344-4a6d-4850-8d06-9c3fef1060a4', + 'rule_description': 'Info about this rule and what actions to take', + 'record': { + 'compressed_size': '9982', + 'timestamp': '1496947381.18', + 'node_id': '1', + 'cb_server': 'cbserver', + 'size': '21504', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1' + } + } + + assert_equal(publication, expectation) + + +class TestPopulateFields(object): + PUBLISHER_NAME = 'publishers.community.generic.populate_fields' + + def setup(self): + self._alert = get_alert(context={ + 'populate_fields': [ + 'compressed_size', 'id', 'oof', 'multi_field' + ], + 'other_field': 'a', + 'container': { + 'multi_field': 1, + 'depth2': { + 'multi_field': 2, + } + } + }) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [TestDefaultPublisher.PUBLISHER_NAME, self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_remove_fields(self): + """AlertPublisher - populate_fields""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = { + 'compressed_size': ['9982'], + 'oof': [], + 'id': ['79192344-4a6d-4850-8d06-9c3fef1060a4'], + 'multi_field': [1, 2] + } + + assert_equal(publication, expectation) + + +class TestStringifyArrays(object): + PUBLISHER_NAME = 'publishers.community.generic.StringifyArrays' + + def setup(self): + self._alert = get_alert(context={ + 'array': ['a', 'b', 'c'], + 'not_array': ['a', {'b': 'c'}, 'd'], + 'nest': { + 'deep_array': ['a', 'b', 'c'], + } + }) + self._alert.created = datetime(2019, 1, 1) + self._alert.publishers = [TestDefaultPublisher.PUBLISHER_NAME, self.PUBLISHER_NAME] + self._output = MagicMock(spec=SlackOutput) # Just use some random output + + def test_publish(self): + """AlertPublisher - StringifyArrays - publish""" + publication = compose_alert(self._alert, self._output, 'test') + + expectation = { + 'not_array': ['a', {'b': 'c'}, 'd'], + 'array': 'a\nb\nc', + 'nest': {'deep_array': 'a\nb\nc'} + } + + assert_equal(publication['context'], expectation) + + +def test_stringifyarrays_is_scalar_array_none(): + """AlertPublisher - StringifyArrays - is_scalar_array - None""" + assert_false(StringifyArrays.is_scalar_array(None)) + + +def test_stringifyarrays_is_scalar_array_dict(): + """AlertPublisher - StringifyArrays - is_scalar_array - Dict""" + assert_false(StringifyArrays.is_scalar_array({'a': 'b'})) + + +def test_stringifyarrays_is_scalar_array_string(): + """AlertPublisher - StringifyArrays - is_scalar_array - String""" + assert_false(StringifyArrays.is_scalar_array('aaa')) + + +def test_stringifyarrays_is_scalar_array_array_string(): + """AlertPublisher - StringifyArrays - is_scalar_array - Array[str]""" + assert_true(StringifyArrays.is_scalar_array(['a', 'b'])) + + +def test_stringifyarrays_is_scalar_array_array_int(): + """AlertPublisher - StringifyArrays - is_scalar_array - Array[int]""" + assert_true(StringifyArrays.is_scalar_array([1, 2])) + + +def test_stringifyarrays_is_scalar_array_array_mixed(): + """AlertPublisher - StringifyArrays - is_scalar_array - Array[mixed]""" + assert_true(StringifyArrays.is_scalar_array([1, 'a'])) + + +def test_stringifyarrays_is_scalar_array_array_mixed_invalid(): + """AlertPublisher - StringifyArrays - is_scalar_array - Array[mixed], invalid""" + assert_false(StringifyArrays.is_scalar_array([1, 'a', {}])) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py b/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py index a37650967..640d89629 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/credentials/test_provider.py @@ -392,6 +392,7 @@ def test_save_credentials_into_s3(self): assert_equal(loaded_creds, creds) + @mock_s3 def test_save_credentials_into_s3_blank_credentials(self): """S3Driver - Save Credentials does nothing when Credentials are Blank""" input_credentials = Credentials('', is_encrypted=False, region=REGION) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_demisto.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_demisto.py index 9d1d12a0c..9ed5e4cfe 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_demisto.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_demisto.py @@ -20,6 +20,7 @@ from mock import patch, Mock, MagicMock from nose.tools import assert_is_instance, assert_true, assert_false, assert_equal +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.demisto import DemistoOutput, DemistoRequestAssembler from stream_alert.alert_processor.outputs.output_base import OutputRequestFailure @@ -76,6 +77,7 @@ {'type': 'staged', 'value': 'False'}, ] + class TestDemistoOutput(object): """Test class for SlackOutput""" DESCRIPTOR = 'unit_test_demisto' @@ -126,6 +128,7 @@ def test_dispatch(self, request_mock): 'details': 'Info about this rule and what actions to take', 'createInvestigation': True, } + class Matcher(object): def __eq__(self, other): if other == expected_data: @@ -180,7 +183,8 @@ def test_assemble(): alert = get_alert(context=SAMPLE_CONTEXT) alert.created = datetime(2019, 1, 1) - alert_publication = alert.publish_for(None, None) # FIXME (derek.wang) + output = MagicMock(spec=DemistoOutput) + alert_publication = compose_alert(alert, output, 'asdf') request = DemistoRequestAssembler.assemble(alert, alert_publication) diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py index 1d88d1d2e..cebd0948a 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_pagerduty.py @@ -13,16 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. """ -# pylint: disable=protected-access,attribute-defined-outside-init -from mock import patch, PropertyMock, Mock, MagicMock +# pylint: disable=protected-access,attribute-defined-outside-init,too-many-lines,invalid-name +from collections import OrderedDict +from mock import patch, PropertyMock, Mock, MagicMock, call from nose.tools import assert_equal, assert_false, assert_true -# import cProfile, pstats, StringIO +from stream_alert.alert_processor.outputs.output_base import OutputDispatcher, OutputRequestFailure from stream_alert.alert_processor.outputs.pagerduty import ( PagerDutyOutput, PagerDutyOutputV2, - PagerDutyIncidentOutput -) + PagerDutyIncidentOutput, + WorkContext, PagerDutyRestApiClient, JsonHttpProvider) from tests.unit.stream_alert_alert_processor.helpers import get_alert @@ -64,6 +65,34 @@ def test_dispatch_success(self, post_mock, log_mock): log_mock.assert_called_with('Successfully sent alert to %s:%s', self.SERVICE, self.DESCRIPTOR) + post_mock.assert_called_with( + 'http://pagerduty.foo.bar/create_event.json', + headers=None, + json={ + 'client_url': '', + 'event_type': 'trigger', + 'contexts': [], + 'client': 'streamalert', + 'details': { + 'record': { + 'compressed_size': '9982', + 'node_id': '1', + 'cb_server': 'cbserver', + 'timestamp': '1496947381.18', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }, + 'description': 'Info about this rule and what actions to take' + }, + 'service_key': 'mocked_service_key', + 'description': 'StreamAlert Rule Triggered - cb_binarystore_file_added' + }, + timeout=3.05, + verify=True + ) + @patch('logging.Logger.error') @patch('requests.post') def test_dispatch_failure(self, post_mock, log_mock): @@ -109,6 +138,48 @@ def test_get_default_properties(self): assert_equal(len(props), 1) assert_equal(props['url'], 'https://events.pagerduty.com/v2/enqueue') + @patch('requests.post') + def test_dispatch_sends_correct_request(self, post_mock): + """PagerDutyOutputV2 - Dispatch Sends Correct Request""" + post_mock.return_value.status_code = 200 + + self._dispatcher.dispatch(get_alert(), self.OUTPUT) + + post_mock.assert_called_with( + 'http://pagerduty.foo.bar/create_event.json', + headers=None, + json={ + 'event_action': 'trigger', + 'client': 'StreamAlert', + 'client_url': None, + 'payload': { + 'custom_details': OrderedDict( + [ + ('description', 'Info about this rule and what actions to take'), + ('record', { + 'compressed_size': '9982', 'node_id': '1', 'cb_server': 'cbserver', + 'timestamp': '1496947381.18', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }) + ] + ), + 'source': 'carbonblack:binarystore.file.added', + 'severity': 'critical', + 'summary': 'StreamAlert Rule Triggered - cb_binarystore_file_added', + 'component': None, + 'group': None, + 'class': None, + }, + 'routing_key': 'mocked_routing_key', + 'images': [], + 'links': [], + }, + timeout=3.05, verify=True + ) + @patch('logging.Logger.info') @patch('requests.post') def test_dispatch_success(self, post_mock, log_mock): @@ -141,7 +212,7 @@ def test_dispatch_bad_descriptor(self, log_mock): log_mock.assert_called_with('Failed to send alert to %s:%s', self.SERVICE, 'bad_descriptor') -#pylint: disable=too-many-public-methods +# pylint: disable=too-many-public-methods @patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) @patch('stream_alert.alert_processor.outputs.pagerduty.PagerDutyIncidentOutput.BACKOFF_MAX', 0) @patch('stream_alert.alert_processor.outputs.pagerduty.PagerDutyIncidentOutput.BACKOFF_TIME', 0) @@ -177,296 +248,318 @@ def test_get_default_properties(self): assert_equal(len(props), 1) assert_equal(props['api'], 'https://api.pagerduty.com') - def test_get_endpoint(self): - """PagerDutyIncidentOutput - Get Endpoint""" - endpoint = self._dispatcher._get_endpoint(self.CREDS['api'], 'testtest') - assert_equal(endpoint, 'https://api.pagerduty.com/testtest') - - @patch('requests.get') - def test_check_exists_get_id(self, get_mock): - """PagerDutyIncidentOutput - Check Exists Get ID""" - # GET /check - get_mock.return_value.status_code = 200 - json_check = {'check': [{'id': 'checked_id'}]} - get_mock.return_value.json.return_value = json_check - - checked = self._dispatcher._check_exists('filter', 'http://mock_url', 'check') - assert_equal(checked, 'checked_id') - - @patch('requests.get') - def test_check_exists_get_id_fail(self, get_mock): - """PagerDutyIncidentOutput - Check Exists Get Id Fail""" - get_mock.return_value.status_code = 200 - get_mock.return_value.json.return_value = dict() - - checked = self._dispatcher._check_exists('filter', 'http://mock_url', 'check') - assert_false(checked) - - @patch('requests.get') - def test_check_exists_no_get_id(self, get_mock): - """PagerDutyIncidentOutput - Check Exists No Get Id""" - # GET /check - get_mock.return_value.status_code = 200 - json_check = {'check': [{'id': 'checked_id'}]} - get_mock.return_value.json.return_value = json_check - - assert_true(self._dispatcher._check_exists('filter', 'http://mock_url', 'check', False)) - - @patch('requests.get') - def test_user_verify_success(self, get_mock): - """PagerDutyIncidentOutput - User Verify Success""" - get_mock.return_value.status_code = 200 - json_check = {'users': [{'id': 'verified_user_id'}]} - get_mock.return_value.json.return_value = json_check - - user_verified = self._dispatcher._user_verify('valid_user') - assert_equal(user_verified['id'], 'verified_user_id') - assert_equal(user_verified['type'], 'user_reference') - - @patch('requests.get') - def test_user_verify_fail(self, get_mock): - """PagerDutyIncidentOutput - User Verify Fail""" - get_mock.return_value.status_code = 200 - json_check = {'not_users': [{'not_id': 'verified_user_id'}]} - get_mock.return_value.json.return_value = json_check - - user_verified = self._dispatcher._user_verify('valid_user') - assert_false(user_verified) - - @patch('requests.get') - def test_policy_verify_success_no_default(self, get_mock): - """PagerDutyIncidentOutput - Policy Verify Success (No Default)""" - # GET /escalation_policies - get_mock.return_value.status_code = 200 - json_check = {'escalation_policies': [{'id': 'good_policy_id'}]} - get_mock.return_value.json.return_value = json_check - - policy_verified = self._dispatcher._policy_verify('valid_policy', '') - assert_equal(policy_verified['id'], 'good_policy_id') - assert_equal(policy_verified['type'], 'escalation_policy_reference') - - @patch('requests.get') - def test_policy_verify_success_default(self, get_mock): - """PagerDutyIncidentOutput - Policy Verify Success (Default)""" - # GET /escalation_policies - type(get_mock.return_value).status_code = PropertyMock(side_effect=[200, 200]) - json_check_bad = {'no_escalation_policies': [{'id': 'bad_policy_id'}]} - json_check_good = {'escalation_policies': [{'id': 'good_policy_id'}]} - get_mock.return_value.json.side_effect = [json_check_bad, json_check_good] - - policy_verified = self._dispatcher._policy_verify('valid_policy', 'default_policy') - assert_equal(policy_verified['id'], 'good_policy_id') - assert_equal(policy_verified['type'], 'escalation_policy_reference') - + @patch('requests.put') + @patch('requests.post') @patch('requests.get') - def test_policy_verify_fail_default(self, get_mock): - """PagerDutyIncidentOutput - Policy Verify Fail (Default)""" - # GET /not_escalation_policies - type(get_mock.return_value).status_code = PropertyMock(side_effect=[400, 400]) - json_check_bad = {'escalation_policies': [{'id': 'bad_policy_id'}]} - json_check_bad_default = {'escalation_policies': [{'id': 'good_policy_id'}]} - get_mock.return_value.json.side_effect = [json_check_bad, json_check_bad_default] + def test_dispatch_sends_correct_create_request(self, get_mock, post_mock, put_mock): + """PagerDutyIncidentOutput - Dispatch Success, Good User, Sends Correct Create Request""" + # GET /users, /users + json_user = {'users': [{'id': 'valid_user_id'}]} - assert_false(self._dispatcher._policy_verify('valid_policy', 'default_policy')) + # GET /incidents + json_lookup = {'incidents': [{'id': 'incident_id'}]} - @patch('requests.get') - def test_policy_verify_fail_no_default(self, get_mock): - """PagerDutyIncidentOutput - Policy Verify Fail (No Default)""" - # GET /not_escalation_policies get_mock.return_value.status_code = 200 - json_check = {'not_escalation_policies': [{'not_id': 'verified_policy_id'}]} - get_mock.return_value.json.return_value = json_check - - assert_false(self._dispatcher._policy_verify('valid_policy', 'default_policy')) + get_mock.return_value.json.side_effect = [json_user, json_user, json_lookup] - @patch('requests.get') - def test_service_verify_success(self, get_mock): - """PagerDutyIncidentOutput - Service Verify Success""" - # GET /services - get_mock.return_value.status_code = 200 - json_check = {'services': [{'id': 'verified_service_id'}]} - get_mock.return_value.json.return_value = json_check + # POST /incidents, /v2/enqueue, /incidents/incident_id/notes + post_mock.return_value.status_code = 200 + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + json_note = {'note': {'id': 'note_id'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - service_verified = self._dispatcher._service_verify('valid_service') - assert_equal(service_verified['id'], 'verified_service_id') - assert_equal(service_verified['type'], 'service_reference') + # PUT /incidents/indicent_id/merge + put_mock.return_value.status_code = 200 - @patch('requests.get') - def test_service_verify_fail(self, get_mock): - """PagerDutyIncidentOutput - Service Verify Fail""" - get_mock.return_value.status_code = 200 - json_check = {'not_services': [{'not_id': 'verified_service_id'}]} - get_mock.return_value.json.return_value = json_check + ctx = {'pagerduty-incident': {'assigned_user': 'valid_user'}} - assert_false(self._dispatcher._service_verify('valid_service')) + self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT) + + # Useful tidbit; for writing fixtures for implement multiple sequential calls, you can use + # mock.assert_has_calls() to render out all of the calls in order: + # post_mock.assert_has_calls([call()]) + post_mock.assert_any_call( + 'https://api.pagerduty.com/incidents', + headers={ + 'From': 'email@domain.com', + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + json={ + 'incident': { + 'body': { + 'type': 'incident_body', + 'details': 'Info about this rule and what actions to take' + }, + 'service': { + 'type': 'service_reference', + 'id': 'mocked_service_id' + }, + 'title': 'StreamAlert Incident - Rule triggered: cb_binarystore_file_added', + 'priority': {}, + 'assignments': [ + { + 'assignee': { + 'type': 'user_reference', 'id': 'valid_user_id' + } + } + ], + 'type': 'incident', + 'incident_key': '', + } + }, + timeout=3.05, verify=False + ) + @patch('stream_alert.alert_processor.outputs.pagerduty.compose_alert') + @patch('requests.put') + @patch('requests.post') @patch('requests.get') - def test_item_verify_success(self, get_mock): - """PagerDutyIncidentOutput - Item Verify Success""" - # GET /items - get_mock.return_value.status_code = 200 - json_check = {'items': [{'id': 'verified_item_id'}]} - get_mock.return_value.json.return_value = json_check + def test_dispatch_sends_correct_with_urgency(self, get_mock, post_mock, put_mock, + compose_alert): + """PagerDutyIncidentOutput - Dispatch Success, Good User, Sends Correct Urgency""" + compose_alert.return_value = { + '@pagerduty-incident.urgency': 'low' + } - item_verified = self._dispatcher._item_verify('valid_item', 'items', 'item_reference') + # GET /users, /users + json_user = {'users': [{'id': 'valid_user_id'}]} - assert_equal(item_verified['id'], 'verified_item_id') - assert_equal(item_verified['type'], 'item_reference') + # GET /incidents + json_lookup = {'incidents': [{'id': 'incident_id'}]} - @patch('requests.get') - def test_item_verify_no_get_id_success(self, get_mock): - """PagerDutyIncidentOutput - Item Verify No Get Id Success""" - # GET /items get_mock.return_value.status_code = 200 - json_check = {'items': [{'id': 'verified_item_id'}]} - get_mock.return_value.json.return_value = json_check + get_mock.return_value.json.side_effect = [json_user, json_user, json_lookup] - assert_true(self._dispatcher._item_verify('valid_item', 'items', 'item_reference', False)) + # POST /incidents, /v2/enqueue, /incidents/incident_id/notes + post_mock.return_value.status_code = 200 + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + json_note = {'note': {'id': 'note_id'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - @patch('requests.get') - def test_priority_verify_success(self, get_mock): - """PagerDutyIncidentOutput - Priority Verify Success""" - priority_name = 'priority_name' - # GET /priorities - get_mock.return_value.status_code = 200 - json_check = {'priorities': [{'id': 'verified_priority_id', 'name': priority_name}]} - get_mock.return_value.json.return_value = json_check + # PUT /incidents/indicent_id/merge + put_mock.return_value.status_code = 200 - context = {'incident_priority': priority_name} + ctx = {'pagerduty-incident': {'assigned_user': 'valid_user'}} - priority_verified = self._dispatcher._priority_verify(context) - assert_equal(priority_verified['id'], 'verified_priority_id') - assert_equal(priority_verified['type'], 'priority_reference') + self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT) + + # Useful tidbit; for writing fixtures for implement multiple sequential calls, you can use + # mock.assert_has_calls() to render out all of the calls in order: + # post_mock.assert_has_calls([call()]) + post_mock.assert_any_call( + 'https://api.pagerduty.com/incidents', + headers={ + 'From': 'email@domain.com', + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + json={ + 'incident': { + 'body': { + 'type': 'incident_body', + 'details': 'Info about this rule and what actions to take' + }, + 'service': { + 'type': 'service_reference', + 'id': 'mocked_service_id' + }, + 'title': 'StreamAlert Incident - Rule triggered: cb_binarystore_file_added', + 'priority': {}, + 'assignments': [ + { + 'assignee': { + 'type': 'user_reference', 'id': 'valid_user_id' + } + } + ], + 'urgency': 'low', + 'type': 'incident', + 'incident_key': '', + } + }, + timeout=3.05, verify=False + ) + @patch('logging.Logger.warn') + @patch('stream_alert.alert_processor.outputs.pagerduty.compose_alert') + @patch('requests.put') + @patch('requests.post') @patch('requests.get') - def test_priority_verify_fail(self, get_mock): - """PagerDutyIncidentOutput - Priority Verify Fail""" - # GET /priorities - get_mock.return_value.status_code = 404 + def test_dispatch_sends_correct_bad_urgency(self, get_mock, post_mock, put_mock, + compose_alert, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, Good User, Sends Correct Urgency""" + compose_alert.return_value = { + '@pagerduty-incident.urgency': 'asdf' + } - context = {'incident_priority': 'priority_name'} + # GET /users, /users + json_user = {'users': [{'id': 'valid_user_id'}]} - priority_not_verified = self._dispatcher._priority_verify(context) - assert_equal(priority_not_verified, dict()) + # GET /incidents + json_lookup = {'incidents': [{'id': 'incident_id'}]} - @patch('requests.get') - def test_priority_verify_empty(self, get_mock): - """PagerDutyIncidentOutput - Priority Verify Empty""" - # GET /priorities get_mock.return_value.status_code = 200 - json_check = {} - get_mock.return_value.json.return_value = json_check - - context = {'incident_priority': 'priority_name'} + get_mock.return_value.json.side_effect = [json_user, json_user, json_lookup] - priority_not_verified = self._dispatcher._priority_verify(context) - assert_equal(priority_not_verified, dict()) + # POST /incidents, /v2/enqueue, /incidents/incident_id/notes + post_mock.return_value.status_code = 200 + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + json_note = {'note': {'id': 'note_id'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - @patch('requests.get') - def test_priority_verify_not_found(self, get_mock): - """PagerDutyIncidentOutput - Priority Verify Not Found""" - # GET /priorities - get_mock.return_value.status_code = 200 - json_check = {'priorities': [{'id': 'verified_priority_id', 'name': 'not_priority_name'}]} - get_mock.return_value.json.return_value = json_check + # PUT /incidents/indicent_id/merge + put_mock.return_value.status_code = 200 - context = {'incident_priority': 'priority_name'} + ctx = {'pagerduty-incident': {'assigned_user': 'valid_user'}} - priority_not_verified = self._dispatcher._priority_verify(context) - assert_equal(priority_not_verified, dict()) + self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT) + + # Useful tidbit; for writing fixtures for implement multiple sequential calls, you can use + # mock.assert_has_calls() to render out all of the calls in order: + # post_mock.assert_has_calls([call()]) + post_mock.assert_any_call( + 'https://api.pagerduty.com/incidents', + headers={ + 'From': 'email@domain.com', + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + json={ + 'incident': { + 'body': { + 'type': 'incident_body', + 'details': 'Info about this rule and what actions to take' + }, + 'service': { + 'type': 'service_reference', + 'id': 'mocked_service_id' + }, + 'title': 'StreamAlert Incident - Rule triggered: cb_binarystore_file_added', + 'priority': {}, + 'assignments': [ + { + 'assignee': { + 'type': 'user_reference', 'id': 'valid_user_id' + } + } + ], + 'type': 'incident', + 'incident_key': '', + } + }, + timeout=3.05, verify=False + ) + log_mock.assert_called_with('[%s] Invalid pagerduty incident urgency: "%s"', + 'pagerduty-incident', 'asdf') + @patch('requests.put') + @patch('requests.post') @patch('requests.get') - def test_priority_verify_invalid(self, get_mock): - """PagerDutyIncidentOutput - Priority Verify Invalid""" - # GET /priorities - get_mock.return_value.status_code = 200 - json_check = {'not_priorities': [{'id': 'verified_priority_id', 'name': 'priority_name'}]} - get_mock.return_value.json.return_value = json_check - - context = {'incident_priority': 'priority_name'} + def test_dispatch_sends_correct_enqueue_event_request(self, get_mock, post_mock, put_mock): + """PagerDutyIncidentOutput - Dispatch Success, Good User, Sends Correct Event Request""" + # GET /users, /users + json_user = {'users': [{'id': 'valid_user_id'}]} - priority_not_verified = self._dispatcher._priority_verify(context) - assert_equal(priority_not_verified, dict()) + # GET /incidents + json_lookup = {'incidents': [{'id': 'incident_id'}]} - @patch('requests.get') - def test_incident_assignment_user(self, get_mock): - """PagerDutyIncidentOutput - Incident Assignment User""" - context = {'assigned_user': 'user_to_assign'} get_mock.return_value.status_code = 200 - json_user = {'users': [{'id': 'verified_user_id'}]} - get_mock.return_value.json.return_value = json_user - - assigned_key, assigned_value = self._dispatcher._incident_assignment(context) - - assert_equal(assigned_key, 'assignments') - assert_equal(assigned_value[0]['assignee']['id'], 'verified_user_id') - assert_equal(assigned_value[0]['assignee']['type'], 'user_reference') - - def test_incident_assignment_policy_no_default(self): - """PagerDutyIncidentOutput - Incident Assignment Policy (No Default)""" - context = {'assigned_policy_id': 'policy_id_to_assign'} - - assigned_key, assigned_value = self._dispatcher._incident_assignment(context) - - assert_equal(assigned_key, 'escalation_policy') - assert_equal(assigned_value['id'], 'policy_id_to_assign') - assert_equal(assigned_value['type'], 'escalation_policy_reference') + get_mock.return_value.json.side_effect = [json_user, json_user, json_lookup] - @patch('requests.post') - def test_add_note_incident_success(self, post_mock): - """PagerDutyIncidentOutput - Add Note to Incident Success""" + # POST /incidents, /v2/enqueue, /incidents/incident_id/notes post_mock.return_value.status_code = 200 - json_note = {'note': {'id': 'created_note_id'}} - post_mock.return_value.json.return_value = json_note - - note_id = self._dispatcher._add_incident_note('incident_id', 'this is the note') - - assert_equal(note_id, 'created_note_id') + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + json_note = {'note': {'id': 'note_id'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - @patch('requests.post') - def test_add_note_incident_fail(self, post_mock): - """PagerDutyIncidentOutput - Add Note to Incident Fail""" - post_mock.return_value.status_code = 200 - json_note = {'note': {'not_id': 'created_note_id'}} - post_mock.return_value.json.return_value = json_note + # PUT /incidents/indicent_id/merge + put_mock.return_value.status_code = 200 - note_id = self._dispatcher._add_incident_note('incident_id', 'this is the note') + ctx = {'pagerduty-incident': {'assigned_user': 'valid_user'}} - assert_false(note_id) + self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT) + + # post_mock.assert_has_calls([call()]) + post_mock.assert_any_call( + 'https://events.pagerduty.com/v2/enqueue', + headers=None, + json={ + 'event_action': 'trigger', + 'client': 'StreamAlert', + 'client_url': None, + 'payload': { + 'custom_details': OrderedDict( + [ + ('description', 'Info about this rule and what actions to take'), + ('record', { + 'compressed_size': '9982', 'node_id': '1', 'cb_server': 'cbserver', + 'timestamp': '1496947381.18', + 'md5': '0F9AA55DA3BDE84B35656AD8911A22E1', + 'type': 'binarystore.file.added', + 'file_path': '/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip', + 'size': '21504' + }) + ] + ), + 'source': 'carbonblack:binarystore.file.added', + 'severity': 'critical', + 'summary': 'StreamAlert Rule Triggered - cb_binarystore_file_added', + 'component': None, + 'group': None, + 'class': None, + }, + 'routing_key': 'mocked_key', + 'images': [], + 'links': [], + }, + timeout=3.05, verify=True + ) + @patch('requests.put') @patch('requests.post') - def test_add_note_incident_bad_request(self, post_mock): - """PagerDutyIncidentOutput - Add Note to Incident Bad Request""" - post_mock.return_value.status_code = 400 - json_note = {'note': {'id': 'created_note_id'}} - post_mock.return_value.json.return_value = json_note + @patch('requests.get') + def test_dispatch_sends_correct_merge_request(self, get_mock, post_mock, put_mock): + """PagerDutyIncidentOutput - Dispatch Success, Good User, Sends Correct Merge Request""" + # GET /users, /users + json_user = {'users': [{'id': 'valid_user_id'}]} - note_id = self._dispatcher._add_incident_note('incident_id', 'this is the note') + # GET /incidents + json_lookup = {'incidents': [{'id': 'incident_id'}]} - assert_false(note_id) + get_mock.return_value.status_code = 200 + get_mock.return_value.json.side_effect = [json_user, json_user, json_lookup] - @patch('requests.post') - def test_add_note_incident_no_response(self, post_mock): - """PagerDutyIncidentOutput - Add Note to Incident No Response""" + # POST /incidents, /v2/enqueue, /incidents/incident_id/notes post_mock.return_value.status_code = 200 - json_note = {} - post_mock.return_value.json.return_value = json_note + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + json_note = {'note': {'id': 'note_id'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - note_id = self._dispatcher._add_incident_note('incident_id', 'this is the note') + # PUT /incidents/indicent_id/merge + put_mock.return_value.status_code = 200 - assert_false(note_id) + ctx = {'pagerduty-incident': {'assigned_user': 'valid_user'}} - @patch('requests.get') - def test_item_verify_fail(self, get_mock): - """PagerDutyIncidentOutput - Item Verify Fail""" - # /not_items - get_mock.return_value.status_code = 200 - json_check = {'not_items': [{'not_id': 'verified_item_id'}]} - get_mock.return_value.json.return_value = json_check + self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT) - item_verified = self._dispatcher._item_verify('http://mock_url', 'valid_item', - 'items', 'item_reference') - assert_false(item_verified) + put_mock.assert_called_with( + 'https://api.pagerduty.com/incidents/incident_id/merge', + headers={'From': 'email@domain.com', 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2'}, + json={'source_incidents': [{'type': 'incident_reference', 'id': 'incident_id'}]}, + timeout=3.05, + verify=False + ) @patch('logging.Logger.info') @patch('requests.put') @@ -506,7 +599,7 @@ def test_dispatch_success_good_user(self, get_mock, post_mock, put_mock, log_moc @patch('requests.get') def test_dispatch_success_good_policy(self, get_mock, post_mock, put_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success, Good Policy""" - # GET /users + # GET /users json_user = {'users': [{'id': 'user_id'}]} # GET /incidents @@ -522,7 +615,7 @@ def test_dispatch_success_good_policy(self, get_mock, post_mock, put_mock, log_m json_note = {'note': {'id': 'note_id'}} post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - # PUT /incidents/indicent_id/merge + # PUT /incidents/{incident_id}/merge put_mock.return_value.status_code = 200 ctx = {'pagerduty-incident': {'assigned_policy_id': 'valid_policy_id'}} @@ -538,23 +631,45 @@ def test_dispatch_success_good_policy(self, get_mock, post_mock, put_mock, log_m @patch('requests.get') def test_dispatch_success_with_priority(self, get_mock, post_mock, put_mock, log_mock): """PagerDutyIncidentOutput - Dispatch Success With Priority""" - # GET /priorities, /users + # GET /priorities, /users json_user = {'users': [{'id': 'user_id'}]} json_priority = {'priorities': [{'id': 'priority_id', 'name': 'priority_name'}]} - - # GET /incidents json_lookup = {'incidents': [{'id': 'incident_id'}]} + + def setup_post_mock(mock, json_incident, json_event, json_note): + def post(*args, **_): + url = args[0] + if url == 'https://api.pagerduty.com/incidents': + response = json_incident + elif url == 'https://events.pagerduty.com/v2/enqueue': + response = json_event + elif ( + url.startswith('https://api.pagerduty.com/incidents/') and + url.endswith('/notes') + ): + response = json_note + else: + raise RuntimeError('Misconfigured mock: {}'.format(url)) + + _mock = MagicMock() + _mock.status_code = 200 + _mock.json.return_value = response + return _mock + + mock.side_effect = post + get_mock.return_value.status_code = 200 get_mock.return_value.json.side_effect = [json_user, json_priority, json_lookup] - # POST /incidents, /v2/enqueue, /incidents/incident_id/notes - post_mock.return_value.status_code = 200 + # POST /incidents, /v2/enqueue, /incidents/{incident_id}/notes json_incident = {'incident': {'id': 'incident_id'}} json_event = {'dedup_key': 'returned_dedup_key'} json_note = {'note': {'id': 'note_id'}} - post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] + # post_mock.return_value.status_code = 200 + # post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] + setup_post_mock(post_mock, json_incident, json_event, json_note) - # PUT /incidents/indicent_id/merge + # PUT /incidents/{incident_id}/merge put_mock.return_value.status_code = 200 ctx = { @@ -573,29 +688,150 @@ def test_dispatch_success_with_priority(self, get_mock, post_mock, put_mock, log @patch('requests.put') @patch('requests.post') @patch('requests.get') - def test_dispatch_success_bad_user(self, get_mock, post_mock, put_mock, log_mock): - """PagerDutyIncidentOutput - Dispatch Success, Bad User""" - # GET /users, /users + def test_dispatch_success_with_note(self, get_mock, post_mock, put_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success With Note""" + # GET /priorities, /users json_user = {'users': [{'id': 'user_id'}]} - json_not_user = {'not_users': [{'id': 'user_id'}]} - - # GET /incidents json_lookup = {'incidents': [{'id': 'incident_id'}]} + def setup_post_mock(mock, json_incident, json_event, json_note): + def post(*args, **_): + url = args[0] + if url == 'https://api.pagerduty.com/incidents': + response = json_incident + elif url == 'https://events.pagerduty.com/v2/enqueue': + response = json_event + elif ( + url.startswith('https://api.pagerduty.com/incidents/') and + url.endswith('/notes') + ): + response = json_note + else: + raise RuntimeError('Misconfigured mock: {}'.format(url)) + + _mock = MagicMock() + _mock.status_code = 200 + _mock.json.return_value = response + return _mock + + mock.side_effect = post + get_mock.return_value.status_code = 200 - get_mock.return_value.json.side_effect = [json_user, json_not_user, json_lookup] + get_mock.return_value.json.side_effect = [json_user, json_lookup] - # POST /incidents, /v2/enqueue, /incidents/incident_id/notes - post_mock.return_value.status_code = 200 + # POST /incidents, /v2/enqueue, /incidents/{incident_id}/notes json_incident = {'incident': {'id': 'incident_id'}} json_event = {'dedup_key': 'returned_dedup_key'} json_note = {'note': {'id': 'note_id'}} - post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] + setup_post_mock(post_mock, json_incident, json_event, json_note) - # PUT /incidents/indicent_id/merge + # PUT /incidents/{incident_id}/merge put_mock.return_value.status_code = 200 - ctx = {'pagerduty-incident': {'assigned_user': 'invalid_user'}} + ctx = { + 'pagerduty-incident': { + 'note': 'This is just a note' + } + } + + assert_true(self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT)) + + # post_mock.assert_has_calls([call()]) + post_mock.assert_any_call( + 'https://api.pagerduty.com/incidents/incident_id/notes', + headers={'From': 'email@domain.com', + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2'}, + json={'note': {'content': 'This is just a note'}}, + timeout=3.05, verify=False + ) + + log_mock.assert_called_with('Successfully sent alert to %s:%s', + self.SERVICE, self.DESCRIPTOR) + + @patch('logging.Logger.info') + @patch('requests.put') + @patch('requests.post') + @patch('requests.get') + def test_dispatch_success_none_note(self, get_mock, post_mock, put_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success With No Note""" + # GET /priorities, /users + json_user = {'users': [{'id': 'user_id'}]} + json_lookup = {'incidents': [{'id': 'incident_id'}]} + + def setup_post_mock(mock, json_incident, json_event): + def post(*args, **_): + url = args[0] + if url == 'https://api.pagerduty.com/incidents': + response = json_incident + elif url == 'https://events.pagerduty.com/v2/enqueue': + response = json_event + elif ( + url.startswith('https://api.pagerduty.com/incidents/') and + url.endswith('/notes') + ): + # assert the /notes endpoint is never called + raise RuntimeError('This endpoint is not intended to be called') + else: + raise RuntimeError('Misconfigured mock: {}'.format(url)) + + _mock = MagicMock() + _mock.status_code = 200 + _mock.json.return_value = response + return _mock + + mock.side_effect = post + + get_mock.return_value.status_code = 200 + get_mock.return_value.json.side_effect = [json_user, json_lookup] + + # POST /incidents, /v2/enqueue, /incidents/{incident_id}/notes + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + setup_post_mock(post_mock, json_incident, json_event) + + # PUT /incidents/{incident_id}/merge + put_mock.return_value.status_code = 200 + + ctx = { + 'pagerduty-incident': { + 'note': None + } + } + + assert_true(self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT)) + + log_mock.assert_called_with('Successfully sent alert to %s:%s', + self.SERVICE, self.DESCRIPTOR) + + @patch('logging.Logger.info') + @patch('requests.put') + @patch('requests.post') + @patch('requests.get') + def test_dispatch_success_bad_user(self, get_mock, post_mock, put_mock, log_mock): + """PagerDutyIncidentOutput - Dispatch Success, Bad User""" + # GET /users, /users + json_user = {'users': [{'id': 'user_id'}]} + json_not_user = {'not_users': [{'id': 'user_id'}]} + + # GET /incidents + json_lookup = {'incidents': [{'id': 'incident_id'}]} + + get_mock.return_value.status_code = 200 + get_mock.return_value.json.side_effect = [json_user, json_not_user, json_lookup] + + # POST /incidents, /v2/enqueue, /incidents/incident_id/notes + post_mock.return_value.status_code = 200 + json_incident = {'incident': {'id': 'incident_id'}} + json_event = {'dedup_key': 'returned_dedup_key'} + json_note = {'note': {'id': 'note_id'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] + + # PUT /incidents/indicent_id/merge + put_mock.return_value.status_code = 200 + + ctx = {'pagerduty-incident': {'assigned_user': 'invalid_user'}} assert_true(self._dispatcher.dispatch(get_alert(context=ctx), self.OUTPUT)) @@ -667,9 +903,10 @@ def test_dispatch_success_no_merge_response(self, get_mock, post_mock, put_mock, post_mock.return_value.status_code = 200 json_incident = {'incident': {'id': 'incident_id'}} json_event = {'dedup_key': 'returned_dedup_key'} - post_mock.return_value.json.side_effect = [json_incident, json_event] + json_note = {'note': {'aa'}} + post_mock.return_value.json.side_effect = [json_incident, json_event, json_note] - # PUT /incidents/indicent_id/merge + # PUT /incidents/{incident_id}/merge put_mock.return_value.status_code = 200 put_mock.return_value.json.return_value = {} @@ -755,7 +992,6 @@ def test_dispatch_bad_dispatch(self, get_mock, post_mock, log_mock): log_mock.assert_called_with('Failed to send alert to %s:%s', self.SERVICE, self.DESCRIPTOR) - @patch('logging.Logger.error') @patch('requests.get') def test_dispatch_bad_email(self, get_mock, log_mock): @@ -776,3 +1012,504 @@ def test_dispatch_bad_descriptor(self, log_mock): self._dispatcher.dispatch(get_alert(), ':'.join([self.SERVICE, 'bad_descriptor']))) log_mock.assert_called_with('Failed to send alert to %s:%s', self.SERVICE, 'bad_descriptor') + + +@patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) +@patch('stream_alert.alert_processor.outputs.pagerduty.PagerDutyIncidentOutput.BACKOFF_MAX', 0) +@patch('stream_alert.alert_processor.outputs.pagerduty.PagerDutyIncidentOutput.BACKOFF_TIME', 0) +class TestWorkContext(object): + """Test class for WorkContext""" + DESCRIPTOR = 'unit_test_pagerduty-incident' + SERVICE = 'pagerduty-incident' + OUTPUT = ':'.join([SERVICE, DESCRIPTOR]) + CREDS = {'api': 'https://api.pagerduty.com', + 'token': 'mocked_token', + 'service_name': 'mocked_service_name', + 'service_id': 'mocked_service_id', + 'escalation_policy': 'mocked_escalation_policy', + 'escalation_policy_id': 'mocked_escalation_policy_id', + 'email_from': 'email@domain.com', + 'integration_key': 'mocked_key'} + + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, provider_constructor): + """Setup before each method""" + provider = MagicMock() + provider_constructor.return_value = provider + provider.load_credentials = Mock( + side_effect=lambda x: self.CREDS if x == self.DESCRIPTOR else None + ) + dispatcher = PagerDutyIncidentOutput(None) + self._work = WorkContext(dispatcher, self.CREDS) + + @patch('requests.get') + def test_get_standardized_priority_sends_correct_reuqest(self, get_mock): + """PagerDutyIncidentOutput - Priority Verify Sends Correct Request""" + priority_name = 'priority_name' + # GET /priorities + get_mock.return_value.status_code = 200 + context = {'incident_priority': priority_name} + + self._work.get_standardized_priority(context) + + get_mock.assert_called_with( + 'https://api.pagerduty.com/priorities', + headers={ + 'From': 'email@domain.com', + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + params=None, + timeout=3.05, + # verify=False # FIXME (derek.wang) Before the refactor this was False. Why? + verify=True + ) + + @patch('requests.get') + def test_get_standardized_priority_success(self, get_mock): + """PagerDutyIncidentOutput - Priority Verify Success""" + priority_name = 'priority_name' + # GET /priorities + get_mock.return_value.status_code = 200 + json_check = {'priorities': [{'id': 'verified_priority_id', 'name': priority_name}]} + get_mock.return_value.json.return_value = json_check + + context = {'incident_priority': priority_name} + + priority_verified = self._work.get_standardized_priority(context) + assert_equal(priority_verified['id'], 'verified_priority_id') + assert_equal(priority_verified['type'], 'priority_reference') + + @patch('requests.get') + def test_get_standardized_priority_fail(self, get_mock): + """PagerDutyIncidentOutput - Priority Verify Fail""" + # GET /priorities + get_mock.return_value.status_code = 404 + + context = {'incident_priority': 'priority_name'} + + priority_not_verified = self._work.get_standardized_priority(context) + assert_equal(priority_not_verified, dict()) + + @patch('requests.get') + def test_get_standardized_priority_empty(self, get_mock): + """PagerDutyIncidentOutput - Priority Verify Empty""" + # GET /priorities + get_mock.return_value.status_code = 200 + json_check = {} + get_mock.return_value.json.return_value = json_check + + context = {'incident_priority': 'priority_name'} + + priority_not_verified = self._work.get_standardized_priority(context) + assert_equal(priority_not_verified, dict()) + + @patch('requests.get') + def test_get_standardized_priority_not_found(self, get_mock): + """PagerDutyIncidentOutput - Priority Verify Not Found""" + # GET /priorities + get_mock.return_value.status_code = 200 + json_check = {'priorities': [{'id': 'verified_priority_id', 'name': 'not_priority_name'}]} + get_mock.return_value.json.return_value = json_check + + context = {'incident_priority': 'priority_name'} + + priority_not_verified = self._work.get_standardized_priority(context) + assert_equal(priority_not_verified, dict()) + + @patch('requests.get') + def test_get_standardized_priority_invalid(self, get_mock): + """PagerDutyIncidentOutput - Priority Verify Invalid""" + # GET /priorities + get_mock.return_value.status_code = 200 + json_check = {'not_priorities': [{'id': 'verified_priority_id', 'name': 'priority_name'}]} + get_mock.return_value.json.return_value = json_check + + context = {'incident_priority': 'priority_name'} + + priority_not_verified = self._work.get_standardized_priority(context) + assert_equal(priority_not_verified, dict()) + + @patch('requests.get') + def test_get_incident_assignment_user_sends_correct_rquest(self, get_mock): + """PagerDutyIncidentOutput - Incident Assignment User Sends Correct Request""" + context = {'assigned_user': 'user_to_assign'} + get_mock.return_value.status_code = 400 + + self._work.get_incident_assignment(context) + + get_mock.assert_called_with( + 'https://api.pagerduty.com/users', + headers={ + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + params={'query': 'user_to_assign'}, + timeout=3.05, + # verify=False # FIXME (derek.wang) before the refactor, this was False. Why? + verify=True + ) + + @patch('requests.get') + def test_get_incident_assignment_user(self, get_mock): + """PagerDutyIncidentOutput - Incident Assignment User""" + context = {'assigned_user': 'user_to_assign'} + get_mock.return_value.status_code = 200 + json_user = {'users': [{'id': 'verified_user_id'}]} + get_mock.return_value.json.return_value = json_user + + assigned_key, assigned_value = self._work.get_incident_assignment(context) + + assert_equal(assigned_key, 'assignments') + assert_equal(assigned_value[0]['assignee']['id'], 'verified_user_id') + assert_equal(assigned_value[0]['assignee']['type'], 'user_reference') + + def test_get_incident_assignment_policy_no_default(self): + """PagerDutyIncidentOutput - Incident Assignment Policy (No Default)""" + context = {'assigned_policy_id': 'policy_id_to_assign'} + + assigned_key, assigned_value = self._work.get_incident_assignment(context) + + assert_equal(assigned_key, 'escalation_policy') + assert_equal(assigned_value['id'], 'policy_id_to_assign') + assert_equal(assigned_value['type'], 'escalation_policy_reference') + + @patch('requests.get') + def test_user_verify_success(self, get_mock): + """PagerDutyIncidentOutput - User Verify Success""" + get_mock.return_value.status_code = 200 + json_check = {'users': [{'id': 'verified_user_id'}]} + get_mock.return_value.json.return_value = json_check + + user_verified = self._work.verify_user_exists() + assert_true(user_verified) + + @patch('requests.get') + def test_user_verify_fail(self, get_mock): + """PagerDutyIncidentOutput - User Verify Fail""" + get_mock.return_value.status_code = 200 + json_check = {'not_users': [{'not_id': 'verified_user_id'}]} + get_mock.return_value.json.return_value = json_check + + user_verified = self._work.verify_user_exists() + assert_false(user_verified) + + +@patch('stream_alert.alert_processor.outputs.output_base.OutputDispatcher.MAX_RETRY_ATTEMPTS', 1) +@patch('stream_alert.alert_processor.outputs.pagerduty.PagerDutyIncidentOutput.BACKOFF_MAX', 0) +@patch('stream_alert.alert_processor.outputs.pagerduty.PagerDutyIncidentOutput.BACKOFF_TIME', 0) +class TestPagerDutyRestApiClient(object): + + @patch('stream_alert.alert_processor.outputs.output_base.OutputCredentialsProvider') + def setup(self, _): + dispatcher = PagerDutyIncidentOutput(None) + http = JsonHttpProvider(dispatcher) + self._api_client = PagerDutyRestApiClient('mocked_token', 'user@email.com', http) + + @patch('requests.post') + def test_multiple_requests_verify_ssl_once(self, post_mock): + """PagerDutyIncidentOutput - Multiple Requests Verify SSL Once""" + post_mock.return_value.status_code = 200 + + self._api_client.add_note('incident_id', 'this is the note') + self._api_client.add_note('incident_id', 'this is another note') + self._api_client.add_note('incident_id', 'this is a third note') + + class Anything(object): + def __eq__(self, _): + return True + + class VerifyIsCalledWith(object): + def __init__(self, expected_verify_value): + self._expected_verify_value = expected_verify_value + + def __eq__(self, other): + return self._expected_verify_value == other + + post_mock.assert_has_calls( + [ + call( + Anything(), + headers=Anything(), json=Anything(), timeout=Anything(), + verify=VerifyIsCalledWith(True) + ), + call( + Anything(), + headers=Anything(), json=Anything(), timeout=Anything(), + verify=VerifyIsCalledWith(False) + ), + call( + Anything(), + headers=Anything(), json=Anything(), timeout=Anything(), + verify=VerifyIsCalledWith(False) + ), + ], + # So the problem with assert_has_calls() is that it requires you to declare all calls + # including chained calls. This doesn't work because we do a bunch of random stuff + # inbetween with the return value (such as .json() calls) and it's not really feasible + # to declare ALL of the calls. + # + # By setting any_order=True, we ensure all of the above calls are made at least once. + # We lose out on the ability to detect that we called verify=True FIRST (before the + # two verify=False calls)... but, oh well? + any_order=True + ) + + @patch('requests.post') + def test_add_note_incident_sends_correct_request(self, post_mock): + """PagerDutyIncidentOutput - Add Note to Incident Sends Correct Request""" + post_mock.return_value.status_code = 200 + + self._api_client.add_note('incident_id', 'this is the note') + + post_mock.assert_called_with( + 'https://api.pagerduty.com/incidents/incident_id/notes', + headers={ + 'From': 'user@email.com', + 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + json={'note': {'content': 'this is the note'}}, + timeout=3.05, + verify=True + ) + + @patch('requests.post') + def test_add_note_incident_success(self, post_mock): + """PagerDutyIncidentOutput - Add Note to Incident Success""" + post_mock.return_value.status_code = 200 + json_note = {'note': {'id': 'created_note_id'}} + post_mock.return_value.json.return_value = json_note + + note = self._api_client.add_note('incident_id', 'this is the note') + + assert_equal(note.get('id'), 'created_note_id') + + @patch('requests.post') + def test_add_note_incident_fail(self, post_mock): + """PagerDutyIncidentOutput - Add Note to Incident Fail""" + post_mock.return_value.status_code = 200 + json_note = {'note': {'not_id': 'created_note_id'}} + post_mock.return_value.json.return_value = json_note + + note = self._api_client.add_note('incident_id', 'this is the note') + + assert_false(note.get('id')) + + @patch('requests.post') + def test_add_note_incident_bad_request(self, post_mock): + """PagerDutyIncidentOutput - Add Note to Incident Bad Request""" + post_mock.return_value.status_code = 400 + json_note = {'note': {'id': 'created_note_id'}} + post_mock.return_value.json.return_value = json_note + + note = self._api_client.add_note('incident_id', 'this is the note') + + assert_false(note) + + @patch('requests.post') + def test_add_note_incident_no_response(self, post_mock): + """PagerDutyIncidentOutput - Add Note to Incident No Response""" + post_mock.return_value.status_code = 200 + json_note = {} + post_mock.return_value.json.return_value = json_note + + note = self._api_client.add_note('incident_id', 'this is the note') + + assert_false(note) + + @patch('requests.get') + def test_get_escalation_policy_sends_correct_request(self, get_mock): + """PagerDutyIncidentOutput - Get Escalation Policies Success""" + get_mock.return_value.status_code = 200 + + self._api_client.get_escalation_policy_by_id('PDUDOHF') + + get_mock.assert_called_with( + 'https://api.pagerduty.com/escalation_policies', + headers={ + 'From': 'user@email.com', 'Content-Type': 'application/json', + 'Authorization': 'Token token=mocked_token', + 'Accept': 'application/vnd.pagerduty+json;version=2' + }, + params={ + 'query': 'PDUDOHF' + }, + timeout=3.05, verify=True + ) + + @patch('requests.get') + def test_get_escalation_policy_success(self, get_mock): + """PagerDutyIncidentOutput - Get Escalation Policies Success""" + get_mock.return_value.status_code = 200 + json_note = {'escalation_policies': [{'id': 'PDUDOHF'}]} + get_mock.return_value.json.return_value = json_note + + policy = self._api_client.get_escalation_policy_by_id('PDUDOHF') + + assert_equal(policy.get('id'), 'PDUDOHF') + + +class TestJsonHttpProvider(object): + + def setup(self): + self._dispatcher = MagicMock(spec=OutputDispatcher) + self._http = JsonHttpProvider(self._dispatcher) + + def test_get_sends_correct_arguments(self): + """JsonHttpProvider - Get - Arguments""" + self._http.get( + 'http://airbnb.com', + {'q': 'zz'}, + headers={'Accept': 'application/tofu'}, + verify=True + ) + self._dispatcher._get_request_retry.assert_called_with( + 'http://airbnb.com', + {'q': 'zz'}, + {'Accept': 'application/tofu'}, + True + ) + + def test_get_returns_false_on_error(self): + """JsonHttpProvider - Get - Error""" + self._dispatcher._get_request_retry.side_effect = OutputRequestFailure('?') + assert_false(self._http.get('http://airbnb.com', {'q': 'zz'})) + + def test_post_sends_correct_arguments(self): + """JsonHttpProvider - Post - Arguments""" + self._http.post( + 'http://airbnb.com', + {'q': 'zz'}, + headers={'Accept': 'application/tofu'}, + verify=True + ) + self._dispatcher._post_request_retry.assert_called_with( + 'http://airbnb.com', + {'q': 'zz'}, + {'Accept': 'application/tofu'}, + True + ) + + def test_post_returns_false_on_error(self): + """JsonHttpProvider - Post - Error""" + self._dispatcher._post_request_retry.side_effect = OutputRequestFailure('?') + assert_false(self._http.post('http://airbnb.com', {'q': 'zz'})) + + def test_put_sends_correct_arguments(self): + """JsonHttpProvider - Post - Arguments""" + self._http.put( + 'http://airbnb.com', + {'q': 'zz'}, + headers={'Accept': 'application/tofu'}, + verify=True + ) + self._dispatcher._put_request_retry.assert_called_with( + 'http://airbnb.com', + {'q': 'zz'}, + {'Accept': 'application/tofu'}, + True + ) + + def test_put_returns_false_on_error(self): + """JsonHttpProvider - Put - Error""" + self._dispatcher._put_request_retry.side_effect = OutputRequestFailure('?') + assert_false(self._http.put('http://airbnb.com', {})) + + +class TestWorkContextUnit(object): + """This test focuses on testing corner cases instead of top-down. + + This class does not mock out entire requests but rather mocks out behavior on the Work class. + """ + + def setup(self): + incident = {'id': 'ABCDEFGH'} + event = {'dedup_key': '000000ppppdpdpdpdpd'} + merged_incident = {'id': '12345678'} + note = {'id': 'notepaid'} + work = WorkContext( + MagicMock( + spec=OutputDispatcher, + __service__='test' + ), + { + 'email_from': 'test@test.test', + 'escalation_policy_id': 'EP123123', + 'service_id': 'SP123123', + 'token': 'zzzzzzzzzz', + 'api': 'https://api.pagerduty.com', + } + ) + work.verify_user_exists = MagicMock(return_value=True) + work._create_base_incident = MagicMock(return_value=incident) + work._create_base_alert_event = MagicMock(return_value=event) + work._merge_event_into_incident = MagicMock(return_value=merged_incident) + work._add_incident_note = MagicMock(return_value=note) + work._add_instability_note = MagicMock(return_value=note) + + self._work = work + + @patch('logging.Logger.error') + @patch('stream_alert.alert_processor.outputs.pagerduty.compose_alert') + def test_positive_case(self, compose_alert_mock, log_error): + """PagerDuty WorkContext - Minimum Positive Case""" + publication = {} + compose_alert_mock.return_value = publication + + alert = get_alert() + result = self._work.run(alert, 'descriptor') + assert_true(result) + + log_error.assert_not_called() + + @patch('logging.Logger.error') + @patch('stream_alert.alert_processor.outputs.pagerduty.compose_alert') + def test_unstable_merge_fail(self, compose_alert_mock, log_error): + """PagerDuty WorkContext - Unstable - Merge Failed""" + publication = {} + compose_alert_mock.return_value = publication + + self._work._merge_event_into_incident = MagicMock(return_value=False) + + alert = get_alert() + result = self._work.run(alert, 'descriptor') + assert_true(result) + + log_error.assert_called_with( + '[%s] Failed to merge alert [%s] into [%s]', 'test', '000000ppppdpdpdpdpd', 'ABCDEFGH' + ) + + @patch('logging.Logger.error') + @patch('stream_alert.alert_processor.outputs.pagerduty.compose_alert') + def test_unstable_note_fail(self, compose_alert_mock, log_error): + """PagerDuty WorkContext - Unstable - Add Node Failed""" + publication = {} + compose_alert_mock.return_value = publication + + self._work._add_incident_note = MagicMock(return_value=False) + + alert = get_alert() + result = self._work.run(alert, 'descriptor') + assert_true(result) + + log_error.assert_called_with( + '[%s] Failed to add note to incident (%s)', 'test', 'ABCDEFGH' + ) + + @patch('stream_alert.alert_processor.outputs.pagerduty.compose_alert') + def test_unstable_adds_instability_note(self, compose_alert_mock): + """PagerDuty WorkContext - Unstable - Add Instability Note""" + publication = {} + compose_alert_mock.return_value = publication + + self._work._add_incident_note = MagicMock(return_value=False) + + alert = get_alert() + result = self._work.run(alert, 'descriptor') + assert_true(result) + + self._work._add_instability_note.assert_called_with('ABCDEFGH') diff --git a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py index 0079ca068..d80114d1d 100644 --- a/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py +++ b/tests/unit/stream_alert_alert_processor/test_outputs/test_slack.py @@ -18,6 +18,7 @@ from mock import patch, Mock, MagicMock from nose.tools import assert_equal, assert_false, assert_true, assert_set_equal +from stream_alert.alert_processor.helpers import compose_alert from stream_alert.alert_processor.outputs.slack import SlackOutput from tests.unit.stream_alert_alert_processor.helpers import ( get_random_alert, @@ -49,8 +50,9 @@ def test_format_message_single(self): """SlackOutput - Format Single Message - Slack""" rule_name = 'test_rule_single' alert = get_random_alert(25, rule_name) - alert_publication = alert.publish_for(None, None) # FIXME (derek.wang) - loaded_message = SlackOutput._format_message(rule_name, alert_publication) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + loaded_message = SlackOutput._format_message(alert, alert_publication) # tests assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) @@ -59,12 +61,107 @@ def test_format_message_single(self): '*StreamAlert Rule Triggered: test_rule_single*') assert_equal(len(loaded_message['attachments']), 1) - def test_format_message_mutliple(self): + def test_format_message_custom_text(self): + """SlackOutput - Format Single Message - Custom Text""" + rule_name = 'test_rule_single' + alert = get_random_alert(25, rule_name) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + alert_publication['@slack.text'] = 'Lorem ipsum foobar' + + loaded_message = SlackOutput._format_message(alert, alert_publication) + + # tests + assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) + assert_equal(loaded_message['text'], 'Lorem ipsum foobar') + assert_equal(len(loaded_message['attachments']), 1) + + def test_format_message_custom_attachment(self): + """SlackOutput - Format Message, Custom Attachment""" + rule_name = 'test_empty_rule_description' + alert = get_random_alert(10, rule_name, True) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + alert_publication['@slack.attachments'] = [ + {'text': 'aasdfkjadfj'} + ] + + loaded_message = SlackOutput._format_message(alert, alert_publication) + + # tests + assert_equal(len(loaded_message['attachments']), 1) + assert_equal(loaded_message['attachments'][0]['text'], 'aasdfkjadfj') + + @patch('logging.Logger.warning') + def test_format_message_custom_attachment_limit(self, log_warning): + """SlackOutput - Format Message, Custom Attachment is Truncated""" + rule_name = 'test_empty_rule_description' + alert = get_random_alert(10, rule_name, True) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + + long_message = 'a'*(SlackOutput.MAX_MESSAGE_SIZE + 1) + alert_publication['@slack.attachments'] = [ + {'text': long_message} + ] + + loaded_message = SlackOutput._format_message(alert, alert_publication) + + # tests + assert_equal(len(loaded_message['attachments'][0]['text']), 3999) # bug in elide + log_warning.assert_called_with( + 'Custom attachment was truncated to length %d. Full message: %s', + SlackOutput.MAX_MESSAGE_SIZE, + long_message + ) + + def test_format_message_custom_attachment_multi(self): + """SlackOutput - Format Message, Multiple Custom Attachments""" + rule_name = 'test_empty_rule_description' + alert = get_random_alert(10, rule_name, True) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + alert_publication['@slack.attachments'] = [ + {'text': 'attachment text1'}, + {'text': 'attachment text2'}, + ] + + loaded_message = SlackOutput._format_message(alert, alert_publication) + + # tests + assert_equal(len(loaded_message['attachments']), 2) + assert_equal(loaded_message['attachments'][0]['text'], 'attachment text1') + assert_equal(loaded_message['attachments'][1]['text'], 'attachment text2') + + @patch('logging.Logger.warning') + def test_format_message_custom_attachment_multi_limit(self, log_warning): + """SlackOutput - Format Message, Too many Custom Attachments is truncated""" + rule_name = 'test_empty_rule_description' + alert = get_random_alert(10, rule_name, True) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + alert_publication['@slack.attachments'] = [] + for _ in range(SlackOutput.MAX_ATTACHMENTS + 1): + alert_publication['@slack.attachments'].append({'text': 'yay'}) + + loaded_message = SlackOutput._format_message(alert, alert_publication) + + # tests + assert_equal(len(loaded_message['attachments']), SlackOutput.MAX_ATTACHMENTS) + assert_equal(loaded_message['attachments'][19]['text'], 'yay') + log_warning.assert_called_with( + 'Message with %d custom attachments was truncated to %d attachments', + SlackOutput.MAX_ATTACHMENTS + 1, + SlackOutput.MAX_ATTACHMENTS + ) + + def test_format_message_multiple(self): """SlackOutput - Format Multi-Message""" rule_name = 'test_rule_multi-part' alert = get_random_alert(30, rule_name) - alert_publication = alert.publish_for(None, None) # FIXME (derek.wang) - loaded_message = SlackOutput._format_message(rule_name, alert_publication) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + loaded_message = SlackOutput._format_message(alert, alert_publication) # tests assert_set_equal(set(loaded_message.keys()), {'text', 'mrkdwn', 'attachments'}) @@ -76,8 +173,9 @@ def test_format_message_default_rule_description(self): """SlackOutput - Format Message, Default Rule Description""" rule_name = 'test_empty_rule_description' alert = get_random_alert(10, rule_name, True) - alert_publication = alert.publish_for(None, None) # FIXME (derek.wang) - loaded_message = SlackOutput._format_message(rule_name, alert_publication) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + loaded_message = SlackOutput._format_message(alert, alert_publication) # tests default_rule_description = '*Rule Description:*\nNo rule description provided\n' @@ -194,8 +292,9 @@ def test_max_attachments(self, log_mock): """SlackOutput - Max Attachment Reached""" alert = get_alert() alert.record = {'info': 'test' * 20000} - alert_publication = alert.publish_for(None, None) # FIXME (derek.wang) - list(SlackOutput._format_attachments(alert_publication, 'foo')) + output = MagicMock(spec=SlackOutput) + alert_publication = compose_alert(alert, output, 'asdf') + SlackOutput._format_default_attachments(alert, alert_publication, 'foo') log_mock.assert_called_with( '%s: %d-part message truncated to %d parts', alert_publication, diff --git a/tests/unit/stream_alert_alert_processor/test_publishers/__init__.py b/tests/unit/stream_alert_alert_processor/test_publishers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/stream_alert_alert_processor/test_publishers/slack/__init__.py b/tests/unit/stream_alert_alert_processor/test_publishers/slack/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/stream_alert_alert_processor/test_publishers/slack/test_slack_layout.py b/tests/unit/stream_alert_alert_processor/test_publishers/slack/test_slack_layout.py new file mode 100644 index 000000000..e5810a50e --- /dev/null +++ b/tests/unit/stream_alert_alert_processor/test_publishers/slack/test_slack_layout.py @@ -0,0 +1,271 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access,attribute-defined-outside-init +from datetime import datetime +from nose.tools import assert_equal, assert_less_equal + +from publishers.community.slack.slack_layout import ( + AttachFullRecord, + AttachPublication, + AttachRuleInfo, + AttachStringTemplate, + Summary +) +from tests.unit.stream_alert_alert_processor.helpers import get_alert + + +class TestSummary(object): + + def setup(self): + self._publisher = Summary() + + def test_simple(self): + """Publishers - Slack - Summary""" + alert = get_alert() + + alert.created = datetime.utcfromtimestamp(1546329600) + + publication = self._publisher.publish(alert, {}) + + expectation = { + '@slack.text': 'Rule triggered', + '@slack._previous_publication': {}, + '@slack.attachments': [ + { + 'author_link': '', + 'color': '#ff5a5f', + 'text': 'Info about this rule and what actions to take', + 'author_name': '', + 'mrkdwn_in': [], + 'thumb_url': '', + 'title': 'cb_binarystore_file_added', + 'footer': '', + 'ts': 1546329600, + 'title_link': ( + 'https://github.com/airbnb/streamalert/search' + '?q=cb_binarystore_file_added+path%3A%2Frules' + ), + 'image_url': '', + 'fallback': 'Rule triggered: cb_binarystore_file_added', + 'author_icon': '', + 'footer_icon': '', + } + ] + } + + assert_equal(publication['@slack.text'], expectation['@slack.text']) + assert_equal( + publication['@slack._previous_publication'], + expectation['@slack._previous_publication'] + ) + assert_equal(len(publication['@slack.attachments']), len(expectation['@slack.attachments'])) + assert_equal( + publication['@slack.attachments'][0].keys(), + expectation['@slack.attachments'][0].keys() + ) + assert_equal(publication['@slack.attachments'][0], expectation['@slack.attachments'][0]) + + +class TestAttachRuleInfo(object): + + def setup(self): + self._publisher = AttachRuleInfo() + + def test_simple(self): + """Publishers - Slack - AttachRuleInfo""" + alert = get_alert() + alert.created = datetime(2019, 1, 1) + alert.rule_description = ''' +Author: unit_test +Reference: somewhere_over_the_rainbow +Description: ? +Att&ck vector: Assuming direct control +''' + + publication = self._publisher.publish(alert, {}) + + expectation = { + '@slack.attachments': [ + { + 'color': '#8ce071', + 'fields': [ + { + 'title': 'Att&ck vector', + 'value': 'Assuming direct control', + }, + { + 'title': 'Reference', + 'value': 'somewhere_over_the_rainbow', + } + ] + } + ] + } + + assert_equal(publication, expectation) + + +class TestAttachPublication(object): + + def setup(self): + self._publisher = AttachPublication() + + def test_simple(self): + """Publishers - Slack - AttachPublication""" + alert = get_alert() + alert.created = datetime(2019, 1, 1) + + previous = { + '@slack._previous_publication': {'foo': 'bar'}, + '@slack.attachments': [ + { + 'text': 'attachment1', + }, + ] + } + publication = self._publisher.publish(alert, previous) + + expectation = { + '@slack._previous_publication': {'foo': 'bar'}, + '@slack.attachments': [ + {'text': 'attachment1'}, + { + 'color': '#00d1c1', + 'text': '```\n{\n "foo": "bar"\n}\n```', + 'mrkdwn_in': ['text'], + 'title': 'Alert Data:' + } + ] + } + + assert_equal(publication, expectation) + + +class TestAttachStringTemplate(object): + def setup(self): + self._publisher = AttachStringTemplate() + + def test_from_publication(self): + """Publishers - Slack - AttachStringTemplate - from publication""" + alert = get_alert(context={ + 'slack_message_template': 'Foo {bar} baz {buzz}' + }) + alert.created = datetime(2019, 1, 1) + + publication = self._publisher.publish(alert, {'bar': 'BAR?', 'buzz': 'BUZZ?'}) + + expectation = { + '@slack.attachments': [ + {'color': '#ffb400', 'text': 'Foo BAR? baz BUZZ?'} + ], + 'bar': 'BAR?', + 'buzz': 'BUZZ?', + } + assert_equal(publication, expectation) + + def test_from_previous_publication(self): + """Publishers - Slack - AttachStringTemplate - from previous publication""" + alert = get_alert(context={ + 'slack_message_template': 'Foo {bar} baz {buzz}' + }) + alert.created = datetime(2019, 1, 1) + + publication = self._publisher.publish(alert, { + '@slack._previous_publication': { + 'bar': 'BAR?', 'buzz': 'BUZZ?', + }, + 'bar': 'wrong', + 'buzz': 'wrong', + }) + + expectation = { + '@slack._previous_publication': {'bar': 'BAR?', 'buzz': 'BUZZ?'}, + '@slack.attachments': [{'color': '#ffb400', 'text': 'Foo BAR? baz BUZZ?'}], + 'bar': 'wrong', + 'buzz': 'wrong', + } + assert_equal(publication, expectation) + + +class TestAttachFullRecord(object): + + def setup(self): + self._publisher = AttachFullRecord() + + def test_simple(self): + """Publishers - Slack - AttachFullRecord""" + alert = get_alert() + alert.created = datetime(2019, 1, 1) + + publication = self._publisher.publish(alert, {}) + + expectation = { + '@slack.attachments': [ + { + 'footer': 'via ', + 'fields': [ + {'value': '79192344-4a6d-4850-8d06-9c3fef1060a4', 'title': 'Alert Id'} + ], + 'mrkdwn_in': ['text'], + 'author': 'corp-prefix.prod.cb.region', + 'color': '#7b0051', + 'text': ( + '```\n\n{\n "cb_server": "cbserver",\n "compressed_size": "9982",' + '\n "file_path": "/tmp/5DA/AD8/0F9AA55DA3BDE84B35656AD8911A22E1.zip",' + '\n "md5": "0F9AA55DA3BDE84B35656AD8911A22E1",\n "node_id": "1",' + '\n "size": "21504",\n "timestamp": "1496947381.18",' + '\n "type": "binarystore.file.added"\n}\n```' + ), + 'title': 'Record', + 'footer_icon': '' + } + ] + } + assert_equal(publication, expectation) + + def test_record_splitting(self): + """Publishers - Slack - AttachFullRecord - Split Record""" + alert = get_alert() + alert.created = datetime(2019, 1, 1) + + alert.record = { + 'massive_record': [] + } + for index in range(0, 999): + alert.record['massive_record'].append({ + 'index': index, + 'value': 'foo' + }) + + publication = self._publisher.publish(alert, {}) + + attachments = publication['@slack.attachments'] + + assert_equal(len(attachments), 14) + for attachment in attachments: + assert_less_equal(len(attachment['text']), 4000) + + assert_equal(attachments[0]['title'], 'Record') + assert_equal(len(attachments[0]['fields']), 0) + assert_equal(attachments[0]['footer'], '') + + assert_equal(attachments[1]['title'], '') + assert_equal(len(attachments[1]['fields']), 0) + assert_equal(attachments[1]['footer'], '') + + assert_equal(attachments[13]['title'], '') + assert_equal(len(attachments[13]['fields']), 1) + assert_equal(attachments[13]['footer'], 'via ') diff --git a/tests/unit/stream_alert_shared/test_description.py b/tests/unit/stream_alert_shared/test_description.py new file mode 100644 index 000000000..5ed714314 --- /dev/null +++ b/tests/unit/stream_alert_shared/test_description.py @@ -0,0 +1,328 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from nose.tools import assert_equal + +from stream_alert.shared.description import RuleDescriptionParser + + +class TestRuleDescriptionParserParse(object): + + @staticmethod + def test_simple(): + """RuleDescriptionParser - One Field""" + + # Should be able to parse the author out + case = ''' +author: Derek Wang +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, {'author': ['Derek Wang'], 'description': []}) + + @staticmethod + def test_strange_spacing(): + """RuleDescriptionParser - Spacing""" + + # This string contains random spaces before and after the author field. + case = ''' + + author: Derek Wang + ''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, {'author': ['Derek Wang'], 'description': []}) + + @staticmethod + def test_no_fields(): + """RuleDescriptionParser - No Fields""" + case = ''' +This rule has no format and thus the entire + string is considered to be lines of the + description. +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, { + 'description': [ + 'This rule has no format and thus the entire', + 'string is considered to be lines of the', + 'description.', + ] + }) + + @staticmethod + def test_misleading_fields(): + """RuleDescriptionParser - No Fields""" + case = ''' + This rule has some colons in it in strange places. For example: right here + But should not have fields because... reasons. + ''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, { + 'description': [ + 'This rule has some colons in it in strange places. For example: right here', + 'But should not have fields because... reasons.', + ] + }) + + @staticmethod + def test_multiple_fields(): + """RuleDescriptionParser - Multiple Fields""" + case = ''' +author: Derek Wang +owner: Bobby Tables +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, {'author': ['Derek Wang'], 'description': [], 'owner': ['Bobby Tables']}) + + @staticmethod + def test_multiple_fields_multiple_lines(): + """RuleDescriptionParser - Multiple Fields and Multiple Lines""" + case = ''' +author: Derek Wang (CSIRT) +reference: There is no cow level + Greed is good +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, { + 'author': ['Derek Wang (CSIRT)'], + 'description': [], + 'reference': ['There is no cow level', 'Greed is good'], + }) + + @staticmethod + def test_indentations(): + """RuleDescriptionParser - Indentations""" + case = ''' + author: Derek Wang (CSIRT) + description: Lorem ipsum bacon jalapeno cheeseburger + I'm clearly hungry + Planet pied piper forest windmill +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, { + 'author': ['Derek Wang (CSIRT)'], + 'description': [ + 'Lorem ipsum bacon jalapeno cheeseburger', + "I'm clearly hungry", + 'Planet pied piper forest windmill', + ] + }) + + @staticmethod + def test_description_prefix(): + """RuleDescriptionParser - Multiple Fields and Multiple Lines""" + case = ''' +This rule triggers when the temperature of the boiler exceeds 9000 + +author: Derek Wang (CSIRT) +reference: https://www.google.com +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, { + 'description': [ + 'This rule triggers when the temperature of the boiler exceeds 9000', + '' + ], + 'author': ['Derek Wang (CSIRT)'], + 'reference': ['https://www.google.com'], + }) + + @staticmethod + def test_special_characters(): + """RuleDescriptionParser - Special characters""" + case = ''' + author: Derek Wang (CSIRT) + + ATT&CK Tactic: Defense Evasion + ATT&CK Technique: Obfuscated Files or Information + ATT&CK URL: https://attack.mitre.org/wiki/Technique/T1027 +''' + + data = RuleDescriptionParser.parse(case) + assert_equal(data, { + 'author': ['Derek Wang (CSIRT)', ''], + 'description': [], + 'att&ck tactic': ['Defense Evasion'], + 'att&ck technique': ['Obfuscated Files or Information'], + 'att&ck url': ['https://attack.mitre.org/wiki/Technique/T1027'], + }) + + +class TestRuleDescriptionParserPresent(object): + + @staticmethod + def test_simple(): + """RuleDescriptionParser - present - One Field""" + case = ''' +author: Derek Wang +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, {'author': 'Derek Wang', 'description': '', 'fields': {}}) + + @staticmethod + def test_multiple_fields_multiple_lines(): + """RuleDescriptionParser - present - Multi Line""" + case = ''' +author: Derek Wang +description: This description + has multiple lines + with inconsistent indentation +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': 'Derek Wang', + 'description': 'This description has multiple lines with inconsistent indentation', + 'fields': {} + }) + + @staticmethod + def test_fields_with_multiline_urls(): + """RuleDescriptionParser - present - Multi Line Urls""" + case = ''' +author: Derek Wang +description: Lorem ipsum bacon + Cheeseburger +reference: https://www.airbnb.com/ + users/notifications +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': 'Derek Wang', + 'description': 'Lorem ipsum bacon Cheeseburger', + 'fields': { + 'reference': 'https://www.airbnb.com/users/notifications' + } + }) + + @staticmethod + def test_fields_with_multiline_complex_urls(): + """RuleDescriptionParser - present - Multi Line Complex Urls""" + case = ''' +reference: https://www.airbnb.com/ + users/notifications + ?a=b&$=b20L#hash=value[0] +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': '', + 'description': '', + 'fields': { + 'reference': 'https://www.airbnb.com/users/notifications?a=b&$=b20L#hash=value[0]' + } + }) + + @staticmethod + def test_fields_with_multiline_invalid_urls(): + """RuleDescriptionParser - present - Do not concat invalid URLs""" + case = ''' +reference: https://www.airbnb.com/users/notifications + Gets concatenated with this line with a space inbetween. +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': '', + 'description': '', + 'fields': { + 'reference': ( + 'https://www.airbnb.com/users/notifications ' + 'Gets concatenated with this line with a space inbetween.' + ) + } + }) + + @staticmethod + def test_handle_multiple_urls(): + """RuleDescriptionParser - present - Do not use http: as field""" + case = ''' +reference: https://www.airbnb.com/users/notifications + https://www.airbnb.com/account/profile + HTTP URL: https://www.airbnb.com/account/haha +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': '', + 'description': '', + 'fields': { + 'reference': ( + 'https://www.airbnb.com/users/notifications' + 'https://www.airbnb.com/account/profile' + ), + 'http url': 'https://www.airbnb.com/account/haha' + } + }) + + @staticmethod + def test_two_linebreaks_equals_newline(): + """RuleDescriptionParser - present - Handle linebreaks""" + case = ''' +description: + This is a long description where normal linebreaks like + this one will simply cause the sentence to continue flowing + as normal. + + However a double linebreak will cause a real newline character + to appear in the final product. + + + And double linebreaks cause double newlines. +''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': '', + 'description': ( + 'This is a long description where normal linebreaks like ' + 'this one will simply cause the sentence to continue flowing ' + 'as normal.\n' + 'However a double linebreak will cause a real newline character ' + 'to appear in the final product.' + '\n\n' + 'And double linebreaks cause double newlines.' + ), + 'fields': {} + }) + + @staticmethod + def test_url_plus_string(): + """RuleDescriptionParser - present - URL Plus String""" + case = ''' + description: + https://airbnb.com + + The above url is line broken from this comment. + ''' + + data = RuleDescriptionParser.present(case) + assert_equal(data, { + 'author': '', + 'description': ( + 'https://airbnb.com\n' + 'The above url is line broken from this comment.' + ), + 'fields': {} + }) diff --git a/tests/unit/stream_alert_shared/test_importer.py b/tests/unit/stream_alert_shared/test_importer.py new file mode 100644 index 000000000..1e0e91c85 --- /dev/null +++ b/tests/unit/stream_alert_shared/test_importer.py @@ -0,0 +1,77 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=no-self-use,protected-access + +from mock import call, patch +from nose.tools import assert_equal, assert_raises +from pyfakefs import fake_filesystem_unittest + +from stream_alert.shared.importer import import_folders, _path_to_module, _python_file_paths + +class RuleImportTest(fake_filesystem_unittest.TestCase): + """Test rule import logic with a mocked filesystem.""" + # pylint: disable=protected-access + + def setUp(self): + self.setUpPyfakefs() + + # Add rules files which should be imported. + self.fs.create_file('rules/matchers/matchers.py') + self.fs.create_file('rules/example.py') + self.fs.create_file('rules/community/cloudtrail/critical_api.py') + + # Add other files which should NOT be imported. + self.fs.create_file('rules/matchers/README.md') + self.fs.create_file('rules/__init__.py') + self.fs.create_file('rules/example.pyc') + self.fs.create_file('rules/community/REVIEWERS') + + @staticmethod + def test_python_rule_paths(): + """Rule - Python File Paths""" + result = set(_python_file_paths('rules')) + expected = { + 'rules/matchers/matchers.py', + 'rules/example.py', + 'rules/community/cloudtrail/critical_api.py' + } + assert_equal(expected, result) + + @staticmethod + def test_path_to_module(): + """Rule - Path to Module""" + assert_equal('name', _path_to_module('name.py')) + assert_equal('a.b.c.name', _path_to_module('a/b/c/name.py')) + + @staticmethod + def test_path_to_module_invalid(): + """Rule - Path to Module, Raises Exception""" + with assert_raises(NameError): + _path_to_module('a.b.py') + + with assert_raises(NameError): + _path_to_module('a/b/old.name.py') + + @staticmethod + @patch('importlib.import_module') + def test_import_rules(mock_import): + """Rule - Import Folders""" + import_folders('rules') + mock_import.assert_has_calls([ + call('rules.matchers.matchers'), + call('rules.example'), + call('rules.community.cloudtrail.critical_api') + ], any_order=True) diff --git a/tests/unit/stream_alert_shared/test_publisher.py b/tests/unit/stream_alert_shared/test_publisher.py new file mode 100644 index 000000000..50b806033 --- /dev/null +++ b/tests/unit/stream_alert_shared/test_publisher.py @@ -0,0 +1,397 @@ +""" +Copyright 2017-present, Airbnb Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +# pylint: disable=protected-access,attribute-defined-outside-init,invalid-name +from mock import patch +from nose.tools import assert_true, assert_equal, assert_false + +from stream_alert.alert_processor.helpers import _assemble_alert_publisher_for_output +from stream_alert.shared.publisher import ( + AlertPublisherRepository, + AlertPublisher, + CompositePublisher, + DefaultPublisher, + Register, + WrappedFunctionPublisher, +) +from tests.unit.stream_alert_alert_processor.helpers import get_alert + + +@Register +class SamplePublisher1(AlertPublisher): + + def publish(self, alert, publication): + new_publication = publication.copy() + new_publication['test1'] = True + return new_publication + + +@Register +class SamplePublisher2(AlertPublisher): + + def publish(self, alert, publication): + new_publication = publication.copy() + new_publication['test2'] = True + return new_publication + + +@Register +class SamplePublisher3(AlertPublisher): + _FIELD = 'test3' + + def publish(self, alert, publication): + new_publication = publication.copy() + new_publication[self._FIELD] = True + return new_publication + + +@Register +class SamplePublisher4(SamplePublisher3): + _FIELD = 'test4' + + +@Register +def sample_publisher_5(_, publication): + new_publication = publication.copy() + new_publication['test4'] = True + return new_publication + + +@Register +def sample_publisher_blank(*_): + return {} + + +class TestRegister(object): + + @staticmethod + def test_register_works_properly(): + """AlertPublisher - @Register - Works properly""" + assert_true(AlertPublisherRepository.has_publisher( + AlertPublisherRepository.get_publisher_name(SamplePublisher1) + )) + + +class TestCompositePublisher(object): + + @staticmethod + def test_composite_publisher_ordering(): + """CompositePublisher - Ensure publishers executed in correct order""" + publisher = CompositePublisher([ + SamplePublisher1(), + WrappedFunctionPublisher(sample_publisher_blank), + SamplePublisher2(), + ]) + + alert = get_alert() + publication = publisher.publish(alert, {}) + + expectation = {'test2': True} + assert_equal(publication, expectation) + + +class TestWrappedFunctionPublisher(object): + + @staticmethod + def test_wrapped_function_publisher(): + """WrappedFunctionPublisher - Ensure function is executed properly""" + publisher = WrappedFunctionPublisher(sample_publisher_5) + + alert = get_alert() + publication = publisher.publish(alert, {}) + + expectation = {'test4': True} + assert_equal(publication, expectation) + + +class TestAlertPublisherRepository(object): + + @staticmethod + def test_is_valid_publisher_class(): + """AlertPublisherRepository - is_valid_publisher() - Class""" + assert_true(AlertPublisherRepository.is_valid_publisher(SamplePublisher1)) + + @staticmethod + def test_is_valid_publisher_function(): + """AlertPublisherRepository - is_valid_publisher() - Function""" + assert_true(AlertPublisherRepository.is_valid_publisher(sample_publisher_5)) + + @staticmethod + def test_is_valid_publisher_invalid(): + """AlertPublisherRepository - is_valid_publisher() - Class""" + assert_false(AlertPublisherRepository.is_valid_publisher('aaa')) + + @staticmethod + def test_get_publisher_name_class(): + """AlertPublisherRepository - get_publisher_name() - Class""" + + name = AlertPublisherRepository.get_publisher_name(SamplePublisher1) + assert_equal( + name, + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1' + ) + + @staticmethod + def test_get_publisher_name_function(): + """AlertPublisherRepository - get_publisher_name() - Function""" + + name = AlertPublisherRepository.get_publisher_name(sample_publisher_5) + assert_equal( + name, + 'tests.unit.stream_alert_shared.test_publisher.sample_publisher_5' + ) + + @staticmethod + def test_registers_default_publishers(): + """AlertPublisher - AlertPublisherRepository - all_publishers()""" + publishers = AlertPublisherRepository.all_publishers() + + assert_true(len(publishers) > 0) + + @staticmethod + def test_has_publisher(): + """AlertPublisher - AlertPublisherRepository - get_publisher() - SamplePublisher1""" + assert_true(AlertPublisherRepository.has_publisher( + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1' + )) + + @staticmethod + def test_get_publisher(): + """AlertPublisher - AlertPublisherRepository - get_publisher() - SamplePublisher1""" + publisher = AlertPublisherRepository.get_publisher( + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1' + ) + + assert_true(isinstance(publisher, SamplePublisher1)) + + @staticmethod + def test_create_composite_publisher(): + """AlertPublisher - AlertPublisherRepository - create_composite_publisher() - Valid""" + publisher = AlertPublisherRepository.create_composite_publisher([ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1', + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2', + ]) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 2) + assert_true(isinstance(publisher._publishers[0], SamplePublisher1)) + assert_true(isinstance(publisher._publishers[1], SamplePublisher2)) + + @staticmethod + def test_create_composite_publisher_default(): + """AlertPublisher - AlertPublisherRepository - create_composite_publisher() - Default""" + publisher = AlertPublisherRepository.create_composite_publisher([]) + + assert_true(isinstance(publisher, DefaultPublisher)) + + @staticmethod + @patch('logging.Logger.error') + def test_create_composite_publisher_noexist(error_log): + """AlertPublisher - AlertPublisherRepository - create_composite_publisher() - No Exist""" + publisher = AlertPublisherRepository.create_composite_publisher(['no_exist']) + + assert_true(isinstance(publisher, DefaultPublisher)) + error_log.assert_called_with('Publisher [%s] does not exist', 'no_exist') + + +class TestAlertPublisherRepositoryAssemblePublisher(object): + def setup(self): + self._alert = get_alert(context={'this_context': 'that_value'}) + self._descriptor = 'some_descriptor' + self._output = 'demisto' + + def test_assemble_alert_publisher_for_output_none(self): + """AlertPublisher - AlertPublisherRepository - assemble() - String""" + self._alert.publishers = None + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, DefaultPublisher)) + + def test_assemble_alert_publisher_for_output_single_string(self): + """AlertPublisher - AlertPublisherRepository - assemble() - String""" + self._alert.publishers = 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1' + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 1) + assert_true(isinstance(publisher._publishers[0], SamplePublisher1)) + + def test_assemble_alert_publisher_for_output_list_string(self): + """AlertPublisher - AlertPublisherRepository - assemble() - List of Strings""" + self._alert.publishers = [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1', + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2', + ] + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 2) + assert_true(isinstance(publisher._publishers[0], SamplePublisher1)) + assert_true(isinstance(publisher._publishers[1], SamplePublisher2)) + + def test_assemble_alert_publisher_for_output_dict_empty(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Empty Dict""" + self._alert.publishers = {} + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, DefaultPublisher)) + + def test_assemble_alert_publisher_for_output_dict_irrelevant_key(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Dict with Irrelevant Key""" + self._alert.publishers = { + 'pagerduty': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1' + ] + } + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, DefaultPublisher)) + + def test_assemble_alert_publisher_for_output_dict_key_string(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Dict with Key -> String""" + self._alert.publishers = { + 'demisto': 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1', + 'pagerduty': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2' + ] + } + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 1) + assert_true(isinstance(publisher._publishers[0], SamplePublisher1)) + + def test_assemble_alert_publisher_for_output_dict_key_array(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Dict with Key -> List""" + self._alert.publishers = { + 'demisto': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1', + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2', + ], + 'pagerduty': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher3' + ], + } + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 2) + + def test_assemble_alert_publisher_for_output_dict_key_descriptor_string(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Dict matches Desc String""" + self._alert.publishers = { + 'demisto:some_descriptor': ( + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1' + ), + 'pagerduty': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2' + ], + } + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 1) + + def test_assemble_alert_publisher_for_output_dict_key_descriptor_list(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Dict matches Desc List""" + self._alert.publishers = { + 'demisto:some_descriptor': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1', + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2', + ], + 'pagerduty': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher3', + ] + } + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 2) + + def test_assemble_alert_publisher_for_output_dict_key_both_descriptor_output_list(self): + """AlertPublisher - AlertPublisherRepository - assemble() - Dict full match Lists""" + self._alert.publishers = { + 'demisto': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher1', + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher2', + ], + 'demisto:some_descriptor': [ + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher3', + 'tests.unit.stream_alert_shared.test_publisher.SamplePublisher4', + ], + 'pagerduty': [ + 'tests.unit.stream_alert_shared.test_publisher.sample_publisher_5', + ] + } + + publisher = _assemble_alert_publisher_for_output( + self._alert, + self._output, + self._descriptor + ) + + assert_true(isinstance(publisher, CompositePublisher)) + assert_equal(len(publisher._publishers), 4) + + # Order is important; the generic ones are loaded first then the specific ones are last + assert_true(isinstance(publisher._publishers[0], SamplePublisher1)) + assert_true(isinstance(publisher._publishers[1], SamplePublisher2)) + assert_true(isinstance(publisher._publishers[2], SamplePublisher3)) + assert_true(isinstance(publisher._publishers[3], SamplePublisher4)) diff --git a/tests/unit/stream_alert_shared/test_rule.py b/tests/unit/stream_alert_shared/test_rule.py index c10246546..4d19f1182 100644 --- a/tests/unit/stream_alert_shared/test_rule.py +++ b/tests/unit/stream_alert_shared/test_rule.py @@ -16,9 +16,8 @@ # pylint: disable=no-self-use,protected-access import hashlib -from mock import call, patch -from nose.tools import assert_equal, assert_raises, raises -from pyfakefs import fake_filesystem_unittest +from mock import patch +from nose.tools import assert_equal, raises from stream_alert.shared import rule, rule_table @@ -291,59 +290,3 @@ def test_get_rules_for_log_type(self): result = rule.Rule.rules_for_log_type('log_type_03') assert_equal(len(result), 1) assert_equal(result[0].name, 'rule_04') - - -class RuleImportTest(fake_filesystem_unittest.TestCase): - """Test rule import logic with a mocked filesystem.""" - # pylint: disable=protected-access - - def setUp(self): - self.setUpPyfakefs() - - # Add rules files which should be imported. - self.fs.create_file('rules/matchers/matchers.py') - self.fs.create_file('rules/example.py') - self.fs.create_file('rules/community/cloudtrail/critical_api.py') - - # Add other files which should NOT be imported. - self.fs.create_file('rules/matchers/README.md') - self.fs.create_file('rules/__init__.py') - self.fs.create_file('rules/example.pyc') - self.fs.create_file('rules/community/REVIEWERS') - - @staticmethod - def test_python_rule_paths(): - """Rule - Python File Paths""" - result = set(rule._python_file_paths('rules')) - expected = { - 'rules/matchers/matchers.py', - 'rules/example.py', - 'rules/community/cloudtrail/critical_api.py' - } - assert_equal(expected, result) - - @staticmethod - def test_path_to_module(): - """Rule - Path to Module""" - assert_equal('name', rule._path_to_module('name.py')) - assert_equal('a.b.c.name', rule._path_to_module('a/b/c/name.py')) - - @staticmethod - def test_path_to_module_invalid(): - """Rule - Path to Module, Raises Exception""" - with assert_raises(NameError): - rule._path_to_module('a.b.py') - - with assert_raises(NameError): - rule._path_to_module('a/b/old.name.py') - - @staticmethod - @patch('importlib.import_module') - def test_import_rules(mock_import): - """Rule - Import Folders""" - rule.import_folders('rules') - mock_import.assert_has_calls([ - call('rules.matchers.matchers'), - call('rules.example'), - call('rules.community.cloudtrail.critical_api') - ], any_order=True) diff --git a/tests/unit/streamalert/rules_engine/test_rules_engine.py b/tests/unit/streamalert/rules_engine/test_rules_engine.py index 8947c2ac7..cb23b2645 100644 --- a/tests/unit/streamalert/rules_engine/test_rules_engine.py +++ b/tests/unit/streamalert/rules_engine/test_rules_engine.py @@ -13,16 +13,46 @@ See the License for the specific language governing permissions and limitations under the License. """ +# pylint: disable=invalid-name + from datetime import datetime, timedelta -# from botocore.exceptions import ClientError, ParamValidationError from mock import Mock, patch, PropertyMock from nose.tools import assert_equal +from publishers.community.generic import remove_internal_fields +from stream_alert.shared.publisher import AlertPublisher, Register, DefaultPublisher import stream_alert.rules_engine.rules_engine as rules_engine_module from stream_alert.rules_engine.rules_engine import RulesEngine +def mock_conf(): + return { + 'global': { + 'general': { + 'rule_locations': [], + 'matcher_locations': [] + }, + 'infrastructure': { + 'rule_staging': { + 'enabled': True + } + } + } + } + + +@Register +def that_publisher(_, __): + return {} + + +@Register +class ThisPublisher(AlertPublisher): + def publish(self, alert, publication): + return {} + + # Without this time.sleep patch, backoff performs sleep # operations and drastically slows down testing # @patch('time.sleep', Mock()) @@ -37,7 +67,7 @@ def setup(self): patch.object(rules_engine_module, 'ThreatIntel'), \ patch.dict('os.environ', {'STREAMALERT_PREFIX': 'test_prefix'}), \ patch('stream_alert.rules_engine.rules_engine.load_config', - Mock(return_value=self._mock_conf())): + Mock(return_value=mock_conf())): self._rules_engine = RulesEngine() def teardown(self): @@ -49,27 +79,12 @@ def teardown(self): RulesEngine._alert_forwarder = None RulesEngine._RULE_TABLE_LAST_REFRESH = datetime(year=1970, month=1, day=1) - @classmethod - def _mock_conf(cls): - return { - 'global': { - 'general': { - 'rule_locations': [], - 'matcher_locations': [] - }, - 'infrastructure': { - 'rule_staging': { - 'enabled': True - } - } - } - } def test_load_rule_table_disabled(self): """RulesEngine - Load Rule Table, Disabled""" RulesEngine._rule_table = None RulesEngine._RULE_TABLE_LAST_REFRESH = datetime(year=1970, month=1, day=1) - config = self._mock_conf() + config = mock_conf() config['global']['infrastructure']['rule_staging']['enabled'] = False RulesEngine._load_rule_table(config) assert_equal(RulesEngine._rule_table, None) @@ -78,7 +93,7 @@ def test_load_rule_table_disabled(self): @patch('logging.Logger.debug') def test_load_rule_table_no_refresh(self, log_mock): """RulesEngine - Load Rule Table, No Refresh""" - config = self._mock_conf() + config = mock_conf() RulesEngine._RULE_TABLE_LAST_REFRESH = datetime.utcnow() RulesEngine._rule_table = 'table' self._rules_engine._load_rule_table(config) @@ -89,7 +104,7 @@ def test_load_rule_table_no_refresh(self, log_mock): @patch('logging.Logger.info') def test_load_rule_table_refresh(self, log_mock): """RulesEngine - Load Rule Table, Refresh""" - config = self._mock_conf() + config = mock_conf() config['global']['infrastructure']['rule_staging']['cache_refresh_minutes'] = 5 fake_date_now = datetime.utcnow() @@ -177,6 +192,8 @@ def test_process_subkeys(self): result = RulesEngine._process_subkeys(record, rule) assert_equal(result, True) + # -- Tests for _rule_analysis() + def test_rule_analysis(self): """RulesEngine - Rule Analysis""" rule = Mock( @@ -184,6 +201,7 @@ def test_rule_analysis(self): is_staged=Mock(return_value=False), outputs_set={'slack:test'}, description='rule description', + publishers=None, context=None, merge_by_keys=None, merge_window_mins=0 @@ -211,6 +229,7 @@ def test_rule_analysis(self): log_type='json', merge_by_keys=None, merge_window=timedelta(minutes=0), + publishers=None, rule_description='rule description', source_entity='test_stream', source_service='kinesis', @@ -226,6 +245,7 @@ def test_rule_analysis_staged(self): is_staged=Mock(return_value=True), outputs_set={'slack:test'}, description='rule description', + publishers=None, context=None, merge_by_keys=None, merge_window_mins=0 @@ -253,6 +273,7 @@ def test_rule_analysis_staged(self): log_type='json', merge_by_keys=None, merge_window=timedelta(minutes=0), + publishers=None, rule_description='rule description', source_entity='test_stream', source_service='kinesis', @@ -269,6 +290,203 @@ def test_rule_analysis_false(self): result = self._rules_engine._rule_analysis({'record': {'foo': 'bar'}}, rule) assert_equal(result is None, True) + def test_rule_analysis_with_publishers(self): + """RulesEngine - Rule Analysis, Publishers""" + rule = Mock( + process=Mock(return_value=True), + is_staged=Mock(return_value=False), + outputs_set={'slack:test', 'demisto:test'}, + description='rule description', + publishers={ + 'demisto': 'stream_alert.shared.publisher.DefaultPublisher', + 'slack': [that_publisher], + 'slack:test': [ThisPublisher], + }, + context=None, + merge_by_keys=None, + merge_window_mins=0 + ) + + # Override the Mock name attribute + type(rule).name = PropertyMock(return_value='test_rule') + record = {'foo': 'bar'} + payload = { + 'cluster': 'prod', + 'log_schema_type': 'log_type', + 'data_type': 'json', + 'resource': 'test_stream', + 'service': 'kinesis', + 'record': record + } + + with patch.object(rules_engine_module, 'Alert') as alert_mock: + result = self._rules_engine._rule_analysis(payload, rule) + alert_mock.assert_called_with( + 'test_rule', record, {'aws-firehose:alerts', 'slack:test', 'demisto:test'}, + cluster='prod', + context=None, + log_source='log_type', + log_type='json', + merge_by_keys=None, + merge_window=timedelta(minutes=0), + publishers={ + 'slack:test': [ + 'tests.unit.streamalert.rules_engine.test_rules_engine.that_publisher', + 'tests.unit.streamalert.rules_engine.test_rules_engine.ThisPublisher', + ], + 'demisto:test': [ + 'stream_alert.shared.publisher.DefaultPublisher' + ], + }, + rule_description='rule description', + source_entity='test_stream', + source_service='kinesis', + staged=False + ) + + assert_equal(result is not None, True) + + # --- Tests for _configure_publishers() + + def test_configure_publishers_empty(self): + """RulesEngine - _configure_publishers, Empty""" + rule = Mock( + outputs_set={'slack:test'}, + publishers=None, + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = None + + assert_equal(publishers, expectation) + + def test_configure_publishers_single_string(self): + """RulesEngine - _configure_publishers, Single string""" + rule = Mock( + outputs_set={'slack:test'}, + publishers='stream_alert.shared.publisher.DefaultPublisher' + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = {'slack:test': ['stream_alert.shared.publisher.DefaultPublisher']} + + assert_equal(publishers, expectation) + + def test_configure_publishers_single_reference(self): + """RulesEngine - _configure_publishers, Single reference""" + rule = Mock( + outputs_set={'slack:test'}, + publishers=DefaultPublisher + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = {'slack:test': ['stream_alert.shared.publisher.DefaultPublisher']} + + assert_equal(publishers, expectation) + + @patch('logging.Logger.warning') + def test_configure_publishers_single_invalid_string(self, log_warn): + """RulesEngine - _configure_publishers, Invalid string""" + rule = Mock( + outputs_set={'slack:test'}, + publishers='blah' + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = {'slack:test': []} + + assert_equal(publishers, expectation) + log_warn.assert_called_with('Requested publisher named (%s) is not registered.', 'blah') + + @patch('logging.Logger.error') + def test_configure_publishers_single_invalid_object(self, log_error): + """RulesEngine - _configure_publishers, Invalid object""" + rule = Mock( + outputs_set={'slack:test'}, + publishers=self # just some random object that's not a publisher + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = {'slack:test': []} + + assert_equal(publishers, expectation) + log_error.assert_called_with('Invalid publisher argument: %s', self) + + def test_configure_publishers_single_applies_to_multiple_outputs(self): + """RulesEngine - _configure_publishers, Multiple outputs""" + rule = Mock( + outputs_set={'slack:test', 'demisto:test', 'pagerduty:test'}, + publishers=DefaultPublisher + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = { + 'slack:test': ['stream_alert.shared.publisher.DefaultPublisher'], + 'demisto:test': ['stream_alert.shared.publisher.DefaultPublisher'], + 'pagerduty:test': ['stream_alert.shared.publisher.DefaultPublisher'], + } + + assert_equal(publishers, expectation) + + def test_configure_publishers_list(self): + """RulesEngine - _configure_publishers, List""" + rule = Mock( + outputs_set={'slack:test'}, + publishers=[DefaultPublisher, remove_internal_fields] + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = {'slack:test': [ + 'stream_alert.shared.publisher.DefaultPublisher', + 'publishers.community.generic.remove_internal_fields', + ]} + + assert_equal(publishers, expectation) + + def test_configure_publishers_mixed_list(self): + """RulesEngine - _configure_publishers, Mixed List""" + rule = Mock( + outputs_set={'slack:test', 'demisto:test'}, + publishers={ + 'demisto': 'stream_alert.shared.publisher.DefaultPublisher', + 'slack': [that_publisher], + 'slack:test': [ThisPublisher], + }, + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = { + 'slack:test': [ + 'tests.unit.streamalert.rules_engine.test_rules_engine.that_publisher', + 'tests.unit.streamalert.rules_engine.test_rules_engine.ThisPublisher' + ], + 'demisto:test': ['stream_alert.shared.publisher.DefaultPublisher'] + } + + assert_equal(publishers, expectation) + + def test_configure_publishers_mixed_single(self): + """RulesEngine - _configure_publishers, Mixed Single""" + rule = Mock( + outputs_set={'slack:test', 'demisto:test'}, + publishers={ + 'demisto': 'stream_alert.shared.publisher.DefaultPublisher', + 'slack': that_publisher, + 'slack:test': ThisPublisher, + }, + ) + + publishers = self._rules_engine._configure_publishers(rule) + expectation = { + 'slack:test': [ + 'tests.unit.streamalert.rules_engine.test_rules_engine.that_publisher', + 'tests.unit.streamalert.rules_engine.test_rules_engine.ThisPublisher', + ], + 'demisto:test': ['stream_alert.shared.publisher.DefaultPublisher'] + } + + assert_equal(publishers, expectation) + def test_run_subkey_failure(self): """RulesEngine - Run, Fail Subkey Check""" self._rules_engine._threat_intel = None