Skip to content

Commit

Permalink
Merge branch 'main' into ruff_config
Browse files Browse the repository at this point in the history
  • Loading branch information
ljstella authored Oct 16, 2024
2 parents a5cd630 + cfda377 commit 45a6cb0
Show file tree
Hide file tree
Showing 16 changed files with 364 additions and 210 deletions.
20 changes: 10 additions & 10 deletions contentctl/actions/detection_testing/GitService.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def getChanges(self, target_branch:str)->List[Detection]:

#Make a filename to content map
filepath_to_content_map = { obj.file_path:obj for (_,obj) in self.director.name_to_content_map.items()}
updated_detections:List[Detection] = []
updated_macros:List[Macro] = []
updated_lookups:List[Lookup] =[]
updated_detections:set[Detection] = set()
updated_macros:set[Macro] = set()
updated_lookups:set[Lookup] = set()

for diff in all_diffs:
if type(diff) == pygit2.Patch:
Expand All @@ -80,14 +80,14 @@ def getChanges(self, target_branch:str)->List[Detection]:
if decoded_path.is_relative_to(self.config.path/"detections") and decoded_path.suffix == ".yml":
detectionObject = filepath_to_content_map.get(decoded_path, None)
if isinstance(detectionObject, Detection):
updated_detections.append(detectionObject)
updated_detections.add(detectionObject)
else:
raise Exception(f"Error getting detection object for file {str(decoded_path)}")

elif decoded_path.is_relative_to(self.config.path/"macros") and decoded_path.suffix == ".yml":
macroObject = filepath_to_content_map.get(decoded_path, None)
if isinstance(macroObject, Macro):
updated_macros.append(macroObject)
updated_macros.add(macroObject)
else:
raise Exception(f"Error getting macro object for file {str(decoded_path)}")

Expand All @@ -98,7 +98,7 @@ def getChanges(self, target_branch:str)->List[Detection]:
updatedLookup = filepath_to_content_map.get(decoded_path, None)
if not isinstance(updatedLookup,Lookup):
raise Exception(f"Expected {decoded_path} to be type {type(Lookup)}, but instead if was {(type(updatedLookup))}")
updated_lookups.append(updatedLookup)
updated_lookups.add(updatedLookup)

elif decoded_path.suffix == ".csv":
# If the CSV was updated, we want to make sure that we
Expand All @@ -125,7 +125,7 @@ def getChanges(self, target_branch:str)->List[Detection]:
if updatedLookup is not None and updatedLookup not in updated_lookups:
# It is possible that both the CSV and YML have been modified for the same lookup,
# and we do not want to add it twice.
updated_lookups.append(updatedLookup)
updated_lookups.add(updatedLookup)

else:
pass
Expand All @@ -136,7 +136,7 @@ def getChanges(self, target_branch:str)->List[Detection]:

# If a detection has at least one dependency on changed content,
# then we must test it again
changed_macros_and_lookups = updated_macros + updated_lookups
changed_macros_and_lookups:set[SecurityContentObject] = updated_macros.union(updated_lookups)

for detection in self.director.detections:
if detection in updated_detections:
Expand All @@ -146,14 +146,14 @@ def getChanges(self, target_branch:str)->List[Detection]:

for obj in changed_macros_and_lookups:
if obj in detection.get_content_dependencies():
updated_detections.append(detection)
updated_detections.add(detection)
break

#Print out the names of all modified/new content
modifiedAndNewContentString = "\n - ".join(sorted([d.name for d in updated_detections]))

print(f"[{len(updated_detections)}] Pieces of modifed and new content (this may include experimental/deprecated/manual_test content):\n - {modifiedAndNewContentString}")
return updated_detections
return sorted(list(updated_detections))

def getSelected(self, detectionFilenames: List[FilePath]) -> List[Detection]:
filepath_to_content_map: dict[FilePath, SecurityContentObject] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from shutil import copyfile
from typing import Union, Optional

from pydantic import BaseModel, PrivateAttr, Field, dataclasses
from pydantic import ConfigDict, BaseModel, PrivateAttr, Field, dataclasses
import requests # type: ignore
import splunklib.client as client # type: ignore
from splunklib.binding import HTTPError # type: ignore
Expand Down Expand Up @@ -48,9 +48,9 @@ class SetupTestGroupResults(BaseModel):
success: bool = True
duration: float = 0
start_time: float

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True
)


class CleanupTestGroupResults(BaseModel):
Expand All @@ -68,14 +68,23 @@ class CannotRunBaselineException(Exception):
# exception
pass

class ReplayIndexDoesNotExistOnServer(Exception):
'''
In order to replay data files into the Splunk Server
for testing, they must be replayed into an index that
exists. If that index does not exist, this error will
be generated and raised before we try to do anything else
with that Data File.
'''
pass

@dataclasses.dataclass(frozen=False)
class DetectionTestingManagerOutputDto():
inputQueue: list[Detection] = Field(default_factory=list)
outputQueue: list[Detection] = Field(default_factory=list)
currentTestingQueue: dict[str, Union[Detection, None]] = Field(default_factory=dict)
start_time: Union[datetime.datetime, None] = None
replay_index: str = "CONTENTCTL_TESTING_INDEX"
replay_index: str = "contentctl_testing_index"
replay_host: str = "CONTENTCTL_HOST"
timeout_seconds: int = 60
terminate: bool = False
Expand All @@ -88,12 +97,13 @@ class DetectionTestingInfrastructure(BaseModel, abc.ABC):
sync_obj: DetectionTestingManagerOutputDto
hec_token: str = ""
hec_channel: str = ""
all_indexes_on_server: list[str] = []
_conn: client.Service = PrivateAttr()
pbar: tqdm.tqdm = None
start_time: Optional[float] = None

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True
)

def __init__(self, **data):
super().__init__(**data)
Expand Down Expand Up @@ -131,6 +141,7 @@ def setup(self):
(self.get_conn, "Waiting for App Installation"),
(self.configure_conf_file_datamodels, "Configuring Datamodels"),
(self.create_replay_index, f"Create index '{self.sync_obj.replay_index}'"),
(self.get_all_indexes, "Getting all indexes from server"),
(self.configure_imported_roles, "Configuring Roles"),
(self.configure_delete_indexes, "Configuring Indexes"),
(self.configure_hec, "Configuring HEC"),
Expand Down Expand Up @@ -169,12 +180,11 @@ def configure_hec(self):
pass

try:

res = self.get_conn().inputs.create(
name="DETECTION_TESTING_HEC",
kind="http",
index=self.sync_obj.replay_index,
indexes=f"{self.sync_obj.replay_index},_internal,_audit",
indexes=",".join(self.all_indexes_on_server), # This allows the HEC to write to all indexes
useACK=True,
)
self.hec_token = str(res.token)
Expand All @@ -183,6 +193,23 @@ def configure_hec(self):
except Exception as e:
raise (Exception(f"Failure creating HEC Endpoint: {str(e)}"))

def get_all_indexes(self) -> None:
"""
Retrieve a list of all indexes in the Splunk instance
"""
try:
# We do not include the replay index because by
# the time we get to this function, it has already
# been created on the server.
indexes = []
res = self.get_conn().indexes
for index in res.list():
indexes.append(index.name)
# Retrieve all available indexes on the splunk instance
self.all_indexes_on_server = indexes
except Exception as e:
raise (Exception(f"Failure getting indexes: {str(e)}"))

def get_conn(self) -> client.Service:
try:
if not self._conn:
Expand Down Expand Up @@ -265,11 +292,7 @@ def configure_imported_roles(
self,
imported_roles: list[str] = ["user", "power", "can_delete"],
enterprise_security_roles: list[str] = ["ess_admin", "ess_analyst", "ess_user"],
indexes: list[str] = ["_*", "*"],
):
indexes.append(self.sync_obj.replay_index)
indexes_encoded = ";".join(indexes)

):
try:
# Set which roles should be configured. For Enterprise Security/Integration Testing,
# we must add some extra foles.
Expand All @@ -281,7 +304,7 @@ def configure_imported_roles(
self.get_conn().roles.post(
self.infrastructure.splunk_app_username,
imported_roles=roles,
srchIndexesAllowed=indexes_encoded,
srchIndexesAllowed=";".join(self.all_indexes_on_server),
srchIndexesDefault=self.sync_obj.replay_index,
)
return
Expand All @@ -293,19 +316,17 @@ def configure_imported_roles(
self.get_conn().roles.post(
self.infrastructure.splunk_app_username,
imported_roles=imported_roles,
srchIndexesAllowed=indexes_encoded,
srchIndexesAllowed=";".join(self.all_indexes_on_server),
srchIndexesDefault=self.sync_obj.replay_index,
)

def configure_delete_indexes(self, indexes: list[str] = ["_*", "*"]):
indexes.append(self.sync_obj.replay_index)
def configure_delete_indexes(self):
endpoint = "/services/properties/authorize/default/deleteIndexesAllowed"
indexes_encoded = ";".join(indexes)
try:
self.get_conn().post(endpoint, value=indexes_encoded)
self.get_conn().post(endpoint, value=";".join(self.all_indexes_on_server))
except Exception as e:
self.pbar.write(
f"Error configuring deleteIndexesAllowed with '{indexes_encoded}': [{str(e)}]"
f"Error configuring deleteIndexesAllowed with '{self.all_indexes_on_server}': [{str(e)}]"
)

def wait_for_conf_file(self, app_name: str, conf_file_name: str):
Expand Down Expand Up @@ -654,8 +675,6 @@ def execute_unit_test(
# Set the mode and timeframe, if required
kwargs = {"exec_mode": "blocking"}



# Set earliest_time and latest_time appropriately if FORCE_ALL_TIME is False
if not FORCE_ALL_TIME:
if test.earliest_time is not None:
Expand Down Expand Up @@ -1035,8 +1054,8 @@ def retry_search_until_timeout(
# Get the start time and compute the timeout
search_start_time = time.time()
search_stop_time = time.time() + self.sync_obj.timeout_seconds
# Make a copy of the search string since we may

# Make a copy of the search string since we may
# need to make some small changes to it below
search = detection.search

Expand Down Expand Up @@ -1088,8 +1107,6 @@ def retry_search_until_timeout(
# Initialize the collection of fields that are empty that shouldn't be
present_threat_objects: set[str] = set()
empty_fields: set[str] = set()



# Filter out any messages in the results
for result in results:
Expand Down Expand Up @@ -1119,7 +1136,7 @@ def retry_search_until_timeout(
# not populated and we should throw an error. This can happen if there is a typo
# on a field. In this case, the field will appear but will not contain any values
current_empty_fields: set[str] = set()

for field in observable_fields_set:
if result.get(field, 'null') == 'null':
if field in risk_object_fields_set:
Expand All @@ -1139,9 +1156,7 @@ def retry_search_until_timeout(
if field in threat_object_fields_set:
present_threat_objects.add(field)
continue



# If everything succeeded up until now, and no empty fields are found in the
# current result, then the search was a success
if len(current_empty_fields) == 0:
Expand All @@ -1155,8 +1170,7 @@ def retry_search_until_timeout(

else:
empty_fields = empty_fields.union(current_empty_fields)



missing_threat_objects = threat_object_fields_set - present_threat_objects
# Report a failure if there were empty fields in a threat object in all results
if len(missing_threat_objects) > 0:
Expand All @@ -1172,7 +1186,6 @@ def retry_search_until_timeout(
duration=time.time() - search_start_time,
)
return


test.result.set_job_content(
job.content,
Expand Down Expand Up @@ -1233,9 +1246,19 @@ def replay_attack_data_file(
test_group: TestGroup,
test_group_start_time: float,
):
tempfile = mktemp(dir=tmp_dir)

# Before attempting to replay the file, ensure that the index we want
# to replay into actuall exists. If not, we should throw a detailed
# exception that can easily be interpreted by the user.
if attack_data_file.custom_index is not None and \
attack_data_file.custom_index not in self.all_indexes_on_server:
raise ReplayIndexDoesNotExistOnServer(
f"Unable to replay data file {attack_data_file.data} "
f"into index '{attack_data_file.custom_index}'. "
"The index does not exist on the Splunk Server. "
f"The only valid indexes on the server are {self.all_indexes_on_server}"
)

tempfile = mktemp(dir=tmp_dir)
if not (str(attack_data_file.data).startswith("http://") or
str(attack_data_file.data).startswith("https://")) :
if pathlib.Path(str(attack_data_file.data)).is_file():
Expand Down Expand Up @@ -1280,7 +1303,6 @@ def replay_attack_data_file(
)
)


# Upload the data
self.format_pbar_string(
TestReportingType.GROUP,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from bottle import template, Bottle, ServerAdapter
from contentctl.actions.detection_testing.views.DetectionTestingView import (
DetectionTestingView,
)
from threading import Thread

from bottle import template, Bottle, ServerAdapter
from wsgiref.simple_server import make_server, WSGIRequestHandler
import jinja2
import webbrowser
from threading import Thread
from pydantic import ConfigDict

from contentctl.actions.detection_testing.views.DetectionTestingView import (
DetectionTestingView,
)

DEFAULT_WEB_UI_PORT = 7999

Expand Down Expand Up @@ -100,9 +102,9 @@ def log_exception(*args, **kwargs):
class DetectionTestingViewWeb(DetectionTestingView):
bottleApp: Bottle = Bottle()
server: SimpleWebServer = SimpleWebServer(host="0.0.0.0", port=DEFAULT_WEB_UI_PORT)

class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True
)

def setup(self):
self.bottleApp.route("/", callback=self.showStatus)
Expand Down
13 changes: 6 additions & 7 deletions contentctl/enrichments/cve_enrichment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shelve
import time
from typing import Annotated, Any, Union, TYPE_CHECKING
from pydantic import BaseModel,Field, computed_field
from pydantic import ConfigDict, BaseModel,Field, computed_field
from decimal import Decimal
from requests.exceptions import ReadTimeout
from contentctl.objects.annotated_types import CVE_TYPE
Expand All @@ -32,13 +32,12 @@ def url(self)->str:
class CveEnrichment(BaseModel):
use_enrichment: bool = True
cve_api_obj: Union[CVESearch,None] = None


class Config:
# Arbitrary_types are allowed to let us use the CVESearch Object
arbitrary_types_allowed = True
frozen = True
# Arbitrary_types are allowed to let us use the CVESearch Object
model_config = ConfigDict(
arbitrary_types_allowed=True,
frozen=True
)

@staticmethod
def getCveEnrichment(config:validate, timeout_seconds:int=10, force_disable_enrichment:bool=True)->CveEnrichment:
Expand Down
Loading

0 comments on commit 45a6cb0

Please sign in to comment.