Skip to content

Commit

Permalink
Merge branch 'main' into test_on_app_change
Browse files Browse the repository at this point in the history
  • Loading branch information
ljstella authored Oct 16, 2024
2 parents 1f15302 + cfda377 commit 23f3742
Show file tree
Hide file tree
Showing 18 changed files with 385 additions and 222 deletions.
25 changes: 14 additions & 11 deletions contentctl/actions/detection_testing/GitService.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def getChanges(self, target_branch:str)->List[Detection]:

#Make a filename to content map
filepath_to_content_map = { obj.file_path:obj for (_,obj) in self.director.name_to_content_map.items()}
updated_detections:List[Detection] = []
updated_macros:List[Macro] = []
updated_lookups:List[Lookup] =[]
updated_datasources:List[DataSource] = []

updated_detections:List[Detection] = set()
updated_macros:List[Macro] = set()
updated_lookups:List[Lookup] = set()
updated_datasources:List[DataSource] = set()


for diff in all_diffs:
if type(diff) == pygit2.Patch:
Expand All @@ -82,14 +84,14 @@ def getChanges(self, target_branch:str)->List[Detection]:
if decoded_path.is_relative_to(self.config.path/"detections") and decoded_path.suffix == ".yml":
detectionObject = filepath_to_content_map.get(decoded_path, None)
if isinstance(detectionObject, Detection):
updated_detections.append(detectionObject)
updated_detections.add(detectionObject)
else:
raise Exception(f"Error getting detection object for file {str(decoded_path)}")

elif decoded_path.is_relative_to(self.config.path/"macros") and decoded_path.suffix == ".yml":
macroObject = filepath_to_content_map.get(decoded_path, None)
if isinstance(macroObject, Macro):
updated_macros.append(macroObject)
updated_macros.add(macroObject)
else:
raise Exception(f"Error getting macro object for file {str(decoded_path)}")

Expand All @@ -107,7 +109,7 @@ def getChanges(self, target_branch:str)->List[Detection]:
updatedLookup = filepath_to_content_map.get(decoded_path, None)
if not isinstance(updatedLookup,Lookup):
raise Exception(f"Expected {decoded_path} to be type {type(Lookup)}, but instead if was {(type(updatedLookup))}")
updated_lookups.append(updatedLookup)
updated_lookups.add(updatedLookup)

elif decoded_path.suffix == ".csv":
# If the CSV was updated, we want to make sure that we
Expand All @@ -133,7 +135,7 @@ def getChanges(self, target_branch:str)->List[Detection]:
if updatedLookup is not None and updatedLookup not in updated_lookups:
# It is possible that both the CSV and YML have been modified for the same lookup,
# and we do not want to add it twice.
updated_lookups.append(updatedLookup)
updated_lookups.add(updatedLookup)

else:
pass
Expand All @@ -144,7 +146,8 @@ def getChanges(self, target_branch:str)->List[Detection]:

# If a detection has at least one dependency on changed content,
# then we must test it again
changed_macros_and_lookups_and_datasources = updated_macros + updated_lookups + updated_datasources

changed_macros_and_lookups_and_datasources:set[SecurityContentObject] = updated_macros.union(updated_lookups, updated_datasources)

for detection in self.director.detections:
if detection in updated_detections:
Expand All @@ -154,14 +157,14 @@ def getChanges(self, target_branch:str)->List[Detection]:

for obj in changed_macros_and_lookups_and_datasources:
if obj in detection.get_content_dependencies():
updated_detections.append(detection)
updated_detections.add(detection)
break

#Print out the names of all modified/new content
modifiedAndNewContentString = "\n - ".join(sorted([d.name for d in updated_detections]))

print(f"[{len(updated_detections)}] Pieces of modifed and new content (this may include experimental/deprecated/manual_test content):\n - {modifiedAndNewContentString}")
return updated_detections
return sorted(list(updated_detections))

def getSelected(self, detectionFilenames: List[FilePath]) -> List[Detection]:
filepath_to_content_map: dict[FilePath, SecurityContentObject] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from shutil import copyfile
from typing import Union, Optional

from pydantic import BaseModel, PrivateAttr, Field, dataclasses
from pydantic import ConfigDict, BaseModel, PrivateAttr, Field, dataclasses
import requests # type: ignore
import splunklib.client as client # type: ignore
from splunklib.binding import HTTPError # type: ignore
Expand Down Expand Up @@ -48,9 +48,9 @@ class SetupTestGroupResults(BaseModel):
success: bool = True
duration: float = 0
start_time: float

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True
)


class CleanupTestGroupResults(BaseModel):
Expand All @@ -68,14 +68,23 @@ class CannotRunBaselineException(Exception):
# exception
pass

class ReplayIndexDoesNotExistOnServer(Exception):
'''
In order to replay data files into the Splunk Server
for testing, they must be replayed into an index that
exists. If that index does not exist, this error will
be generated and raised before we try to do anything else
with that Data File.
'''
pass

@dataclasses.dataclass(frozen=False)
class DetectionTestingManagerOutputDto():
inputQueue: list[Detection] = Field(default_factory=list)
outputQueue: list[Detection] = Field(default_factory=list)
currentTestingQueue: dict[str, Union[Detection, None]] = Field(default_factory=dict)
start_time: Union[datetime.datetime, None] = None
replay_index: str = "CONTENTCTL_TESTING_INDEX"
replay_index: str = "contentctl_testing_index"
replay_host: str = "CONTENTCTL_HOST"
timeout_seconds: int = 60
terminate: bool = False
Expand All @@ -88,12 +97,13 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
sync_obj: DetectionTestingManagerOutputDto
hec_token: str = ""
hec_channel: str = ""
all_indexes_on_server: list[str] = []
_conn: client.Service = PrivateAttr()
pbar: tqdm.tqdm = None
start_time: Optional[float] = None

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True
)

def __init__(self, **data):
super().__init__(**data)
Expand Down Expand Up @@ -131,6 +141,7 @@ def setup(self):
(self.get_conn, "Waiting for App Installation"),
(self.configure_conf_file_datamodels, "Configuring Datamodels"),
(self.create_replay_index, f"Create index '{self.sync_obj.replay_index}'"),
(self.get_all_indexes, "Getting all indexes from server"),
(self.configure_imported_roles, "Configuring Roles"),
(self.configure_delete_indexes, "Configuring Indexes"),
(self.configure_hec, "Configuring HEC"),
Expand Down Expand Up @@ -169,12 +180,11 @@ def configure_hec(self):
pass

try:

res = self.get_conn().inputs.create(
name="DETECTION_TESTING_HEC",
kind="http",
index=self.sync_obj.replay_index,
indexes=f"{self.sync_obj.replay_index},_internal,_audit",
indexes=",".join(self.all_indexes_on_server), # This allows the HEC to write to all indexes
useACK=True,
)
self.hec_token = str(res.token)
Expand All @@ -183,6 +193,23 @@ def configure_hec(self):
except Exception as e:
raise (Exception(f"Failure creating HEC Endpoint: {str(e)}"))

def get_all_indexes(self) -> None:
"""
Retrieve a list of all indexes in the Splunk instance
"""
try:
# We do not include the replay index because by
# the time we get to this function, it has already
# been created on the server.
indexes = []
res = self.get_conn().indexes
for index in res.list():
indexes.append(index.name)
# Retrieve all available indexes on the splunk instance
self.all_indexes_on_server = indexes
except Exception as e:
raise (Exception(f"Failure getting indexes: {str(e)}"))

def get_conn(self) -> client.Service:
try:
if not self._conn:
Expand Down Expand Up @@ -265,11 +292,7 @@ def configure_imported_roles(
self,
imported_roles: list[str] = ["user", "power", "can_delete"],
enterprise_security_roles: list[str] = ["ess_admin", "ess_analyst", "ess_user"],
indexes: list[str] = ["_*", "*"],
):
indexes.append(self.sync_obj.replay_index)
indexes_encoded = ";".join(indexes)

):
try:
# Set which roles should be configured. For Enterprise Security/Integration Testing,
# we must add some extra foles.
Expand All @@ -281,7 +304,7 @@ def configure_imported_roles(
self.get_conn().roles.post(
self.infrastructure.splunk_app_username,
imported_roles=roles,
srchIndexesAllowed=indexes_encoded,
srchIndexesAllowed=";".join(self.all_indexes_on_server),
srchIndexesDefault=self.sync_obj.replay_index,
)
return
Expand All @@ -293,19 +316,17 @@ def configure_imported_roles(
self.get_conn().roles.post(
self.infrastructure.splunk_app_username,
imported_roles=imported_roles,
srchIndexesAllowed=indexes_encoded,
srchIndexesAllowed=";".join(self.all_indexes_on_server),
srchIndexesDefault=self.sync_obj.replay_index,
)

def configure_delete_indexes(self, indexes: list[str] = ["_*", "*"]):
indexes.append(self.sync_obj.replay_index)
def configure_delete_indexes(self):
endpoint = "/services/properties/authorize/default/deleteIndexesAllowed"
indexes_encoded = ";".join(indexes)
try:
self.get_conn().post(endpoint, value=indexes_encoded)
self.get_conn().post(endpoint, value=";".join(self.all_indexes_on_server))
except Exception as e:
self.pbar.write(
f"Error configuring deleteIndexesAllowed with '{indexes_encoded}': [{str(e)}]"
f"Error configuring deleteIndexesAllowed with '{self.all_indexes_on_server}': [{str(e)}]"
)

def wait_for_conf_file(self, app_name: str, conf_file_name: str):
Expand Down Expand Up @@ -654,8 +675,6 @@ def execute_unit_test(
# Set the mode and timeframe, if required
kwargs = {"exec_mode": "blocking"}



# Set earliest_time and latest_time appropriately if FORCE_ALL_TIME is False
if not FORCE_ALL_TIME:
if test.earliest_time is not None:
Expand Down Expand Up @@ -1035,8 +1054,8 @@ def retry_search_until_timeout(
# Get the start time and compute the timeout
search_start_time = time.time()
search_stop_time = time.time() + self.sync_obj.timeout_seconds
# Make a copy of the search string since we may

# Make a copy of the search string since we may
# need to make some small changes to it below
search = detection.search

Expand Down Expand Up @@ -1088,8 +1107,6 @@ def retry_search_until_timeout(
# Initialize the collection of fields that are empty that shouldn't be
present_threat_objects: set[str] = set()
empty_fields: set[str] = set()



# Filter out any messages in the results
for result in results:
Expand Down Expand Up @@ -1119,7 +1136,7 @@ def retry_search_until_timeout(
# not populated and we should throw an error. This can happen if there is a typo
# on a field. In this case, the field will appear but will not contain any values
current_empty_fields: set[str] = set()

for field in observable_fields_set:
if result.get(field, 'null') == 'null':
if field in risk_object_fields_set:
Expand All @@ -1139,9 +1156,7 @@ def retry_search_until_timeout(
if field in threat_object_fields_set:
present_threat_objects.add(field)
continue



# If everything succeeded up until now, and no empty fields are found in the
# current result, then the search was a success
if len(current_empty_fields) == 0:
Expand All @@ -1155,8 +1170,7 @@ def retry_search_until_timeout(

else:
empty_fields = empty_fields.union(current_empty_fields)



missing_threat_objects = threat_object_fields_set - present_threat_objects
# Report a failure if there were empty fields in a threat object in all results
if len(missing_threat_objects) > 0:
Expand All @@ -1172,7 +1186,6 @@ def retry_search_until_timeout(
duration=time.time() - search_start_time,
)
return


test.result.set_job_content(
job.content,
Expand Down Expand Up @@ -1233,9 +1246,19 @@ def replay_attack_data_file(
test_group: TestGroup,
test_group_start_time: float,
):
tempfile = mktemp(dir=tmp_dir)

# Before attempting to replay the file, ensure that the index we want
# to replay into actuall exists. If not, we should throw a detailed
# exception that can easily be interpreted by the user.
if attack_data_file.custom_index is not None and \
attack_data_file.custom_index not in self.all_indexes_on_server:
raise ReplayIndexDoesNotExistOnServer(
f"Unable to replay data file {attack_data_file.data} "
f"into index '{attack_data_file.custom_index}'. "
"The index does not exist on the Splunk Server. "
f"The only valid indexes on the server are {self.all_indexes_on_server}"
)

tempfile = mktemp(dir=tmp_dir)
if not (str(attack_data_file.data).startswith("http://") or
str(attack_data_file.data).startswith("https://")) :
if pathlib.Path(str(attack_data_file.data)).is_file():
Expand Down Expand Up @@ -1280,7 +1303,6 @@ def replay_attack_data_file(
)
)


# Upload the data
self.format_pbar_string(
TestReportingType.GROUP,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from bottle import template, Bottle, ServerAdapter
from contentctl.actions.detection_testing.views.DetectionTestingView import (
DetectionTestingView,
)
from threading import Thread

from bottle import template, Bottle, ServerAdapter
from wsgiref.simple_server import make_server, WSGIRequestHandler
import jinja2
import webbrowser
from threading import Thread
from pydantic import ConfigDict

from contentctl.actions.detection_testing.views.DetectionTestingView import (
DetectionTestingView,
)

DEFAULT_WEB_UI_PORT = 7999

Expand Down Expand Up @@ -100,9 +102,9 @@ def log_exception(*args, **kwargs):
class DetectionTestingViewWeb(DetectionTestingView):
bottleApp: Bottle = Bottle()
server: SimpleWebServer = SimpleWebServer(host="0.0.0.0", port=DEFAULT_WEB_UI_PORT)

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True
)

def setup(self):
self.bottleApp.route("/", callback=self.showStatus)
Expand Down
10 changes: 6 additions & 4 deletions contentctl/actions/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,11 @@ def check_detection_metadata(self, config: inspect) -> None:
validation_errors[rule_name] = []
# No detections should be removed from build to build
if rule_name not in current_build_conf.detection_stanzas:
validation_errors[rule_name].append(DetectionMissingError(rule_name=rule_name))
if config.suppress_missing_content_exceptions:
print(f"[SUPPRESSED] {DetectionMissingError(rule_name=rule_name).long_message}")
else:
validation_errors[rule_name].append(DetectionMissingError(rule_name=rule_name))
continue

# Pull out the individual stanza for readability
previous_stanza = previous_build_conf.detection_stanzas[rule_name]
current_stanza = current_build_conf.detection_stanzas[rule_name]
Expand Down Expand Up @@ -335,7 +337,7 @@ def check_detection_metadata(self, config: inspect) -> None:
)

# Convert our dict mapping to a flat list of errors for use in reporting
validation_error_list = [x for inner_list in validation_errors.values() for x in inner_list]
validation_error_list = [x for inner_list in validation_errors.values() for x in inner_list]

# Report failure/success
print("\nDetection Metadata Validation:")
Expand All @@ -355,4 +357,4 @@ def check_detection_metadata(self, config: inspect) -> None:
raise ExceptionGroup(
"Validation errors when comparing detection stanzas in current and previous build:",
validation_error_list
)
)
Loading

0 comments on commit 23f3742

Please sign in to comment.