From 71dddf0fe6c5841b56cd5f675bb89204203b00c8 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 23 Jan 2025 11:43:45 -0500 Subject: [PATCH 01/16] Added code for new service --- moto/backend_index.py | 1 + moto/securityhub/__init__.py | 1 + moto/securityhub/exceptions.py | 21 ++ moto/securityhub/models.py | 238 +++++++++++++++++++++ moto/securityhub/responses.py | 64 ++++++ moto/securityhub/urls.py | 12 ++ tests/test_securityhub/__init__.py | 0 tests/test_securityhub/test_securityhub.py | 77 +++++++ tests/test_securityhub/test_server.py | 13 ++ 9 files changed, 427 insertions(+) create mode 100644 moto/securityhub/__init__.py create mode 100644 moto/securityhub/exceptions.py create mode 100644 moto/securityhub/models.py create mode 100644 moto/securityhub/responses.py create mode 100644 moto/securityhub/urls.py create mode 100644 tests/test_securityhub/__init__.py create mode 100644 tests/test_securityhub/test_securityhub.py create mode 100644 tests/test_securityhub/test_server.py diff --git a/moto/backend_index.py b/moto/backend_index.py index 609fe0dae2a4..ab9740fb1d72 100644 --- a/moto/backend_index.py +++ b/moto/backend_index.py @@ -183,6 +183,7 @@ ("scheduler", re.compile("https?://scheduler\\.(.+)\\.amazonaws\\.com")), ("sdb", re.compile("https?://sdb\\.(.+)\\.amazonaws\\.com")), ("secretsmanager", re.compile("https?://secretsmanager\\.(.+)\\.amazonaws\\.com")), + ("securityhub", re.compile("https?://securityhub\\.(.+)\\.amazonaws\\.com")), ( "servicediscovery", re.compile("https?://(data-)?servicediscovery\\.(.+)\\.amazonaws\\.com"), diff --git a/moto/securityhub/__init__.py b/moto/securityhub/__init__.py new file mode 100644 index 000000000000..68a7cc517c9e --- /dev/null +++ b/moto/securityhub/__init__.py @@ -0,0 +1 @@ +from .models import securityhub_backends # noqa: F401 diff --git a/moto/securityhub/exceptions.py b/moto/securityhub/exceptions.py new file mode 100644 index 000000000000..1ff50c3d8573 --- /dev/null +++ b/moto/securityhub/exceptions.py @@ -0,0 +1,21 @@ +"""Exceptions raised by the securityhub service.""" + +from moto.core.exceptions import JsonRESTError + + +class SecurityHubClientError(JsonRESTError): + code = 400 + + +class _InvalidOperationException(SecurityHubClientError): + def __init__(self, error_type: str, op: str, msg: str): + super().__init__( + error_type, + "An error occurred (%s) when calling the %s operation: %s" + % (error_type, op, msg), + ) + + +class InvalidInputException(_InvalidOperationException): + def __init__(self, op: str, msg: str): + super().__init__("InvalidInputException", op, msg) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py new file mode 100644 index 000000000000..7628ff67db27 --- /dev/null +++ b/moto/securityhub/models.py @@ -0,0 +1,238 @@ +"""SecurityHubBackend class with methods for supported APIs.""" + +from typing import Any, Dict, List, Optional + +from moto.core.base_backend import BackendDict, BaseBackend +from moto.core.common_models import BaseModel +from moto.securityhub.exceptions import InvalidInputException + + +class Finding(BaseModel): + def __init__(self, finding_id: str, finding_data: Dict[str, Any]): + self.id = finding_id + self.data = finding_data + + # Ensure required fields exist with default values + self.data.setdefault("Id", finding_id) + self.data.setdefault("AwsAccountId", "") + self.data.setdefault("CreatedAt", "") + self.data.setdefault("Description", "") + self.data.setdefault("GeneratorId", "") + self.data.setdefault("ProductArn", "") + self.data.setdefault("SchemaVersion", "") + self.data.setdefault("Title", "") + self.data.setdefault("Types", []) + + # Required but with nested structure + self.data.setdefault("Severity", {"Label": ""}) + self.data.setdefault("Resources", []) + + # Optional fields with default values + self.data.setdefault("UpdatedAt", "") + self.data.setdefault("FirstObservedAt", "") + self.data.setdefault("LastObservedAt", "") + self.data.setdefault("Confidence", 0) + self.data.setdefault("Criticality", 0) + self.data.setdefault("RecordState", "ACTIVE") + self.data.setdefault("WorkflowState", "NEW") + self.data.setdefault("VerificationState", "UNKNOWN") + + def _get_sortable_value(self, field: str) -> Any: + """Get a value from the finding data using dot notation""" + if "." in field: + parent, child = field.split(".") + return self.data.get(parent, {}).get(child) + elif "/" in field: + parent, child = field.split("/") + return self.data.get(parent, {}).get(child) + return self.data.get(field) + + def as_dict(self) -> Dict[str, Any]: + return self.data + + +class SecurityHubBackend(BaseBackend): + """Implementation of SecurityHub APIs.""" + + def __init__(self, region_name: str, account_id: str): + super().__init__(region_name, account_id) + self.findings: List[Finding] = [] + + def get_findings( + self, + filters: Optional[Dict[str, Any]] = None, + sort_criteria: Optional[List[Dict[str, str]]] = None, + next_token: Optional[str] = None, + max_results: Optional[int] = None, + ) -> Dict[str, Any]: + """Gets findings from SecurityHub based on provided filters and sorting criteria""" + findings = self.findings + + # Validate max_results if provided + if max_results is not None: + try: + max_results = int(max_results) + if max_results < 1: + raise InvalidInputException( + "MaxResults must be a number greater than 0" + ) + except ValueError: + raise InvalidInputException( + "MaxResults must be a number greater than 0" + ) + + # Validate sort criteria if provided + if sort_criteria: + allowed_orders = ["asc", "desc"] + + for criterion in sort_criteria: + if "Field" not in criterion or "SortOrder" not in criterion: + raise InvalidInputException( + "SortCriteria must contain Field and SortOrder" + ) + if criterion["SortOrder"].lower() not in allowed_orders: + raise InvalidInputException("SortOrder must be either asc or desc") + + # Apply filters if provided + if filters: + findings = self._apply_filters(findings, filters) + + # Apply sorting if provided + if sort_criteria: + findings = self._sort_findings(findings, sort_criteria) + + # Handle pagination + if next_token: + start_idx = int(next_token) + else: + start_idx = 0 + + end_idx = len(findings) + if max_results: + end_idx = min(start_idx + max_results, len(findings)) + + paginated_findings = findings[start_idx:end_idx] + + # Generate next token if there are more results + next_token = str(end_idx) if end_idx < len(findings) else None + + return { + "Findings": [f.as_dict() for f in paginated_findings], + "NextToken": next_token, + } + + def _apply_filters( + self, findings: List[Finding], filters: Dict[str, Any] + ) -> List[Finding]: + """Apply the provided filters to the findings""" + filtered_findings = findings + # Implementation would go through each filter type and apply the corresponding filtering logic + return filtered_findings + + def _sort_findings( + self, findings: List[Finding], sort_criteria: List[Dict[str, str]] + ) -> List[Finding]: + """Sort findings based on the provided sort criteria""" + for criterion in reversed(sort_criteria): + field = criterion["Field"] + reverse = criterion["SortOrder"].lower() == "desc" + + findings.sort(key=lambda x: self._get_sort_key(x, field), reverse=reverse) + + return findings + + def _get_sort_key(self, finding: Finding, field: str) -> Any: + """Get the sort key for a finding based on the field""" + value = finding._get_sortable_value(field) + + # Handle different data types for sorting + if field in [ + "Confidence", + "Criticality", + "Severity.Normalized", + "Severity.Product", + ]: + return float(value) if value is not None else 0 + elif field in ["FirstObservedAt", "LastObservedAt", "CreatedAt", "UpdatedAt"]: + return value or "" # Sort empty dates last + return str(value) if value is not None else "" + + def batch_import_findings( + self, findings: List[Dict[str, Any]] + ) -> tuple[int, int, List[Dict[str, Any]]]: + """ + Import findings in batch to SecurityHub. + + Args: + findings: List of finding dictionaries to import + + Returns: + Tuple of (failed_count, success_count, failed_findings) + """ + failed_count = 0 + success_count = 0 + failed_findings = [] + + for finding_data in findings: + try: + # Validate required fields + required_fields = [ + "AwsAccountId", + "CreatedAt", + "Description", + "GeneratorId", + "Id", + "ProductArn", + "Resources", + "SchemaVersion", + "Severity", + "Title", + "Types", + ] + + missing_fields = [ + field for field in required_fields if field not in finding_data + ] + if missing_fields: + raise InvalidInputException( + f"Finding must contain the following required fields: {', '.join(missing_fields)}" + ) + + if ( + not isinstance(finding_data["Resources"], list) + or len(finding_data["Resources"]) == 0 + ): + raise InvalidInputException( + "Finding must contain at least one resource in the Resources array" + ) + + finding_id = finding_data["Id"] + + existing_finding = next( + (f for f in self.findings if f.id == finding_id), None + ) + + if existing_finding: + # Update existing finding + existing_finding.data.update(finding_data) + else: + # Create new finding + new_finding = Finding(finding_id, finding_data) + self.findings.append(new_finding) + + success_count += 1 + + except Exception as e: + failed_count += 1 + failed_findings.append( + { + "Id": finding_data.get("Id", ""), + "ErrorCode": "InvalidInput", + "ErrorMessage": str(e), + } + ) + + return failed_count, success_count, failed_findings + + +securityhub_backends = BackendDict(SecurityHubBackend, "securityhub") diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py new file mode 100644 index 000000000000..36839ffbee7f --- /dev/null +++ b/moto/securityhub/responses.py @@ -0,0 +1,64 @@ +"""Handles incoming securityhub requests, invokes methods, returns responses.""" + +import json + +from moto.core.responses import BaseResponse + +from .models import securityhub_backends + + +class SecurityHubResponse(BaseResponse): + def __init__(self): + super().__init__(service_name="securityhub") + + @property + def securityhub_backend(self): + return securityhub_backends[self.current_account][self.region] + + def get_findings(self): + params = self._get_params() + filters = params.get("Filters") + sort_criteria = params.get("SortCriteria") + next_token = params.get("NextToken") + max_results = params.get("MaxResults") + + result = self.securityhub_backend.get_findings( + filters=filters, + sort_criteria=sort_criteria, + next_token=next_token, + max_results=max_results, + ) + + if "NextToken" not in result: + result["NextToken"] = None + + return json.dumps(result) + + def batch_import_findings(self): + raw_body = self.body + if isinstance(raw_body, bytes): + raw_body = raw_body.decode("utf-8") + body = json.loads(raw_body) + + findings = body.get("Findings", []) + + failed_count, success_count, failed_findings = ( + self.securityhub_backend.batch_import_findings( + findings=findings, + ) + ) + + return json.dumps( + { + "FailedCount": failed_count, + "FailedFindings": [ + { + "ErrorCode": finding.get("ErrorCode"), + "ErrorMessage": finding.get("ErrorMessage"), + "Id": finding.get("Id"), + } + for finding in failed_findings + ], + "SuccessCount": success_count, + } + ) diff --git a/moto/securityhub/urls.py b/moto/securityhub/urls.py new file mode 100644 index 000000000000..162e66d8ad4e --- /dev/null +++ b/moto/securityhub/urls.py @@ -0,0 +1,12 @@ +"""securityhub base URL and path.""" + +from .responses import SecurityHubResponse + +url_bases = [ + r"https?://securityhub\.(.+)\.amazonaws\.com", +] + +url_paths = { + "{0}/findings$": SecurityHubResponse.dispatch, + "{0}/findings/import$": SecurityHubResponse.dispatch, +} diff --git a/tests/test_securityhub/__init__.py b/tests/test_securityhub/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py new file mode 100644 index 000000000000..b4418e999cde --- /dev/null +++ b/tests/test_securityhub/test_securityhub.py @@ -0,0 +1,77 @@ +"""Unit tests for securityhub-supported APIs.""" + +import boto3 + +from moto import mock_aws +from moto.core import DEFAULT_ACCOUNT_ID + + +@mock_aws +def test_get_findings(): + client = boto3.client("securityhub", region_name="us-east-1") + + test_finding = { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": "Test finding description", + "GeneratorId": "test-generator", + "Id": "test-finding-001", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": "test-resource", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": "Test Finding", + "Types": ["Software and Configuration Checks"], + } + + # Import the finding + import_response = client.batch_import_findings(Findings=[test_finding]) + assert import_response["SuccessCount"] == 1 + + # Get the findings + response = client.get_findings() + + assert "Findings" in response + assert isinstance(response["Findings"], list) + assert len(response["Findings"]) == 1 + finding = response["Findings"][0] + assert finding["Id"] == "test-finding-001" + assert finding["SchemaVersion"] == "2018-10-08" + assert finding["WorkflowState"] == "NEW" + assert finding["RecordState"] == "ACTIVE" + + +@mock_aws +def test_batch_import_findings(): + client = boto3.client("securityhub", region_name="us-east-2") + + valid_finding = { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": "Test finding description", + "GeneratorId": "test-generator", + "Id": "test-finding-001", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": "test-resource", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": "Test Finding", + "Types": ["Software and Configuration Checks"], + } + + response = client.batch_import_findings(Findings=[valid_finding]) + assert response["SuccessCount"] == 1 + assert response["FailedCount"] == 0 + assert response["FailedFindings"] == [] + + invalid_finding = valid_finding.copy() + invalid_finding["Id"] = "test-finding-002" + invalid_finding["Severity"]["Label"] = "INVALID_LABEL" + + response = client.batch_import_findings(Findings=[invalid_finding]) + + assert response["SuccessCount"] == 1 + assert response["FailedCount"] == 0 + assert len(response["FailedFindings"]) == 0 diff --git a/tests/test_securityhub/test_server.py b/tests/test_securityhub/test_server.py new file mode 100644 index 000000000000..4de28f2edf30 --- /dev/null +++ b/tests/test_securityhub/test_server.py @@ -0,0 +1,13 @@ +"""Test different server responses.""" + +import moto.server as server + + +def test_securityhub_list(): + backend = server.create_backend_app("securityhub") + test_client = backend.test_client() + + resp = test_client.get("/") + + assert resp.status_code == 200 + assert "?" in str(resp.data) From e3f23975955fa97f1cec60df790ff0967cf3f80c Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 23 Jan 2025 14:17:14 -0500 Subject: [PATCH 02/16] Added op and msg --- moto/securityhub/models.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index 7628ff67db27..7c767e696daa 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -74,11 +74,12 @@ def get_findings( max_results = int(max_results) if max_results < 1: raise InvalidInputException( - "MaxResults must be a number greater than 0" + op="GetFindings", + msg="MaxResults must be a number greater than 0", ) except ValueError: raise InvalidInputException( - "MaxResults must be a number greater than 0" + op="GetFindings", msg="MaxResults must be a number greater than 0" ) # Validate sort criteria if provided @@ -88,10 +89,13 @@ def get_findings( for criterion in sort_criteria: if "Field" not in criterion or "SortOrder" not in criterion: raise InvalidInputException( - "SortCriteria must contain Field and SortOrder" + op="GetFindings", + msg="SortCriteria must contain Field and SortOrder", ) if criterion["SortOrder"].lower() not in allowed_orders: - raise InvalidInputException("SortOrder must be either asc or desc") + raise InvalidInputException( + op="GetFindings", msg="SortOrder must be either asc or desc" + ) # Apply filters if provided if filters: @@ -195,7 +199,8 @@ def batch_import_findings( ] if missing_fields: raise InvalidInputException( - f"Finding must contain the following required fields: {', '.join(missing_fields)}" + op="BatchImportFindings", + msg=f"Finding must contain the following required fields: {', '.join(missing_fields)}", ) if ( @@ -203,7 +208,8 @@ def batch_import_findings( or len(finding_data["Resources"]) == 0 ): raise InvalidInputException( - "Finding must contain at least one resource in the Resources array" + op="BatchImportFindings", + msg="Finding must contain at least one resource in the Resources array", ) finding_id = finding_data["Id"] From b8f02a27156f35211e5fc5953acac2fe97f79324 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 23 Jan 2025 14:29:34 -0500 Subject: [PATCH 03/16] Added Return Type --- moto/securityhub/responses.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index 36839ffbee7f..4b5eb391cea6 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -1,6 +1,7 @@ """Handles incoming securityhub requests, invokes methods, returns responses.""" import json +from typing import Any from moto.core.responses import BaseResponse @@ -8,14 +9,14 @@ class SecurityHubResponse(BaseResponse): - def __init__(self): + def __init__(self) -> None: super().__init__(service_name="securityhub") @property - def securityhub_backend(self): + def securityhub_backend(self) -> Any: return securityhub_backends[self.current_account][self.region] - def get_findings(self): + def get_findings(self) -> str: params = self._get_params() filters = params.get("Filters") sort_criteria = params.get("SortCriteria") @@ -34,7 +35,7 @@ def get_findings(self): return json.dumps(result) - def batch_import_findings(self): + def batch_import_findings(self) -> str: raw_body = self.body if isinstance(raw_body, bytes): raw_body = raw_body.decode("utf-8") From d1563c1569f104185956352d0e149ef420cf2fda Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 23 Jan 2025 15:12:39 -0500 Subject: [PATCH 04/16] Changed type to Backend --- moto/securityhub/responses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index 4b5eb391cea6..28f2563f8b5e 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -5,7 +5,7 @@ from moto.core.responses import BaseResponse -from .models import securityhub_backends +from .models import securityhub_backends, SecurityHubBackend class SecurityHubResponse(BaseResponse): @@ -13,7 +13,7 @@ def __init__(self) -> None: super().__init__(service_name="securityhub") @property - def securityhub_backend(self) -> Any: + def securityhub_backend(self) -> SecurityHubBackend: return securityhub_backends[self.current_account][self.region] def get_findings(self) -> str: From 08462ca89b4e7f5789c15906ccee73dfa8ace251 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 23 Jan 2025 15:22:03 -0500 Subject: [PATCH 05/16] Removed Any --- moto/securityhub/responses.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index 28f2563f8b5e..c84bf1328d93 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -1,11 +1,10 @@ """Handles incoming securityhub requests, invokes methods, returns responses.""" import json -from typing import Any from moto.core.responses import BaseResponse -from .models import securityhub_backends, SecurityHubBackend +from .models import SecurityHubBackend, securityhub_backends class SecurityHubResponse(BaseResponse): From 8d757f9b8a6189bdfde02621ef73107fc56c8044 Mon Sep 17 00:00:00 2001 From: Aman Date: Thu, 23 Jan 2025 15:38:18 -0500 Subject: [PATCH 06/16] Changed tuple --- moto/securityhub/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index 7c767e696daa..35c34fe8a14e 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -1,6 +1,6 @@ """SecurityHubBackend class with methods for supported APIs.""" -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel @@ -163,7 +163,7 @@ def _get_sort_key(self, finding: Finding, field: str) -> Any: def batch_import_findings( self, findings: List[Dict[str, Any]] - ) -> tuple[int, int, List[Dict[str, Any]]]: + ) -> Tuple[int, int, List[Dict[str, Any]]]: """ Import findings in batch to SecurityHub. From 7610731c698ba3b8da2d83ef5624739f92dff6d1 Mon Sep 17 00:00:00 2001 From: Aman Date: Fri, 24 Jan 2025 10:49:22 -0500 Subject: [PATCH 07/16] Removed unrequired file --- tests/test_securityhub/test_server.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 tests/test_securityhub/test_server.py diff --git a/tests/test_securityhub/test_server.py b/tests/test_securityhub/test_server.py deleted file mode 100644 index 4de28f2edf30..000000000000 --- a/tests/test_securityhub/test_server.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Test different server responses.""" - -import moto.server as server - - -def test_securityhub_list(): - backend = server.create_backend_app("securityhub") - test_client = backend.test_client() - - resp = test_client.get("/") - - assert resp.status_code == 200 - assert "?" in str(resp.data) From 7a93a1bf5c82313a7d1f0a25aa507a1e3a9a6fdd Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 10:59:07 -0500 Subject: [PATCH 08/16] Took out unrequired code --- moto/securityhub/models.py | 174 +++++++-------------- moto/securityhub/responses.py | 16 +- tests/test_securityhub/test_securityhub.py | 56 ++++++- 3 files changed, 122 insertions(+), 124 deletions(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index 35c34fe8a14e..bfb9eee1db0c 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -12,40 +12,39 @@ def __init__(self, finding_id: str, finding_data: Dict[str, Any]): self.id = finding_id self.data = finding_data - # Ensure required fields exist with default values - self.data.setdefault("Id", finding_id) - self.data.setdefault("AwsAccountId", "") - self.data.setdefault("CreatedAt", "") - self.data.setdefault("Description", "") - self.data.setdefault("GeneratorId", "") - self.data.setdefault("ProductArn", "") - self.data.setdefault("SchemaVersion", "") - self.data.setdefault("Title", "") - self.data.setdefault("Types", []) - - # Required but with nested structure - self.data.setdefault("Severity", {"Label": ""}) - self.data.setdefault("Resources", []) - - # Optional fields with default values - self.data.setdefault("UpdatedAt", "") - self.data.setdefault("FirstObservedAt", "") - self.data.setdefault("LastObservedAt", "") - self.data.setdefault("Confidence", 0) - self.data.setdefault("Criticality", 0) - self.data.setdefault("RecordState", "ACTIVE") - self.data.setdefault("WorkflowState", "NEW") - self.data.setdefault("VerificationState", "UNKNOWN") - - def _get_sortable_value(self, field: str) -> Any: - """Get a value from the finding data using dot notation""" - if "." in field: - parent, child = field.split(".") - return self.data.get(parent, {}).get(child) - elif "/" in field: - parent, child = field.split("/") - return self.data.get(parent, {}).get(child) - return self.data.get(field) + # # Ensure required fields exist with default values + # self.data.setdefault("Id", finding_id) + # self.data.setdefault("AwsAccountId", "") + # self.data.setdefault("CreatedAt", "") + # self.data.setdefault("Description", "") + # self.data.setdefault("GeneratorId", "") + # self.data.setdefault("ProductArn", "") + # self.data.setdefault("Title", "") + # self.data.setdefault("Types", []) + + # # Required but with nested structure + # self.data.setdefault("Severity", {"Label": ""}) + # self.data.setdefault("Resources", []) + + # # Optional fields with default values + # self.data.setdefault("UpdatedAt", "") + # self.data.setdefault("FirstObservedAt", "") + # self.data.setdefault("LastObservedAt", "") + # self.data.setdefault("Confidence", 0) + # self.data.setdefault("Criticality", 0) + # self.data.setdefault("RecordState", "ACTIVE") + # self.data.setdefault("WorkflowState", "NEW") + # self.data.setdefault("VerificationState", "UNKNOWN") + + # def _get_sortable_value(self, field: str) -> Any: + # """Get a value from the finding data using dot notation""" + # if "." in field: + # parent, child = field.split(".") + # return self.data.get(parent, {}).get(child) + # elif "/" in field: + # parent, child = field.split("/") + # return self.data.get(parent, {}).get(child) + # return self.data.get(field) def as_dict(self) -> Dict[str, Any]: return self.data @@ -82,29 +81,6 @@ def get_findings( op="GetFindings", msg="MaxResults must be a number greater than 0" ) - # Validate sort criteria if provided - if sort_criteria: - allowed_orders = ["asc", "desc"] - - for criterion in sort_criteria: - if "Field" not in criterion or "SortOrder" not in criterion: - raise InvalidInputException( - op="GetFindings", - msg="SortCriteria must contain Field and SortOrder", - ) - if criterion["SortOrder"].lower() not in allowed_orders: - raise InvalidInputException( - op="GetFindings", msg="SortOrder must be either asc or desc" - ) - - # Apply filters if provided - if filters: - findings = self._apply_filters(findings, filters) - - # Apply sorting if provided - if sort_criteria: - findings = self._sort_findings(findings, sort_criteria) - # Handle pagination if next_token: start_idx = int(next_token) @@ -125,42 +101,6 @@ def get_findings( "NextToken": next_token, } - def _apply_filters( - self, findings: List[Finding], filters: Dict[str, Any] - ) -> List[Finding]: - """Apply the provided filters to the findings""" - filtered_findings = findings - # Implementation would go through each filter type and apply the corresponding filtering logic - return filtered_findings - - def _sort_findings( - self, findings: List[Finding], sort_criteria: List[Dict[str, str]] - ) -> List[Finding]: - """Sort findings based on the provided sort criteria""" - for criterion in reversed(sort_criteria): - field = criterion["Field"] - reverse = criterion["SortOrder"].lower() == "desc" - - findings.sort(key=lambda x: self._get_sort_key(x, field), reverse=reverse) - - return findings - - def _get_sort_key(self, finding: Finding, field: str) -> Any: - """Get the sort key for a finding based on the field""" - value = finding._get_sortable_value(field) - - # Handle different data types for sorting - if field in [ - "Confidence", - "Criticality", - "Severity.Normalized", - "Severity.Product", - ]: - return float(value) if value is not None else 0 - elif field in ["FirstObservedAt", "LastObservedAt", "CreatedAt", "UpdatedAt"]: - return value or "" # Sort empty dates last - return str(value) if value is not None else "" - def batch_import_findings( self, findings: List[Dict[str, Any]] ) -> Tuple[int, int, List[Dict[str, Any]]]: @@ -179,37 +119,35 @@ def batch_import_findings( for finding_data in findings: try: - # Validate required fields - required_fields = [ - "AwsAccountId", - "CreatedAt", - "Description", - "GeneratorId", - "Id", - "ProductArn", - "Resources", - "SchemaVersion", - "Severity", - "Title", - "Types", - ] - - missing_fields = [ - field for field in required_fields if field not in finding_data - ] - if missing_fields: - raise InvalidInputException( - op="BatchImportFindings", - msg=f"Finding must contain the following required fields: {', '.join(missing_fields)}", - ) + # # Validate required fields + # required_fields = [ + # "AwsAccountId", + # "CreatedAt", + # "UpdatedAt", + # "Description", + # "GeneratorId", + # "Id", + # "ProductArn", + # "Severity", + # "Title", + # "Types", + # ] + + # missing_fields = [ + # field for field in required_fields if field not in finding_data + # ] + # if missing_fields: + # raise InvalidInputException( + # op="BatchImportFindings", + # msg=f"Finding must contain the following required fields: {', '.join(missing_fields)}", + # ) if ( not isinstance(finding_data["Resources"], list) or len(finding_data["Resources"]) == 0 ): raise InvalidInputException( - op="BatchImportFindings", - msg="Finding must contain at least one resource in the Resources array", + "Finding must contain at least one resource in the Resources array", ) finding_id = finding_data["Id"] diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index c84bf1328d93..05be3fb322c0 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -17,8 +17,19 @@ def securityhub_backend(self) -> SecurityHubBackend: def get_findings(self) -> str: params = self._get_params() + + # Don't try to parse JSON if we already have a dict with the right keys + if "SortCriteria" in params: + sort_criteria = params["SortCriteria"] + else: + # Try to parse JSON only if needed + try: + json_params = json.loads(list(params.keys())[0]) + sort_criteria = json_params.get("SortCriteria") + except (json.JSONDecodeError, IndexError): + sort_criteria = None + filters = params.get("Filters") - sort_criteria = params.get("SortCriteria") next_token = params.get("NextToken") max_results = params.get("MaxResults") @@ -29,9 +40,6 @@ def get_findings(self) -> str: max_results=max_results, ) - if "NextToken" not in result: - result["NextToken"] = None - return json.dumps(result) def batch_import_findings(self) -> str: diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py index b4418e999cde..e03f8c61f0fd 100644 --- a/tests/test_securityhub/test_securityhub.py +++ b/tests/test_securityhub/test_securityhub.py @@ -38,8 +38,8 @@ def test_get_findings(): finding = response["Findings"][0] assert finding["Id"] == "test-finding-001" assert finding["SchemaVersion"] == "2018-10-08" - assert finding["WorkflowState"] == "NEW" - assert finding["RecordState"] == "ACTIVE" + # assert finding["WorkflowState"] == "NEW" + # assert finding["RecordState"] == "ACTIVE" @mock_aws @@ -75,3 +75,55 @@ def test_batch_import_findings(): assert response["SuccessCount"] == 1 assert response["FailedCount"] == 0 assert len(response["FailedFindings"]) == 0 + + +# @mock_aws +# def test_get_findings_invalid_parameters(): +# """Test getting findings with invalid parameters.""" +# client = boto3.client("securityhub", region_name="us-east-1") + +# # Test invalid MaxResults +# with pytest.raises(ClientError) as exc: +# client.get_findings(MaxResults=0) +# err = exc.value.response["Error"] +# assert err["Code"] == "InvalidInputException" +# assert "MaxResults must be a number greater than 0" in err["Message"] + +# @mock_aws +# def test_batch_import_findings_validation(): +# """Test batch import findings with invalid input.""" +# client = boto3.client("securityhub", region_name="us-east-1") + +# # Test missing required fields +# invalid_finding = { +# "Id": "test-finding-001", +# # Missing other required fields +# } + +# response = client.batch_import_findings(Findings=[invalid_finding]) +# assert response["FailedCount"] == 1 +# assert response["SuccessCount"] == 0 +# assert len(response["FailedFindings"]) == 1 +# assert "required fields" in response["FailedFindings"][0]["ErrorMessage"] + +# # Test empty resources array +# invalid_finding = { +# "AwsAccountId": DEFAULT_ACCOUNT_ID, +# "CreatedAt": "2024-01-01T00:00:00.000Z", +# "UpdatedAt": "2024-01-01T00:00:00.000Z", +# "Description": "Test finding", +# "GeneratorId": "test-generator", +# "Id": "test-finding-001", +# "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", +# "Resources": [], # Empty resources array +# "SchemaVersion": "2018-10-08", +# "Severity": {"Label": "HIGH"}, +# "Title": "Test Finding", +# "Types": ["Software and Configuration Checks"], +# } + +# response = client.batch_import_findings(Findings=[invalid_finding]) +# assert response["FailedCount"] == 1 +# assert response["SuccessCount"] == 0 +# assert len(response["FailedFindings"]) == 1 +# assert "must contain at least one resource" in response["FailedFindings"][0]["ErrorMessage"] From cc27af445084581305f44e0d5ce87eeedadacc72 Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 11:00:23 -0500 Subject: [PATCH 09/16] Uncommented SortCriteria --- moto/securityhub/responses.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index 05be3fb322c0..75bfa8e7be90 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -18,17 +18,17 @@ def securityhub_backend(self) -> SecurityHubBackend: def get_findings(self) -> str: params = self._get_params() - # Don't try to parse JSON if we already have a dict with the right keys - if "SortCriteria" in params: - sort_criteria = params["SortCriteria"] - else: - # Try to parse JSON only if needed - try: - json_params = json.loads(list(params.keys())[0]) - sort_criteria = json_params.get("SortCriteria") - except (json.JSONDecodeError, IndexError): - sort_criteria = None - + # # Don't try to parse JSON if we already have a dict with the right keys + # if "SortCriteria" in params: + # sort_criteria = params["SortCriteria"] + # else: + # # Try to parse JSON only if needed + # try: + # json_params = json.loads(list(params.keys())[0]) + # sort_criteria = json_params.get("SortCriteria") + # except (json.JSONDecodeError, IndexError): + # sort_criteria = None + sort_criteria = params.get("SortCriteria") filters = params.get("Filters") next_token = params.get("NextToken") max_results = params.get("MaxResults") From b18f515e41d5711e3120ecf34734f4b62763daa2 Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 11:39:04 -0500 Subject: [PATCH 10/16] Added back message --- moto/securityhub/models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index bfb9eee1db0c..8f4531783dfd 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -147,7 +147,8 @@ def batch_import_findings( or len(finding_data["Resources"]) == 0 ): raise InvalidInputException( - "Finding must contain at least one resource in the Resources array", + op="BatchImportFindings", + msg="Finding must contain at least one resource in the Resources array", ) finding_id = finding_data["Id"] From 4b720b85069755b8afaa48d523ec988b9faa69fb Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 14:22:46 -0500 Subject: [PATCH 11/16] Tests passing --- moto/securityhub/models.py | 28 +--------- moto/securityhub/responses.py | 23 +++----- tests/test_securityhub/test_securityhub.py | 64 +++++----------------- 3 files changed, 25 insertions(+), 90 deletions(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index 8f4531783dfd..7262a9e39bfd 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -71,10 +71,11 @@ def get_findings( if max_results is not None: try: max_results = int(max_results) - if max_results < 1: + if max_results < 1 or max_results > 100: + print("max_results", max_results) raise InvalidInputException( op="GetFindings", - msg="MaxResults must be a number greater than 0", + msg="MaxResults must be a number between 1 and 100", ) except ValueError: raise InvalidInputException( @@ -119,29 +120,6 @@ def batch_import_findings( for finding_data in findings: try: - # # Validate required fields - # required_fields = [ - # "AwsAccountId", - # "CreatedAt", - # "UpdatedAt", - # "Description", - # "GeneratorId", - # "Id", - # "ProductArn", - # "Severity", - # "Title", - # "Types", - # ] - - # missing_fields = [ - # field for field in required_fields if field not in finding_data - # ] - # if missing_fields: - # raise InvalidInputException( - # op="BatchImportFindings", - # msg=f"Finding must contain the following required fields: {', '.join(missing_fields)}", - # ) - if ( not isinstance(finding_data["Resources"], list) or len(finding_data["Resources"]) == 0 diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index 75bfa8e7be90..a06e0d42b492 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -16,22 +16,15 @@ def securityhub_backend(self) -> SecurityHubBackend: return securityhub_backends[self.current_account][self.region] def get_findings(self) -> str: - params = self._get_params() + raw_params = self._get_params() - # # Don't try to parse JSON if we already have a dict with the right keys - # if "SortCriteria" in params: - # sort_criteria = params["SortCriteria"] - # else: - # # Try to parse JSON only if needed - # try: - # json_params = json.loads(list(params.keys())[0]) - # sort_criteria = json_params.get("SortCriteria") - # except (json.JSONDecodeError, IndexError): - # sort_criteria = None - sort_criteria = params.get("SortCriteria") - filters = params.get("Filters") - next_token = params.get("NextToken") - max_results = params.get("MaxResults") + # Parse the JSON string that's being used as a key + params = json.loads(next(iter(raw_params.keys()), "{}")) + + sort_criteria = params.get("SortCriteria", []) + filters = params.get("Filters", {}) + next_token = params.get("NextToken", None) + max_results = params.get("MaxResults", 100) result = self.securityhub_backend.get_findings( filters=filters, diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py index e03f8c61f0fd..d0a8769822a1 100644 --- a/tests/test_securityhub/test_securityhub.py +++ b/tests/test_securityhub/test_securityhub.py @@ -1,6 +1,8 @@ """Unit tests for securityhub-supported APIs.""" import boto3 +import pytest +from botocore.exceptions import ClientError from moto import mock_aws from moto.core import DEFAULT_ACCOUNT_ID @@ -77,53 +79,15 @@ def test_batch_import_findings(): assert len(response["FailedFindings"]) == 0 -# @mock_aws -# def test_get_findings_invalid_parameters(): -# """Test getting findings with invalid parameters.""" -# client = boto3.client("securityhub", region_name="us-east-1") - -# # Test invalid MaxResults -# with pytest.raises(ClientError) as exc: -# client.get_findings(MaxResults=0) -# err = exc.value.response["Error"] -# assert err["Code"] == "InvalidInputException" -# assert "MaxResults must be a number greater than 0" in err["Message"] - -# @mock_aws -# def test_batch_import_findings_validation(): -# """Test batch import findings with invalid input.""" -# client = boto3.client("securityhub", region_name="us-east-1") - -# # Test missing required fields -# invalid_finding = { -# "Id": "test-finding-001", -# # Missing other required fields -# } - -# response = client.batch_import_findings(Findings=[invalid_finding]) -# assert response["FailedCount"] == 1 -# assert response["SuccessCount"] == 0 -# assert len(response["FailedFindings"]) == 1 -# assert "required fields" in response["FailedFindings"][0]["ErrorMessage"] - -# # Test empty resources array -# invalid_finding = { -# "AwsAccountId": DEFAULT_ACCOUNT_ID, -# "CreatedAt": "2024-01-01T00:00:00.000Z", -# "UpdatedAt": "2024-01-01T00:00:00.000Z", -# "Description": "Test finding", -# "GeneratorId": "test-generator", -# "Id": "test-finding-001", -# "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", -# "Resources": [], # Empty resources array -# "SchemaVersion": "2018-10-08", -# "Severity": {"Label": "HIGH"}, -# "Title": "Test Finding", -# "Types": ["Software and Configuration Checks"], -# } - -# response = client.batch_import_findings(Findings=[invalid_finding]) -# assert response["FailedCount"] == 1 -# assert response["SuccessCount"] == 0 -# assert len(response["FailedFindings"]) == 1 -# assert "must contain at least one resource" in response["FailedFindings"][0]["ErrorMessage"] +@mock_aws +def test_get_findings_invalid_parameters(): + """Test getting findings with invalid parameters.""" + client = boto3.client("securityhub", region_name="us-east-1") + + # Test invalid MaxResults (must be between 1 and 100) + with pytest.raises(ClientError) as exc: + client.get_findings(MaxResults=101) + + err = exc.value.response["Error"] + assert err["Code"] == "InvalidInputException" + assert "MaxResults must be a number between 1 and 100" in err["Message"] From 6e130845180968e30f37e7d35405dee37c00cff6 Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 14:27:46 -0500 Subject: [PATCH 12/16] Took out print --- moto/securityhub/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index 7262a9e39bfd..ba97aea41106 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -72,7 +72,6 @@ def get_findings( try: max_results = int(max_results) if max_results < 1 or max_results > 100: - print("max_results", max_results) raise InvalidInputException( op="GetFindings", msg="MaxResults must be a number between 1 and 100", From 0365a40d51642e9dd304d5c1cb53f0db9cfebaa5 Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 14:54:27 -0500 Subject: [PATCH 13/16] Removed unused code --- moto/securityhub/models.py | 42 ++-------------------- moto/securityhub/responses.py | 1 - tests/test_securityhub/test_securityhub.py | 6 ---- 3 files changed, 2 insertions(+), 47 deletions(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index ba97aea41106..8283ca148400 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -12,40 +12,6 @@ def __init__(self, finding_id: str, finding_data: Dict[str, Any]): self.id = finding_id self.data = finding_data - # # Ensure required fields exist with default values - # self.data.setdefault("Id", finding_id) - # self.data.setdefault("AwsAccountId", "") - # self.data.setdefault("CreatedAt", "") - # self.data.setdefault("Description", "") - # self.data.setdefault("GeneratorId", "") - # self.data.setdefault("ProductArn", "") - # self.data.setdefault("Title", "") - # self.data.setdefault("Types", []) - - # # Required but with nested structure - # self.data.setdefault("Severity", {"Label": ""}) - # self.data.setdefault("Resources", []) - - # # Optional fields with default values - # self.data.setdefault("UpdatedAt", "") - # self.data.setdefault("FirstObservedAt", "") - # self.data.setdefault("LastObservedAt", "") - # self.data.setdefault("Confidence", 0) - # self.data.setdefault("Criticality", 0) - # self.data.setdefault("RecordState", "ACTIVE") - # self.data.setdefault("WorkflowState", "NEW") - # self.data.setdefault("VerificationState", "UNKNOWN") - - # def _get_sortable_value(self, field: str) -> Any: - # """Get a value from the finding data using dot notation""" - # if "." in field: - # parent, child = field.split(".") - # return self.data.get(parent, {}).get(child) - # elif "/" in field: - # parent, child = field.split("/") - # return self.data.get(parent, {}).get(child) - # return self.data.get(field) - def as_dict(self) -> Dict[str, Any]: return self.data @@ -64,10 +30,9 @@ def get_findings( next_token: Optional[str] = None, max_results: Optional[int] = None, ) -> Dict[str, Any]: - """Gets findings from SecurityHub based on provided filters and sorting criteria""" findings = self.findings - # Validate max_results if provided + # Max Results Parameter if max_results is not None: try: max_results = int(max_results) @@ -81,7 +46,7 @@ def get_findings( op="GetFindings", msg="MaxResults must be a number greater than 0" ) - # Handle pagination + # next_token Parameter if next_token: start_idx = int(next_token) else: @@ -93,7 +58,6 @@ def get_findings( paginated_findings = findings[start_idx:end_idx] - # Generate next token if there are more results next_token = str(end_idx) if end_idx < len(findings) else None return { @@ -135,10 +99,8 @@ def batch_import_findings( ) if existing_finding: - # Update existing finding existing_finding.data.update(finding_data) else: - # Create new finding new_finding = Finding(finding_id, finding_data) self.findings.append(new_finding) diff --git a/moto/securityhub/responses.py b/moto/securityhub/responses.py index a06e0d42b492..3252333509bd 100644 --- a/moto/securityhub/responses.py +++ b/moto/securityhub/responses.py @@ -18,7 +18,6 @@ def securityhub_backend(self) -> SecurityHubBackend: def get_findings(self) -> str: raw_params = self._get_params() - # Parse the JSON string that's being used as a key params = json.loads(next(iter(raw_params.keys()), "{}")) sort_criteria = params.get("SortCriteria", []) diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py index d0a8769822a1..c0e117c7dc4c 100644 --- a/tests/test_securityhub/test_securityhub.py +++ b/tests/test_securityhub/test_securityhub.py @@ -27,11 +27,9 @@ def test_get_findings(): "Types": ["Software and Configuration Checks"], } - # Import the finding import_response = client.batch_import_findings(Findings=[test_finding]) assert import_response["SuccessCount"] == 1 - # Get the findings response = client.get_findings() assert "Findings" in response @@ -40,8 +38,6 @@ def test_get_findings(): finding = response["Findings"][0] assert finding["Id"] == "test-finding-001" assert finding["SchemaVersion"] == "2018-10-08" - # assert finding["WorkflowState"] == "NEW" - # assert finding["RecordState"] == "ACTIVE" @mock_aws @@ -81,10 +77,8 @@ def test_batch_import_findings(): @mock_aws def test_get_findings_invalid_parameters(): - """Test getting findings with invalid parameters.""" client = boto3.client("securityhub", region_name="us-east-1") - # Test invalid MaxResults (must be between 1 and 100) with pytest.raises(ClientError) as exc: client.get_findings(MaxResults=101) From de317a47dc8fb2ff69348e675dce38efccf5af56 Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 16:01:15 -0500 Subject: [PATCH 14/16] Dummy change --- tests/test_securityhub/test_securityhub.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py index c0e117c7dc4c..07731742cf2f 100644 --- a/tests/test_securityhub/test_securityhub.py +++ b/tests/test_securityhub/test_securityhub.py @@ -14,7 +14,7 @@ def test_get_findings(): test_finding = { "AwsAccountId": DEFAULT_ACCOUNT_ID, - "CreatedAt": "2024-01-01T00:00:00.000Z", + "CreatedAt": "2024-01-01T00:00:00.001Z", "UpdatedAt": "2024-01-01T00:00:00.000Z", "Description": "Test finding description", "GeneratorId": "test-generator", From 847a4f9a096587c95d3600ed1975231938bd4202 Mon Sep 17 00:00:00 2001 From: Aman Date: Mon, 27 Jan 2025 16:02:59 -0500 Subject: [PATCH 15/16] Added more tests --- tests/test_securityhub/test_securityhub.py | 37 ++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py index 07731742cf2f..b4898805c529 100644 --- a/tests/test_securityhub/test_securityhub.py +++ b/tests/test_securityhub/test_securityhub.py @@ -85,3 +85,40 @@ def test_get_findings_invalid_parameters(): err = exc.value.response["Error"] assert err["Code"] == "InvalidInputException" assert "MaxResults must be a number between 1 and 100" in err["Message"] + + +@mock_aws +def test_batch_import_multiple_findings(): + client = boto3.client("securityhub", region_name="us-east-1") + + findings = [ + { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": f"Test finding description {i}", + "GeneratorId": "test-generator", + "Id": f"test-finding-{i:03d}", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": f"test-resource-{i}", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": f"Test Finding {i}", + "Types": ["Software and Configuration Checks"], + } + for i in range(1, 4) + ] + + import_response = client.batch_import_findings(Findings=findings) + assert import_response["SuccessCount"] == 3 + assert import_response["FailedCount"] == 0 + assert import_response["FailedFindings"] == [] + + get_response = client.get_findings() + assert "Findings" in get_response + assert isinstance(get_response["Findings"], list) + assert len(get_response["Findings"]) == 3 + + imported_ids = {finding["Id"] for finding in get_response["Findings"]} + expected_ids = {f"test-finding-{i:03d}" for i in range(1, 4)} + assert imported_ids == expected_ids From 541b7a38720b08bee7559e2076d5cd5afa0b630b Mon Sep 17 00:00:00 2001 From: Aman Date: Tue, 28 Jan 2025 10:46:42 -0500 Subject: [PATCH 16/16] Added more tests and changed to Paginator --- moto/securityhub/models.py | 20 ++++++-------- tests/test_securityhub/test_securityhub.py | 32 ++++++++++++++++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/moto/securityhub/models.py b/moto/securityhub/models.py index 8283ca148400..6c813af209cc 100644 --- a/moto/securityhub/models.py +++ b/moto/securityhub/models.py @@ -5,6 +5,7 @@ from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel from moto.securityhub.exceptions import InvalidInputException +from moto.utilities.paginator import Paginator class Finding(BaseModel): @@ -46,19 +47,14 @@ def get_findings( op="GetFindings", msg="MaxResults must be a number greater than 0" ) - # next_token Parameter - if next_token: - start_idx = int(next_token) - else: - start_idx = 0 + paginator = Paginator( + max_results=max_results or 100, + unique_attribute=["id"], + starting_token=next_token, + fail_on_invalid_token=True, + ) - end_idx = len(findings) - if max_results: - end_idx = min(start_idx + max_results, len(findings)) - - paginated_findings = findings[start_idx:end_idx] - - next_token = str(end_idx) if end_idx < len(findings) else None + paginated_findings, next_token = paginator.paginate(findings) return { "Findings": [f.as_dict() for f in paginated_findings], diff --git a/tests/test_securityhub/test_securityhub.py b/tests/test_securityhub/test_securityhub.py index b4898805c529..3a95c07a4923 100644 --- a/tests/test_securityhub/test_securityhub.py +++ b/tests/test_securityhub/test_securityhub.py @@ -122,3 +122,35 @@ def test_batch_import_multiple_findings(): imported_ids = {finding["Id"] for finding in get_response["Findings"]} expected_ids = {f"test-finding-{i:03d}" for i in range(1, 4)} assert imported_ids == expected_ids + + +@mock_aws +def test_get_findings_max_results(): + client = boto3.client("securityhub", region_name="us-east-1") + + findings = [ + { + "AwsAccountId": DEFAULT_ACCOUNT_ID, + "CreatedAt": "2024-01-01T00:00:00.000Z", + "UpdatedAt": "2024-01-01T00:00:00.000Z", + "Description": f"Test finding description {i}", + "GeneratorId": "test-generator", + "Id": f"test-finding-{i:03d}", + "ProductArn": f"arn:aws:securityhub:{client.meta.region_name}:{DEFAULT_ACCOUNT_ID}:product/{DEFAULT_ACCOUNT_ID}/default", + "Resources": [{"Id": f"test-resource-{i}", "Type": "AwsEc2Instance"}], + "SchemaVersion": "2018-10-08", + "Severity": {"Label": "HIGH"}, + "Title": f"Test Finding {i}", + "Types": ["Software and Configuration Checks"], + } + for i in range(1, 4) + ] + + import_response = client.batch_import_findings(Findings=findings) + assert import_response["SuccessCount"] == 3 + + get_response = client.get_findings(MaxResults=1) + assert "Findings" in get_response + assert isinstance(get_response["Findings"], list) + assert len(get_response["Findings"]) == 1 + assert "NextToken" in get_response