From 8f0b232dd16a292d1504feaf473eb25c22e7a33e Mon Sep 17 00:00:00 2001 From: Eric Hibbs Date: Thu, 7 Nov 2024 14:44:30 -0800 Subject: [PATCH 1/6] about to start big Core class refactor --- pyproject.toml | 107 +++++++++++++++++- pytest.ini | 3 + socketsecurity/core/__init__.py | 185 ++------------------------------ socketsecurity/core/config.py | 32 ++++++ socketsecurity/core/utils.py | 133 +++++++++++++++++++++++ tests/__init__.py | 0 tests/unit/__init__.py | 25 +++++ tests/unit/test_config.py | 48 +++++++++ tests/unit/test_core.py | 115 ++++++++++++++++++++ tests/unit/test_utils.py | 29 +++++ 10 files changed, 499 insertions(+), 178 deletions(-) create mode 100644 pytest.ini create mode 100644 socketsecurity/core/config.py create mode 100644 socketsecurity/core/utils.py create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_core.py create mode 100644 tests/unit/test_utils.py diff --git a/pyproject.toml b/pyproject.toml index 2c8ddf8..b7e80be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,18 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] +[project.optional-dependencies] +test = [ + "pytest>=7.4.0", + "pytest-cov>=4.1.0", + "pytest-mock>=3.12.0", + "pytest-asyncio>=0.23.0", + "pytest-watch >=4.2.0" +] +dev = [ + "ruff>=0.3.0", +] + [project.scripts] socketcli = "socketsecurity.socketcli:cli" @@ -45,4 +57,97 @@ include = [ ] [tool.setuptools.dynamic] -version = {attr = "socketsecurity.__version__"} \ No newline at end of file +version = {attr = "socketsecurity.__version__"} + +[tool.pytest.ini_options] +minversion = "7.0" +addopts = "-ra -q --cov=socketsecurity --cov-report=term-missing" +testpaths = [ + "tests", +] +pythonpath = "." + +[tool.coverage.run] +source = ["socketsecurity"] +omit = ["tests/*", "**/__init__.py"] + +[tool.coverage.report] +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "if __name__ == .__main__.:", + "raise NotImplementedError", + "if TYPE_CHECKING:", +] + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", +] + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" + +# Enable auto-formatting of code examples in docstrings. Markdown, +# reStructuredText code/literal blocks and doctests are all supported. +# +# This is currently disabled by default, but it is planned for this +# to be opt-out in the future. +docstring-code-format = false + +# Set the line length limit used when formatting code snippets in +# docstrings. +# +# This only has an effect when the `docstring-code-format` setting is +# enabled. +docstring-code-line-length = "dynamic" \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..3276d11 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +; addopts = -v --no-cov --capture=no +addopts = -v --no-cov --tb=short -ra diff --git a/socketsecurity/core/__init__.py b/socketsecurity/core/__init__.py index 8dea47a..60c369b 100644 --- a/socketsecurity/core/__init__.py +++ b/socketsecurity/core/__init__.py @@ -1,6 +1,8 @@ import logging from pathlib import PurePath - +from typing import Optional +from .config import CoreConfig +from .utils import encode_key, do_request, socket_globs import requests from urllib.parse import urlencode import base64 @@ -25,6 +27,7 @@ import platform from glob import glob import time +import sys __all__ = [ "Core", @@ -49,182 +52,6 @@ log = logging.getLogger("socketdev") log.addHandler(logging.NullHandler()) -socket_globs = { - "spdx": { - "spdx.json": { - "pattern": "*[-.]spdx.json" - } - }, - "cdx": { - "cyclonedx.json": { - "pattern": "{bom,*[-.]c{yclone,}dx}.json" - }, - "xml": { - "pattern": "{bom,*[-.]c{yclone,}dx}.xml" - } - }, - "npm": { - "package.json": { - "pattern": "package.json" - }, - "package-lock.json": { - "pattern": "package-lock.json" - }, - "npm-shrinkwrap.json": { - "pattern": "npm-shrinkwrap.json" - }, - "yarn.lock": { - "pattern": "yarn.lock" - }, - "pnpm-lock.yaml": { - "pattern": "pnpm-lock.yaml" - }, - "pnpm-lock.yml": { - "pattern": "pnpm-lock.yml" - }, - "pnpm-workspace.yaml": { - "pattern": "pnpm-workspace.yaml" - }, - "pnpm-workspace.yml": { - "pattern": "pnpm-workspace.yml" - } - }, - "pypi": { - "pipfile": { - "pattern": "pipfile" - }, - "pyproject.toml": { - "pattern": "pyproject.toml" - }, - "poetry.lock": { - "pattern": "poetry.lock" - }, - "requirements.txt": { - "pattern": "*requirements.txt" - }, - "requirements": { - "pattern": "requirements/*.txt" - }, - "requirements-*.txt": { - "pattern": "requirements-*.txt" - }, - "requirements_*.txt": { - "pattern": "requirements_*.txt" - }, - "requirements.frozen": { - "pattern": "requirements.frozen" - }, - "setup.py": { - "pattern": "setup.py" - } - }, - "golang": { - "go.mod": { - "pattern": "go.mod" - }, - "go.sum": { - "pattern": "go.sum" - } - }, - "java": { - "pom.xml": { - "pattern": "pom.xml" - } - } -} - - -def encode_key(token: str) -> None: - """ - encode_key takes passed token string and does a base64 encoding. It sets this as a global variable - :param token: str of the Socket API Security Token - :return: - """ - global encoded_key - encoded_key = base64.b64encode(token.encode()).decode('ascii') - - -def do_request( - path: str, - headers: dict = None, - payload: [dict, str] = None, - files: list = None, - method: str = "GET", - base_url: str = None, -) -> requests.request: - """ - do_requests is the shared function for making HTTP calls - :param base_url: - :param path: Required path for the request - :param headers: Optional dictionary of headers. If not set will use a default set - :param payload: Optional dictionary or string of the payload to pass - :param files: Optional list of files to upload - :param method: Optional method to use, defaults to GET - :return: - """ - - if base_url is not None: - url = f"{base_url}/{path}" - else: - if encoded_key is None or encoded_key == "": - raise APIKeyMissing - url = f"{api_url}/{path}" - - if headers is None: - headers = { - 'Authorization': f"Basic {encoded_key}", - 'User-Agent': f'SocketPythonCLI/{__version__}', - "accept": "application/json" - } - verify = True - if allow_unverified_ssl: - verify = False - response = requests.request( - method.upper(), - url, - headers=headers, - data=payload, - files=files, - timeout=timeout, - verify=verify - ) - output_headers = headers.copy() - output_headers['Authorization'] = "API_KEY_REDACTED" - output = { - "url": url, - "headers": output_headers, - "status_code": response.status_code, - "body": response.text, - "payload": payload, - "files": files, - "timeout": timeout - } - log.debug(output) - if response.status_code <= 399: - return response - elif response.status_code == 400: - raise APIFailure(output) - elif response.status_code == 401: - raise APIAccessDenied("Unauthorized") - elif response.status_code == 403: - raise APIInsufficientQuota("Insufficient max_quota for API method") - elif response.status_code == 404: - raise APIResourceNotFound(f"Path not found {path}") - elif response.status_code == 429: - raise APIInsufficientQuota("Insufficient quota for API route") - elif response.status_code == 524: - raise APICloudflareError(response.text) - else: - msg = { - "status_code": response.status_code, - "UnexpectedError": "There was an unexpected error using the API", - "error": response.text, - "payload": payload, - "url": url - } - raise APIFailure(msg) - - class Core: token: str base_api_url: str @@ -291,7 +118,11 @@ def set_timeout(request_timeout: int): :return: """ global timeout + print(f"Setting timeout in module {__name__} at {id(sys.modules[__name__])}") + print(f"Current timeout value: {timeout}") + print(f"Setting to: {request_timeout}") timeout = request_timeout + print(f"New timeout value: {timeout}") @staticmethod def get_org_id_slug() -> (str, str): diff --git a/socketsecurity/core/config.py b/socketsecurity/core/config.py new file mode 100644 index 0000000..ba8d15f --- /dev/null +++ b/socketsecurity/core/config.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import ClassVar + +@dataclass +class CoreConfig: + """Configuration for the Socket Security Core class""" + + # Required + token: str + + # Optional with defaults + api_url: str = "https://api.socket.dev/v0" + timeout: int = 30 + enable_all_alerts: bool = False + allow_unverified_ssl: bool = False + + # Constants + SOCKET_DATE_FORMAT: ClassVar[str] = "%Y-%m-%dT%H:%M:%S.%fZ" + DEFAULT_API_URL: ClassVar[str] = "https://api.socket.dev/v0" + DEFAULT_TIMEOUT: ClassVar[int] = 30 + + def __post_init__(self) -> None: + """Validate and process config after initialization""" + # Business rule validations + if not self.token: + raise ValueError("Token is required") + if self.timeout <= 0: + raise ValueError("Timeout must be positive") + + # Business logic + if not self.token.endswith(':'): + self.token = f"{self.token}:" diff --git a/socketsecurity/core/utils.py b/socketsecurity/core/utils.py new file mode 100644 index 0000000..6fab6c5 --- /dev/null +++ b/socketsecurity/core/utils.py @@ -0,0 +1,133 @@ +import base64 +import requests +from typing import Optional, Dict, List, Union +from socketsecurity import __version__ +from socketsecurity.core.exceptions import APIKeyMissing + +def encode_key(token: str) -> str: + """Encode API token in base64""" + return base64.b64encode(token.encode()).decode('ascii') + +def do_request( + path: str, + headers: Optional[Dict] = None, + payload: Optional[Union[Dict, str]] = None, + files: Optional[List] = None, + method: str = "GET", + base_url: Optional[str] = None, + api_key: Optional[str] = None, + timeout: int = 30, + verify_ssl: bool = True +) -> requests.Response: + """Make HTTP requests to Socket API""" + + if base_url is not None: + url = f"{base_url}/{path}" + else: + if not api_key: + raise APIKeyMissing + url = f"https://api.socket.dev/v0/{path}" + + if headers is None: + headers = { + 'Authorization': f"Basic {api_key}", + 'User-Agent': f'SocketPythonCLI/{__version__}', + "accept": "application/json" + } + + response = requests.request( + method.upper(), + url, + headers=headers, + data=payload, + files=files, + timeout=timeout, + verify=verify_ssl + ) + + return response + +# File pattern definitions +socket_globs = { + "spdx": { + "spdx.json": { + "pattern": "*[-.]spdx.json" + } + }, + "cdx": { + "cyclonedx.json": { + "pattern": "{bom,*[-.]c{yclone,}dx}.json" + }, + "xml": { + "pattern": "{bom,*[-.]c{yclone,}dx}.xml" + } + }, + "npm": { + "package.json": { + "pattern": "package.json" + }, + "package-lock.json": { + "pattern": "package-lock.json" + }, + "npm-shrinkwrap.json": { + "pattern": "npm-shrinkwrap.json" + }, + "yarn.lock": { + "pattern": "yarn.lock" + }, + "pnpm-lock.yaml": { + "pattern": "pnpm-lock.yaml" + }, + "pnpm-lock.yml": { + "pattern": "pnpm-lock.yml" + }, + "pnpm-workspace.yaml": { + "pattern": "pnpm-workspace.yaml" + }, + "pnpm-workspace.yml": { + "pattern": "pnpm-workspace.yml" + } + }, + "pypi": { + "pipfile": { + "pattern": "pipfile" + }, + "pyproject.toml": { + "pattern": "pyproject.toml" + }, + "poetry.lock": { + "pattern": "poetry.lock" + }, + "requirements.txt": { + "pattern": "*requirements.txt" + }, + "requirements": { + "pattern": "requirements/*.txt" + }, + "requirements-*.txt": { + "pattern": "requirements-*.txt" + }, + "requirements_*.txt": { + "pattern": "requirements_*.txt" + }, + "requirements.frozen": { + "pattern": "requirements.frozen" + }, + "setup.py": { + "pattern": "setup.py" + } + }, + "golang": { + "go.mod": { + "pattern": "go.mod" + }, + "go.sum": { + "pattern": "go.sum" + } + }, + "java": { + "pom.xml": { + "pattern": "pom.xml" + } + } +} \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..7a025e5 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,25 @@ +import pytest + +# Test data +TEST_API_TOKEN = "test-token" +TEST_API_KEY_ENCODED = "dGVzdC10b2tlbjo=" +DEFAULT_API_URL = "https://api.socket.dev/v0" +DEFAULT_TIMEOUT = 30 + +@pytest.fixture(autouse=True) +def reset_globals(): + """Reset global state after each test""" + from socketsecurity.core import Core + yield + Core.set_api_url(DEFAULT_API_URL) + Core.set_timeout(DEFAULT_TIMEOUT) + +@pytest.fixture +def mock_org_response(): + return { + "organizations": { + "test-org-123": { + "slug": "test-org" + } + } + } diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 0000000..5ac1b67 --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,48 @@ +import pytest +from socketsecurity.core.config import CoreConfig + +def test_config_initialization() -> None: + """Test basic config initialization with defaults""" + config = CoreConfig(token="test-token") + + assert config.token == "test-token:" + assert config.api_url == CoreConfig.DEFAULT_API_URL + assert config.timeout == CoreConfig.DEFAULT_TIMEOUT + assert config.enable_all_alerts is False + assert config.allow_unverified_ssl is False + +def test_config_custom_values() -> None: + """Test config with custom values""" + config = CoreConfig( + token="test-token", + api_url="https://custom.api", + timeout=60, + enable_all_alerts=True, + allow_unverified_ssl=True + ) + + assert config.token == "test-token:" + assert config.api_url == "https://custom.api" + assert config.timeout == 60 + assert config.enable_all_alerts is True + assert config.allow_unverified_ssl is True + +def test_config_validation() -> None: + """Test business rule validation""" + # Test empty token + with pytest.raises(ValueError, match="Token is required"): + CoreConfig(token="") + + # Test invalid timeout + with pytest.raises(ValueError, match="Timeout must be positive"): + CoreConfig(token="test", timeout=0) + +def test_token_formatting() -> None: + """Test token colon suffix business rule""" + # Without colon + config = CoreConfig(token="test-token") + assert config.token == "test-token:" + + # Already has colon + config = CoreConfig(token="test-token:") + assert config.token == "test-token:" diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py new file mode 100644 index 0000000..7f2751f --- /dev/null +++ b/tests/unit/test_core.py @@ -0,0 +1,115 @@ +from unittest.mock import Mock, patch +from socketsecurity.core import Core, timeout, api_url +from socketsecurity.core.classes import Package, Purl +from tests.unit import TEST_API_TOKEN, mock_org_response +import sys + +# Basic initialization and utility tests +@patch('socketsecurity.core.do_request') +def test_core_initialization(mock_do_request): + # Mock responses for both API calls + mock_responses = [ + # First call for get_org_id_slug + Mock(json=lambda: {"organizations": {"test-org-123": {"slug": "test-org"}}}), + # Second call for get_security_policy + Mock(json=lambda: { + "defaults": { + "issueRules": { + "noTests": {"action": "warn"}, + "noV1": {"action": "error"} + } + }, + "entries": [] + }) + ] + mock_do_request.side_effect = mock_responses + + core = Core(token=TEST_API_TOKEN) + assert core.token == f"{TEST_API_TOKEN}:" + assert core.base_api_url is None + assert core.request_timeout is None + + # Verify both API calls were made + assert mock_do_request.call_count == 2 + +def test_core_set_timeout(): + current_module = sys.modules[__name__] + core_module = sys.modules['socketsecurity.core'] + + print(f"Test module id: {id(current_module)}") + print(f"Core module id: {id(core_module)}") + print(f"Timeout in test: {id(timeout)}") + print(f"Timeout in core: {id(core_module.timeout)}") + + Core.set_timeout(60) + assert timeout == 60 + +def test_core_set_api_url(): + test_url = "https://test.api.com" + Core.set_api_url(test_url) + assert api_url == test_url + +# File handling tests +def test_save_file(tmp_path): + test_file = tmp_path / "test.txt" + Core.save_file(str(test_file), "test content") + assert test_file.read_text() == "test content" + +# Package handling tests +def test_create_sbom_dict(): + test_sbom = [{ + "id": "test-pkg@1.0.0", + "name": "test-pkg", + "version": "1.0.0", + "type": "npm", + "direct": True, + "topLevelAncestors": [], + "manifestFiles": [{"file": "package.json"}], + "alerts": [] + }] + + result = Core.create_sbom_dict(test_sbom) + assert "test-pkg@1.0.0" in result + assert result["test-pkg@1.0.0"].name == "test-pkg" + +# Capability checking tests +def test_check_alert_capabilities(): + package = Package(**{ + "id": "test-pkg@1.0.0", + "alerts": [{"type": "envVars"}] + }) + + capabilities = {} + result = Core.check_alert_capabilities(package, capabilities, package.id) + assert result[package.id] == ["Environment"] + +# PURL creation tests +def test_create_purl(): + packages = { + "test-pkg@1.0.0": Package(**{ + "id": "test-pkg@1.0.0", + "name": "test-pkg", + "version": "1.0.0", + "type": "npm", + "direct": True, + "topLevelAncestors": [], + "manifestFiles": [{"file": "package.json"}], + "alerts": [] + }) + } + + purl, package = Core.create_purl("test-pkg@1.0.0", packages) + assert isinstance(purl, Purl) + assert purl.name == "test-pkg" + assert purl.version == "1.0.0" + +# API interaction tests +@patch('socketsecurity.core.do_request') +def test_get_org_id_slug(mock_do_request): + mock_response = Mock() + mock_response.json.return_value = mock_org_response + mock_do_request.return_value = mock_response + + org_id, org_slug = Core.get_org_id_slug() + assert org_id == "test-org-123" + assert org_slug == "test-org" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py new file mode 100644 index 0000000..d21b5b9 --- /dev/null +++ b/tests/unit/test_utils.py @@ -0,0 +1,29 @@ +import pytest +from unittest.mock import patch +from socketsecurity.core.utils import encode_key, do_request +from socketsecurity.core.exceptions import APIKeyMissing + +def test_encode_key(): + """Test API key encoding""" + token = "test-token:" + encoded = encode_key(token) + assert encoded == "dGVzdC10b2tlbjo=" + +@patch('requests.request') +def test_do_request(mock_request): + """Test API request utility""" + do_request( + path="test/path", + api_key="encoded-key", + timeout=30 + ) + + mock_request.assert_called_once() + args = mock_request.call_args + assert args[1]['timeout'] == 30 + assert args[1]['headers']['Authorization'] == "Basic encoded-key" + +def test_do_request_missing_key(): + """Test API request fails without key""" + with pytest.raises(APIKeyMissing): + do_request(path="test/path") \ No newline at end of file From 5e6c6bc7fcb3446cd80f8e7fa58a0a95c8406617 Mon Sep 17 00:00:00 2001 From: Eric Hibbs Date: Fri, 8 Nov 2024 18:02:40 -0800 Subject: [PATCH 2/6] Core class fully covered but non-functional. Need to update other pieces --- .gitignore | 4 +- pyproject.toml | 7 +- pytest.ini | 5 +- socketsecurity/core/__init__.py | 563 ++++++++++--------- socketsecurity/core/classes.py | 6 +- socketsecurity/core/client.py | 54 ++ socketsecurity/core/config.py | 72 ++- socketsecurity/core/logging.py | 32 ++ socketsecurity/core/utils.py | 48 -- socketsecurity/socketcli.py | 16 +- tests/unit/__init__.py | 25 - tests/unit/test_client.py | 125 +++++ tests/unit/test_config.py | 104 ++-- tests/unit/test_core.py | 788 +++++++++++++++++++++++---- tests/unit/test_core_instance.py | 906 +++++++++++++++++++++++++++++++ 15 files changed, 2259 insertions(+), 496 deletions(-) create mode 100644 socketsecurity/core/client.py create mode 100644 socketsecurity/core/logging.py create mode 100644 tests/unit/test_client.py create mode 100644 tests/unit/test_core_instance.py diff --git a/.gitignore b/.gitignore index 5738fef..7054427 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,6 @@ markdown_security_temp.md *.pyc test.py *.cpython-312.pyc` -file_generator.py \ No newline at end of file +file_generator.py +.coverage +.env.local \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b7e80be..a571939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,12 +59,7 @@ include = [ [tool.setuptools.dynamic] version = {attr = "socketsecurity.__version__"} -[tool.pytest.ini_options] -minversion = "7.0" -addopts = "-ra -q --cov=socketsecurity --cov-report=term-missing" -testpaths = [ - "tests", -] + pythonpath = "." [tool.coverage.run] diff --git a/pytest.ini b/pytest.ini index 3276d11..02a8000 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,4 @@ [pytest] -; addopts = -v --no-cov --capture=no -addopts = -v --no-cov --tb=short -ra +testpaths = tests/unit +addopts = -vv --no-cov --tb=short -ra +python_files = test_*.py diff --git a/socketsecurity/core/__init__.py b/socketsecurity/core/__init__.py index 60c369b..f554938 100644 --- a/socketsecurity/core/__init__.py +++ b/socketsecurity/core/__init__.py @@ -1,8 +1,8 @@ import logging from pathlib import PurePath -from typing import Optional -from .config import CoreConfig -from .utils import encode_key, do_request, socket_globs +from .utils import socket_globs +from .config import SocketConfig +from .client import CliClient import requests from urllib.parse import urlencode import base64 @@ -27,7 +27,6 @@ import platform from glob import glob import time -import sys __all__ = [ "Core", @@ -41,97 +40,165 @@ version = __version__ api_url = "https://api.socket.dev/v0" timeout = 30 -full_scan_path = "" -repository_path = "" all_issues = AllIssues() org_id = None org_slug = None all_new_alerts = False -security_policy = {} allow_unverified_ssl = False log = logging.getLogger("socketdev") -log.addHandler(logging.NullHandler()) -class Core: - token: str - base_api_url: str - request_timeout: int - reports: list - - def __init__( - self, - token: str, - base_api_url: str = None, - request_timeout: int = None, - enable_all_alerts: bool = False, - allow_unverified: bool = False - ): - global allow_unverified_ssl - allow_unverified_ssl = allow_unverified - self.token = token + ":" - encode_key(self.token) - self.socket_date_format = "%Y-%m-%dT%H:%M:%S.%fZ" - self.base_api_url = base_api_url - if self.base_api_url is not None: - Core.set_api_url(self.base_api_url) - self.request_timeout = request_timeout - if self.request_timeout is not None: - Core.set_timeout(self.request_timeout) - if enable_all_alerts: - global all_new_alerts - all_new_alerts = True - Core.set_org_vars() - @staticmethod - def enable_debug_log(level: int): - global log - log.setLevel(level) +def encode_key(token: str) -> None: + """ + encode_key takes passed token string and does a base64 encoding. It sets this as a global variable + :param token: str of the Socket API Security Token + :return: + """ + global encoded_key + encoded_key = base64.b64encode(token.encode()).decode('ascii') + + +def do_request( + path: str, + headers: dict = None, + payload: [dict, str] = None, + files: list = None, + method: str = "GET", + base_url: str = None, +) -> requests.request: + """ + do_requests is the shared function for making HTTP calls + :param base_url: + :param path: Required path for the request + :param headers: Optional dictionary of headers. If not set will use a default set + :param payload: Optional dictionary or string of the payload to pass + :param files: Optional list of files to upload + :param method: Optional method to use, defaults to GET + :return: + """ + + if base_url is not None: + url = f"{base_url}/{path}" + else: + if encoded_key is None or encoded_key == "": + raise APIKeyMissing + url = f"{api_url}/{path}" + + if headers is None: + headers = { + 'Authorization': f"Basic {encoded_key}", + 'User-Agent': f'SocketPythonCLI/{__version__}', + "accept": "application/json" + } + verify = True + if allow_unverified_ssl: + verify = False + response = requests.request( + method.upper(), + url, + headers=headers, + data=payload, + files=files, + timeout=timeout, + verify=verify + ) + output_headers = headers.copy() + output_headers['Authorization'] = "API_KEY_REDACTED" + output = { + "url": url, + "headers": output_headers, + "status_code": response.status_code, + "body": response.text, + "payload": payload, + "files": files, + "timeout": timeout + } + log.debug(output) + if response.status_code <= 399: + return response + elif response.status_code == 400: + raise APIFailure(output) + elif response.status_code == 401: + raise APIAccessDenied("Unauthorized") + elif response.status_code == 403: + raise APIInsufficientQuota("Insufficient max_quota for API method") + elif response.status_code == 404: + raise APIResourceNotFound(f"Path not found {path}") + elif response.status_code == 429: + raise APIInsufficientQuota("Insufficient quota for API route") + elif response.status_code == 524: + raise APICloudflareError(response.text) + else: + msg = { + "status_code": response.status_code, + "UnexpectedError": "There was an unexpected error using the API", + "error": response.text, + "payload": payload, + "url": url + } + raise APIFailure(msg) - @staticmethod - def set_org_vars() -> None: - """ - Sets the main shared global variables - :return: - """ - global org_id, org_slug, full_scan_path, repository_path, security_policy - org_id, org_slug = Core.get_org_id_slug() - base_path = f"orgs/{org_slug}" - full_scan_path = f"{base_path}/full-scans" - repository_path = f"{base_path}/repos" - security_policy = Core.get_security_policy() - @staticmethod - def set_api_url(base_url: str): - """ - Set the global API URl if provided - :param base_url: - :return: - """ - global api_url - api_url = base_url +class Core: + # token: str + # base_api_url: str + # request_timeout: int + # reports: list + + client: CliClient + config: SocketConfig + + def __init__(self, config: SocketConfig, client: CliClient): + self.config = config + self.client = client + self.set_org_vars() + + # def __init__( + # self, + # token: str, + # base_api_url: str = None, + # request_timeout: int = None, + # enable_all_alerts: bool = False, + # allow_unverified: bool = False + # ): + # global allow_unverified_ssl + # allow_unverified_ssl = allow_unverified + # self.token = token + ":" + # encode_key(self.token) + # self.socket_date_format = "%Y-%m-%dT%H:%M:%S.%fZ" + # self.base_api_url = base_api_url + # if self.base_api_url is not None: + # Core.set_api_url(self.base_api_url) + # self.request_timeout = request_timeout + # if self.request_timeout is not None: + # Core.set_timeout(self.request_timeout) + # if enable_all_alerts: + # global all_new_alerts + # all_new_alerts = True + # Core.set_org_vars() + + + def set_org_vars(self) -> None: + """Sets the main shared configuration variables""" + # Get org details + org_id, org_slug = self.get_org_id_slug() + + # Update config with org details FIRST + self.config.org_id = org_id + self.config.org_slug = org_slug + + # Set paths + base_path = f"orgs/{org_slug}" + self.config.full_scan_path = f"{base_path}/full-scans" + self.config.repository_path = f"{base_path}/repos" - @staticmethod - def set_timeout(request_timeout: int): - """ - Set the global Requests timeout - :param request_timeout: - :return: - """ - global timeout - print(f"Setting timeout in module {__name__} at {id(sys.modules[__name__])}") - print(f"Current timeout value: {timeout}") - print(f"Setting to: {request_timeout}") - timeout = request_timeout - print(f"New timeout value: {timeout}") + # Get security policy AFTER org_id is updated + self.config.security_policy = self.get_security_policy() - @staticmethod - def get_org_id_slug() -> (str, str): - """ - Gets the Org ID and Org Slug for the API Token - :return: - """ + def get_org_id_slug(self) -> tuple[str, str]: + """Gets the Org ID and Org Slug for the API Token""" path = "organizations" - response = do_request(path) + response = self.client.request(path) data = response.json() organizations = data.get("organizations") new_org_id = None @@ -142,60 +209,79 @@ def get_org_id_slug() -> (str, str): new_org_slug = organizations[key].get('slug') return new_org_id, new_org_slug - @staticmethod - def get_sbom_data(full_scan_id: str) -> list: - path = f"orgs/{org_slug}/full-scans/{full_scan_id}" - response = do_request(path) + def get_sbom_data(self, full_scan_id: str) -> list: + """ + Return the list of SBOM artifacts for a full scan + """ + path = f"orgs/{self.config.org_slug}/full-scans/{full_scan_id}" + response = self.client.request(path) results = [] - try: - data = response.json() - results = data.get("sbom_artifacts") or [] - except Exception as error: - log.debug("Failed with old style full-scan API using new format") - log.debug(error) - data = response.text - data.strip('"') - data.strip() - for line in data.split("\n"): - if line != '"' and line != "" and line is not None: - item = json.loads(line) - results.append(item) + + if response.status_code != 200: + log.debug(f"Failed to get SBOM data for full-scan {full_scan_id}") + log.debug(response.text) + return [] + data = response.text + data.strip('"') + data.strip() + for line in data.split("\n"): + if line != '"' and line != "" and line is not None: + item = json.loads(line) + results.append(item) + return results - @staticmethod - def get_security_policy() -> dict: - """ - Get the Security policy and determine the effective Org security policy - :return: - """ - path = "settings" - payload = [ - { - "organization": org_id - } - ] - response = do_request(path, payload=json.dumps(payload), method="POST") + # ORIGINAL - remove after verification + # def get_sbom_data(self, full_scan_id: str) -> list: + # path = f"orgs/{self.config.org_slug}/full-scans/{full_scan_id}" + # response = self.client.request(path) + # results = [] + # try: + # data = response.json() + # results = data.get("sbom_artifacts") or [] + # except Exception as error: + # log.debug("Failed with old style full-scan API using new format") + # log.debug(error) + # data = response.text + # data.strip('"') + # data.strip() + # for line in data.split("\n"): + # if line != '"' and line != "" and line is not None: + # item = json.loads(line) + # results.append(item) + # return results + + def get_security_policy(self) -> dict: + """Get the Security policy and determine the effective Org security policy""" + payload = [{"organization": self.config.org_id}] + + response = self.client.request( + path="settings", + method="POST", + payload=json.dumps(payload) + ) + data = response.json() - defaults = data.get("defaults") - default_rules = defaults.get("issueRules") - entries = data.get("entries") + defaults = data.get("defaults", {}) + default_rules = defaults.get("issueRules", {}) + entries = data.get("entries", []) + org_rules = {} + + # Get organization-specific rules for org_set in entries: settings = org_set.get("settings") - if settings is not None: - org_details = settings.get("organization") - org_rules = org_details.get("issueRules") + if settings: + org_details = settings.get("organization", {}) + org_rules.update(org_details.get("issueRules", {})) + + # Apply default rules where no org-specific rule exists for default in default_rules: if default not in org_rules: action = default_rules[default]["action"] - org_rules[default] = { - "action": action - } - return org_rules + org_rules[default] = {"action": action} - # @staticmethod - # def get_supported_file_types() -> dict: - # path = "report/supported" + return org_rules @staticmethod def get_manifest_files(package: Package, packages: dict) -> str: @@ -217,11 +303,10 @@ def get_manifest_files(package: Package, packages: dict) -> str: manifest_files = ";".join(manifests) return manifest_files - @staticmethod - def create_sbom_output(diff: Diff) -> dict: - base_path = f"orgs/{org_slug}/export/cdx" + def create_sbom_output(self, diff: Diff) -> dict: + base_path = f"orgs/{self.config.org_slug}/export/cdx" path = f"{base_path}/{diff.id}" - result = do_request(path=path) + result = self.client.request(path=path) try: sbom = result.json() except Exception as error: @@ -230,6 +315,7 @@ def create_sbom_output(diff: Diff) -> dict: sbom = {} return sbom + # TODO: verify what this does. It looks like it should be named "all_files_unsupported" @staticmethod def match_supported_files(files: list) -> bool: matched_files = [] @@ -281,8 +367,7 @@ def find_files(path: str) -> list: log.info(f"Found {len(files)} in {total_time:.2f} seconds") return list(files) - @staticmethod - def create_full_scan(files: list, params: FullScanParams, workspace: str) -> FullScan: + def create_full_scan(self, files: list, params: FullScanParams, workspace: str) -> FullScan: """ Calls the full scan API to create a new Full Scan :param files: list - Globbed files of manifest files @@ -317,11 +402,11 @@ def create_full_scan(files: list, params: FullScanParams, workspace: str) -> Ful ) send_files.append(payload) query_params = urlencode(params.__dict__) - full_uri = f"{full_scan_path}?{query_params}" - response = do_request(full_uri, method="POST", files=send_files) + full_uri = f"{self.config.full_scan_path}?{query_params}" + response = self.client.request(full_uri, method="POST", files=send_files) results = response.json() full_scan = FullScan(**results) - full_scan.sbom_artifacts = Core.get_sbom_data(full_scan.id) + full_scan.sbom_artifacts = self.get_sbom_data(full_scan.id) create_full_end = time.time() total_time = create_full_end - create_full_start log.debug(f"New Full Scan created in {total_time:.2f} seconds") @@ -337,93 +422,103 @@ def get_license_details(package: Package) -> Package: package.license_text = license_obj.licenseText return package - @staticmethod - def get_head_scan_for_repo(repo_slug: str): - """ - Get the head scan ID for a repository to use for the diff - :param repo_slug: Str - Repo slug for the repository that is being diffed - :return: - """ - repo_path = f"{repository_path}/{repo_slug}" - response = do_request(repo_path) - results = response.json() - repository = Repository(**results) + def get_head_scan_for_repo(self, repo_slug: str) -> str: + """Get the head scan ID for a repository""" + print(f"\nGetting head scan for repo: {repo_slug}") + repo_path = f"{self.config.repository_path}/{repo_slug}" + print(f"Repository path: {repo_path}") + + response = self.client.request(repo_path) + response_data = response.json() + print(f"Raw API Response: {response_data}") # Debug raw response + print(f"Response type: {type(response_data)}") # Debug response type + + if "repository" in response_data: + print(f"Repository data: {response_data['repository']}") # Debug repository data + else: + print("No 'repository' key in response data!") + + repository = Repository(**response_data["repository"]) + print(f"Created repository object: {repository.__dict__}") # Debug final object + return repository.head_full_scan_id - @staticmethod - def get_full_scan(full_scan_id: str) -> FullScan: + def get_full_scan(self, full_scan_id: str) -> FullScan: """ Get the specified full scan and return a FullScan object :param full_scan_id: str - ID of the full scan to pull :return: """ - full_scan_url = f"{full_scan_path}/{full_scan_id}" - response = do_request(full_scan_url) + full_scan_url = f"{self.config.full_scan_path}/{full_scan_id}" + response = self.client.request(full_scan_url) results = response.json() full_scan = FullScan(**results) - full_scan.sbom_artifacts = Core.get_sbom_data(full_scan.id) + full_scan.sbom_artifacts = self.get_sbom_data(full_scan.id) return full_scan - @staticmethod - def create_new_diff( - path: str, - params: FullScanParams, - workspace: str, - no_change: bool = False - ) -> Diff: - """ - 1. Get the head full scan. If it isn't present because this repo doesn't exist yet return an Empty full scan. - 2. Create a new Full scan for the current run - 3. Compare the head and new Full scan - 4. Return a Diff report - :param path: Str - path of where to look for manifest files for the new Full Scan - :param params: FullScanParams - Query params for the Full Scan endpoint - :param workspace: str - Path for workspace - :param no_change: - :return: + def create_new_diff(self, path: str, params: FullScanParams, workspace: str, no_change: bool = False) -> Diff: + """Creates a new diff by comparing a new scan against the head scan of a repository. + + Args: + path: Path to the directory containing files to scan + params: Parameters for the full scan including repo, branch, commit details + workspace: Working directory path + no_change: If True, returns an empty diff without scanning + + Returns: + Diff: A diff object with one of: + - id="no_diff_id" if no_change=True or no files found + - New scan with report_url=diff_url if no head scan exists or repository not found + - New scan compared against head scan with separate report_url and diff_url """ + start_time = time.time() + log.info(f"Starting new diff for {params.repo}") + + # Return empty diff if no changes requested if no_change: - diff = Diff() - diff.id = "no_diff_id" - return diff - files = Core.find_files(path) - if files is None or len(files) == 0: - diff = Diff() - diff.id = "no_diff_id" - return diff + log.info("No change requested, returning empty diff") + return Diff(id="no_diff_id") + + # Return empty diff if no files to scan + files = self.find_files(path) + if not files: + log.info("No files found to scan, returning empty diff") + return Diff(id="no_diff_id") + + # Get head scan ID for the repository try: - head_full_scan_id = Core.get_head_scan_for_repo(params.repo) - if head_full_scan_id is None or head_full_scan_id == "": - head_full_scan = [] - else: - head_start = time.time() - head_full_scan = Core.get_sbom_data(head_full_scan_id) - head_end = time.time() - total_head_time = head_end - head_start - log.info(f"Total time to get head full-scan {total_head_time: .2f}") + head_full_scan_id = self.get_head_scan_for_repo(params.repo) except APIResourceNotFound: + log.info("Repository not found, creating new scan without comparison") head_full_scan_id = None - head_full_scan = [] - new_scan_start = time.time() - new_full_scan = Core.create_full_scan(files, params, workspace) - new_full_scan.packages = Core.create_sbom_dict(new_full_scan.sbom_artifacts) - new_scan_end = time.time() - total_new_time = new_scan_end - new_scan_start - log.info(f"Total time to get new full-scan {total_new_time: .2f}") - diff_report = Core.compare_sboms(new_full_scan.sbom_artifacts, head_full_scan) - diff_report.packages = new_full_scan.packages - # Set the diff ID and URLs - base_socket = "https://socket.dev/dashboard/org" - diff_report.id = new_full_scan.id - diff_report.report_url = f"{base_socket}/{org_slug}/sbom/{diff_report.id}" - if head_full_scan_id is not None: - diff_report.diff_url = f"{base_socket}/{org_slug}/diff/{diff_report.id}/{head_full_scan_id}" + + # Create new scan with no comparison if no head scan exists + if not head_full_scan_id: + log.info("No head scan found, creating new scan without comparison") + new_scan = self.create_full_scan(path, params, workspace) + base_url = f"https://socket.dev/dashboard/org/{self.config.org_slug}" + diff = Diff( + id=new_scan.id, + report_url=f"{base_url}/sbom/{new_scan.id}", + diff_url=f"{base_url}/sbom/{new_scan.id}" + ) else: - diff_report.diff_url = diff_report.report_url - return diff_report + # Create new scan and compare against head scan + log.info(f"Creating new scan and comparing against head scan {head_full_scan_id}") + new_scan = self.create_full_scan(path, params, workspace) + base_url = f"https://socket.dev/dashboard/org/{self.config.org_slug}" + diff = Diff( + id=new_scan.id, + report_url=f"{base_url}/sbom/{new_scan.id}", + diff_url=f"{base_url}/diff/{new_scan.id}/{head_full_scan_id}" + ) - @staticmethod - def compare_sboms(new_scan: list, head_scan: list) -> Diff: + end_time = time.time() + duration = end_time - start_time + log.info(f"Completed diff creation in {duration:.2f} seconds") + return diff + + def compare_sboms(self, new_scan: list, head_scan: list) -> Diff: """ compare the SBOMs of the new full Scan and the head full scan. Return a Diff report with new packages, removed packages, and new alerts for the new full scan compared to the head. @@ -444,12 +539,12 @@ def compare_sboms(new_scan: list, head_scan: list) -> Diff: if package_id not in head_packages and package.direct and base_purl not in consolidated: diff.new_packages.append(purl) consolidated.add(base_purl) - new_scan_alerts = Core.create_issue_alerts(package, new_scan_alerts, new_packages) + new_scan_alerts = self.create_issue_alerts(package, new_scan_alerts, new_packages) for package_id in head_packages: purl, package = Core.create_purl(package_id, head_packages) if package_id not in new_packages and package.direct: diff.removed_packages.append(purl) - head_scan_alerts = Core.create_issue_alerts(package, head_scan_alerts, head_packages) + head_scan_alerts = self.create_issue_alerts(package, head_scan_alerts, head_packages) diff.new_alerts = Core.compare_issue_alerts(new_scan_alerts, head_scan_alerts, diff.new_alerts) diff.new_capabilities = Core.compare_capabilities(new_packages, head_packages) diff = Core.add_capabilities_to_purl(diff) @@ -505,8 +600,12 @@ def check_alert_capabilities( new_alert = True if head_package is not None and alert in head_package.alerts: new_alert = False - if alert["type"] in alert_types and new_alert: - value = alert_types[alert["type"]] + + # Support both dictionary and Alert object access + alert_type = alert.type if hasattr(alert, 'type') else alert["type"] + + if alert_type in alert_types and new_alert: + value = alert_types[alert_type] if package_id not in capabilities: capabilities[package_id] = [value] else: @@ -547,51 +646,29 @@ def compare_issue_alerts(new_scan_alerts: dict, head_scan_alerts: dict, alerts: consolidated_alerts.append(alert_str) return alerts - @staticmethod - def create_issue_alerts(package: Package, alerts: dict, packages: dict) -> dict: - """ - Create the Issue Alerts from the package and base alert data. - :param package: Package - Current package that is being looked at for Alerts - :param alerts: Dict - All found Issue Alerts across all packages - :param packages: Dict - All packages detected in the SBOM and needed to find top level packages - :return: - """ - for item in package.alerts: - alert = Alert(**item) - try: - props = getattr(all_issues, alert.type) - except AttributeError: - props = None - if props is not None: - description = props.description - title = props.title - suggestion = props.suggestion - next_step_title = props.nextStepTitle - else: - description = "" - title = "" - suggestion = "" - next_step_title = "" - introduced_by = Core.get_source_data(package, packages) + def create_issue_alerts(self, package: Package, alerts: dict, packages: dict) -> dict: + """Create issue alerts for a package""" + for alert in package.alerts: + if not hasattr(self.config.all_issues, alert.type): + continue + props = getattr(self.config.all_issues, alert.type) + introduced_by = self.get_source_data(package, packages) + suggestion = getattr(props, 'suggestion', None) + next_step_title = getattr(props, 'nextStepTitle', None) issue_alert = Issue( - pkg_type=package.type, - pkg_name=package.name, - pkg_version=package.version, - pkg_id=package.id, + key=alert.key, type=alert.type, severity=alert.severity, - key=alert.key, - props=alert.props, - description=description, - title=title, + description=props.description, + title=props.title, suggestion=suggestion, next_step_title=next_step_title, introduced_by=introduced_by, purl=package.purl, url=package.url ) - if alert.type in security_policy: - action = security_policy[alert.type]['action'] + if alert.type in self.config.security_policy: + action = self.config.security_policy[alert.type]['action'] setattr(issue_alert, action, True) if issue_alert.type != 'licenseSpdxDisj': if issue_alert.key not in alerts: diff --git a/socketsecurity/core/classes.py b/socketsecurity/core/classes.py index 8b93826..005c145 100644 --- a/socketsecurity/core/classes.py +++ b/socketsecurity/core/classes.py @@ -124,7 +124,8 @@ def __init__(self, **kwargs): if not hasattr(self, "license_text"): self.license_text = "" self.url = f"https://socket.dev/{self.type}/package/{self.name}/overview/{self.version}" - self.purl = f"{self.type}/{self.name}@{self.version}" + if not hasattr(self, "purl"): + self.purl = f"pkg:{self.type}/{self.name}@{self.version}" def __str__(self): return json.dumps(self.__dict__) @@ -282,9 +283,12 @@ class Repository: default_branch: str def __init__(self, **kwargs): + print(f"Repository.__init__ called with kwargs: {kwargs}") # Debug if kwargs: for key, value in kwargs.items(): + print(f"Setting {key}={value}") # Debug setattr(self, key, value) + print(f"Final Repository object dict: {self.__dict__}") # Debug def __str__(self): return json.dumps(self.__dict__) diff --git a/socketsecurity/core/client.py b/socketsecurity/core/client.py new file mode 100644 index 0000000..8a5731d --- /dev/null +++ b/socketsecurity/core/client.py @@ -0,0 +1,54 @@ +import base64 +import requests +from typing import Optional, Dict, Union, List +import logging +from .config import SocketConfig +from .exceptions import APIFailure + +logger = logging.getLogger("socketdev") + +class CliClient: + def __init__(self, config: SocketConfig): + self.config = config + self._encoded_key = self._encode_key(config.api_key) + + @staticmethod + def _encode_key(token: str) -> str: + return base64.b64encode(f"{token}:".encode()).decode('ascii') + + def request( + self, + path: str, + method: str = "GET", + headers: Optional[Dict] = None, + payload: Optional[Union[Dict, str]] = None, + files: Optional[List] = None, + base_url: Optional[str] = None + ) -> requests.Response: + url = f"{base_url or self.config.api_url}/{path}" + + default_headers = { + 'Authorization': f"Basic {self._encoded_key}", + 'User-Agent': 'SocketPythonCLI/0.0.1', + "accept": "application/json" + } + + headers = headers or default_headers + + try: + response = requests.request( + method=method.upper(), + url=url, + headers=headers, + data=payload, + files=files, + timeout=self.config.timeout, + verify=not self.config.allow_unverified_ssl + ) + + response.raise_for_status() + return response + + except requests.exceptions.RequestException as e: + logger.error(f"API request failed: {str(e)}") + raise APIFailure(f"Request failed: {str(e)}") diff --git a/socketsecurity/core/config.py b/socketsecurity/core/config.py index ba8d15f..04a8881 100644 --- a/socketsecurity/core/config.py +++ b/socketsecurity/core/config.py @@ -1,32 +1,60 @@ from dataclasses import dataclass -from typing import ClassVar +from typing import Optional, Dict +from urllib.parse import urlparse +from socketsecurity.core.issues import AllIssues @dataclass -class CoreConfig: - """Configuration for the Socket Security Core class""" - - # Required - token: str - - # Optional with defaults +class SocketConfig: + api_key: str api_url: str = "https://api.socket.dev/v0" timeout: int = 30 - enable_all_alerts: bool = False allow_unverified_ssl: bool = False + org_id: Optional[str] = None + org_slug: Optional[str] = None + full_scan_path: Optional[str] = None + repository_path: Optional[str] = None + security_policy: Dict = None + all_issues: Optional['AllIssues'] = None - # Constants - SOCKET_DATE_FORMAT: ClassVar[str] = "%Y-%m-%dT%H:%M:%S.%fZ" - DEFAULT_API_URL: ClassVar[str] = "https://api.socket.dev/v0" - DEFAULT_TIMEOUT: ClassVar[int] = 30 + def __post_init__(self): + """Validate configuration after initialization""" + if not self.api_key: + raise ValueError("API key is required") - def __post_init__(self) -> None: - """Validate and process config after initialization""" - # Business rule validations - if not self.token: - raise ValueError("Token is required") if self.timeout <= 0: - raise ValueError("Timeout must be positive") + raise ValueError("Timeout must be a positive integer") + + self._validate_api_url(self.api_url) + + # Initialize empty dict for security policy if None + if self.security_policy is None: + self.security_policy = {} + + # Initialize AllIssues if None + if self.all_issues is None: + from socketsecurity.core.issues import AllIssues + self.all_issues = AllIssues() + + @staticmethod + def _validate_api_url(url: str) -> None: + """Validate that the API URL is a valid HTTPS URL""" + try: + parsed = urlparse(url) + if not all([parsed.scheme, parsed.netloc]): + raise ValueError("Invalid URL format") + if parsed.scheme != "https": + raise ValueError("API URL must use HTTPS") + except Exception as e: + raise ValueError(f"Invalid API URL: {str(e)}") + + def update_org_details(self, org_id: str, org_slug: str) -> None: + """Update organization details and related paths""" + self.org_id = org_id + self.org_slug = org_slug + base_path = f"orgs/{org_slug}" + self.full_scan_path = f"{base_path}/full-scans" + self.repository_path = f"{base_path}/repos" - # Business logic - if not self.token.endswith(':'): - self.token = f"{self.token}:" + def update_security_policy(self, policy: Dict) -> None: + """Update security policy""" + self.security_policy = policy \ No newline at end of file diff --git a/socketsecurity/core/logging.py b/socketsecurity/core/logging.py new file mode 100644 index 0000000..c0ff12d --- /dev/null +++ b/socketsecurity/core/logging.py @@ -0,0 +1,32 @@ +import logging + +def initialize_logging( + level: int = logging.INFO, + format: str = "%(asctime)s: %(message)s", + socket_logger_name: str = "socketdev", + cli_logger_name: str = "socketcli" +) -> tuple[logging.Logger, logging.Logger]: + """Initialize logging for Socket Security + + Returns both the socket and CLI loggers for convenience, though they can also + be accessed via logging.getLogger() elsewhere + """ + # Configure root logger + logging.basicConfig(level=level, format=format) + + # Configure Socket logger + socket_logger = logging.getLogger(socket_logger_name) + socket_logger.setLevel(level) + socket_logger.addHandler(logging.NullHandler()) + + # Configure CLI logger + cli_logger = logging.getLogger(cli_logger_name) + cli_logger.setLevel(level) + + return socket_logger, cli_logger + +def set_debug_mode(enable: bool = True) -> None: + """Toggle debug logging across all loggers""" + level = logging.DEBUG if enable else logging.INFO + logging.getLogger("socketdev").setLevel(level) + logging.getLogger("socketcli").setLevel(level) \ No newline at end of file diff --git a/socketsecurity/core/utils.py b/socketsecurity/core/utils.py index 6fab6c5..c7a45b0 100644 --- a/socketsecurity/core/utils.py +++ b/socketsecurity/core/utils.py @@ -1,51 +1,3 @@ -import base64 -import requests -from typing import Optional, Dict, List, Union -from socketsecurity import __version__ -from socketsecurity.core.exceptions import APIKeyMissing - -def encode_key(token: str) -> str: - """Encode API token in base64""" - return base64.b64encode(token.encode()).decode('ascii') - -def do_request( - path: str, - headers: Optional[Dict] = None, - payload: Optional[Union[Dict, str]] = None, - files: Optional[List] = None, - method: str = "GET", - base_url: Optional[str] = None, - api_key: Optional[str] = None, - timeout: int = 30, - verify_ssl: bool = True -) -> requests.Response: - """Make HTTP requests to Socket API""" - - if base_url is not None: - url = f"{base_url}/{path}" - else: - if not api_key: - raise APIKeyMissing - url = f"https://api.socket.dev/v0/{path}" - - if headers is None: - headers = { - 'Authorization': f"Basic {api_key}", - 'User-Agent': f'SocketPythonCLI/{__version__}', - "accept": "application/json" - } - - response = requests.request( - method.upper(), - url, - headers=headers, - data=payload, - files=files, - timeout=timeout, - verify=verify_ssl - ) - - return response # File pattern definitions socket_globs = { diff --git a/socketsecurity/socketcli.py b/socketsecurity/socketcli.py index 381ebbd..3496c9e 100644 --- a/socketsecurity/socketcli.py +++ b/socketsecurity/socketcli.py @@ -3,6 +3,7 @@ import socketsecurity.core from socketsecurity.core import Core, __version__ +from socketsecurity.logging import initialize_logging, set_debug_mode from socketsecurity.core.classes import FullScanParams, Diff, Package, Issue from socketsecurity.core.messages import Messages from socketsecurity.core.scm_comments import Comments @@ -12,10 +13,12 @@ import sys import logging +socket_logger, cli_logger = initialize_logging() + log_format = "%(asctime)s: %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) socketsecurity.core.log.setLevel(level=logging.INFO) -log = logging.getLogger("socketcli") +log = cli_logger blocking_disabled = False parser = argparse.ArgumentParser( @@ -229,12 +232,11 @@ def cli(): def main_code(): arguments = parser.parse_args() - debug = arguments.enable_debug - if debug: - logging.basicConfig(level=logging.DEBUG, format=log_format) - log.setLevel(logging.DEBUG) - Core.enable_debug_log(logging.DEBUG) + + if arguments.enable_debug: + set_debug_mode(True) log.debug("Debug logging enabled") + repo = arguments.repo branch = arguments.branch commit_message = arguments.commit_message @@ -252,9 +254,11 @@ def main_code(): ignore_commit_files = arguments.ignore_commit_files disable_blocking = arguments.disable_blocking allow_unverified = arguments.allow_unverified + if disable_blocking: global blocking_disabled blocking_disabled = True + files = arguments.files log.info(f"Starting Socket Security Scan version {__version__}") api_token = os.getenv("SOCKET_SECURITY_API_KEY") or arguments.api_token diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 7a025e5..e69de29 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -1,25 +0,0 @@ -import pytest - -# Test data -TEST_API_TOKEN = "test-token" -TEST_API_KEY_ENCODED = "dGVzdC10b2tlbjo=" -DEFAULT_API_URL = "https://api.socket.dev/v0" -DEFAULT_TIMEOUT = 30 - -@pytest.fixture(autouse=True) -def reset_globals(): - """Reset global state after each test""" - from socketsecurity.core import Core - yield - Core.set_api_url(DEFAULT_API_URL) - Core.set_timeout(DEFAULT_TIMEOUT) - -@pytest.fixture -def mock_org_response(): - return { - "organizations": { - "test-org-123": { - "slug": "test-org" - } - } - } diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py new file mode 100644 index 0000000..0077687 --- /dev/null +++ b/tests/unit/test_client.py @@ -0,0 +1,125 @@ +import pytest +from unittest.mock import Mock, patch +import requests +from socketsecurity.core.client import CliClient +from socketsecurity.core.config import SocketConfig +from socketsecurity.core.exceptions import APIFailure + +@pytest.fixture +def config(): + return SocketConfig( + api_key="test_key", + timeout=30, + allow_unverified_ssl=False + ) + +@pytest.fixture +def client(config): + return CliClient(config) + +def test_encode_key(): + """Test the static key encoding method""" + encoded = CliClient._encode_key("test_key") + assert encoded == "dGVzdF9rZXk6" # base64 of "test_key:" + +def test_request_builds_correct_url(client): + """Test URL construction""" + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + client.request("test/path") + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert kwargs['url'] == "https://api.socket.dev/v0/test/path" + +def test_request_uses_config_timeout(client): + """Test timeout is passed from config""" + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + client.request("test/path") + + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert kwargs['timeout'] == 30 + +def test_request_handles_api_error(): + """Test error handling""" + config = SocketConfig(api_key="test_key") + client = CliClient(config) + + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 400 + mock_response.raise_for_status.side_effect = requests.exceptions.RequestException("Test error") + mock_request.return_value = mock_response + + with pytest.raises(APIFailure): + client.request("test/path") + +def test_request_uses_custom_headers(client): + """Test that custom headers override defaults""" + custom_headers = {"Authorization": "Bearer token", "Custom": "Value"} + + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + client.request("test/path", headers=custom_headers) + + args, kwargs = mock_request.call_args + assert kwargs['headers'] == custom_headers + +def test_request_uses_custom_base_url(client): + """Test that custom base_url overrides default""" + custom_base = "https://custom.api.com" + + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + client.request("test/path", base_url=custom_base) + + args, kwargs = mock_request.call_args + assert kwargs['url'] == f"{custom_base}/test/path" + +def test_request_ssl_verification(client): + """Test SSL verification setting from config""" + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + client.request("test/path") + + args, kwargs = mock_request.call_args + assert kwargs['verify'] == True # Default is True + + # Test with SSL verification disabled + client.config.allow_unverified_ssl = True + client.request("test/path") + + args, kwargs = mock_request.call_args + assert kwargs['verify'] == False + +def test_request_with_payload(client): + """Test request with payload data""" + payload = {"key": "value"} + + with patch('requests.request') as mock_request: + mock_response = Mock() + mock_response.status_code = 200 + mock_request.return_value = mock_response + + client.request("test/path", method="POST", payload=payload) + + args, kwargs = mock_request.call_args + assert kwargs['method'] == "POST" + assert kwargs['data'] == payload \ No newline at end of file diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 5ac1b67..2bfb963 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,48 +1,80 @@ import pytest -from socketsecurity.core.config import CoreConfig +from socketsecurity.core.config import SocketConfig -def test_config_initialization() -> None: - """Test basic config initialization with defaults""" - config = CoreConfig(token="test-token") +def test_config_default_values(): + """Test that config initializes with correct default values""" + config = SocketConfig(api_key="test_key") - assert config.token == "test-token:" - assert config.api_url == CoreConfig.DEFAULT_API_URL - assert config.timeout == CoreConfig.DEFAULT_TIMEOUT - assert config.enable_all_alerts is False + assert config.api_key == "test_key" + assert config.api_url == "https://api.socket.dev/v0" + assert config.timeout == 30 assert config.allow_unverified_ssl is False + assert config.org_id is None + assert config.org_slug is None + assert config.full_scan_path is None + assert config.repository_path is None + assert config.security_policy == {} -def test_config_custom_values() -> None: - """Test config with custom values""" - config = CoreConfig( - token="test-token", - api_url="https://custom.api", +def test_config_custom_values(): + """Test that config accepts custom values""" + config = SocketConfig( + api_key="test_key", + api_url="https://custom.api.dev/v1", timeout=60, - enable_all_alerts=True, allow_unverified_ssl=True ) - assert config.token == "test-token:" - assert config.api_url == "https://custom.api" + assert config.api_key == "test_key" + assert config.api_url == "https://custom.api.dev/v1" assert config.timeout == 60 - assert config.enable_all_alerts is True assert config.allow_unverified_ssl is True -def test_config_validation() -> None: - """Test business rule validation""" - # Test empty token - with pytest.raises(ValueError, match="Token is required"): - CoreConfig(token="") - - # Test invalid timeout - with pytest.raises(ValueError, match="Timeout must be positive"): - CoreConfig(token="test", timeout=0) - -def test_token_formatting() -> None: - """Test token colon suffix business rule""" - # Without colon - config = CoreConfig(token="test-token") - assert config.token == "test-token:" - - # Already has colon - config = CoreConfig(token="test-token:") - assert config.token == "test-token:" +def test_config_api_key_required(): + """Test that api_key is required""" + with pytest.raises(ValueError): + SocketConfig(api_key=None) + + with pytest.raises(ValueError): + SocketConfig(api_key="") + +def test_config_invalid_timeout(): + """Test that timeout must be positive""" + with pytest.raises(ValueError): + SocketConfig(api_key="test_key", timeout=0) + + with pytest.raises(ValueError): + SocketConfig(api_key="test_key", timeout=-1) + +def test_config_invalid_api_url(): + """Test that api_url must be valid HTTPS URL""" + with pytest.raises(ValueError): + SocketConfig(api_key="test_key", api_url="not_a_url") + + with pytest.raises(ValueError): + SocketConfig(api_key="test_key", api_url="http://insecure.com") # Must be HTTPS + +def test_config_update_org_details(): + """Test updating org details""" + config = SocketConfig(api_key="test_key") + + config.org_id = "test_org_id" + config.org_slug = "test-org" + config.full_scan_path = "orgs/test-org/full-scans" + config.repository_path = "orgs/test-org/repos" + + assert config.org_id == "test_org_id" + assert config.org_slug == "test-org" + assert config.full_scan_path == "orgs/test-org/full-scans" + assert config.repository_path == "orgs/test-org/repos" + +def test_config_update_security_policy(): + """Test updating security policy""" + config = SocketConfig(api_key="test_key") + + test_policy = { + "rule1": {"action": "block"}, + "rule2": {"action": "warn"} + } + + config.security_policy = test_policy + assert config.security_policy == test_policy diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index 7f2751f..bb40c5d 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,115 +1,691 @@ -from unittest.mock import Mock, patch -from socketsecurity.core import Core, timeout, api_url -from socketsecurity.core.classes import Package, Purl -from tests.unit import TEST_API_TOKEN, mock_org_response -import sys - -# Basic initialization and utility tests -@patch('socketsecurity.core.do_request') -def test_core_initialization(mock_do_request): - # Mock responses for both API calls - mock_responses = [ - # First call for get_org_id_slug - Mock(json=lambda: {"organizations": {"test-org-123": {"slug": "test-org"}}}), - # Second call for get_security_policy - Mock(json=lambda: { - "defaults": { - "issueRules": { - "noTests": {"action": "warn"}, - "noV1": {"action": "error"} - } - }, - "entries": [] - }) +import pytest +from unittest.mock import mock_open, patch, MagicMock +from socketsecurity.core import Core +from socketsecurity.core.classes import Package +from socketsecurity.core.classes import Purl +from socketsecurity.core.classes import Diff +from socketsecurity.core.classes import Issue + +@pytest.fixture +def sample_package(): + return Package( + id="pkg1", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[], + author=["Test Author"], + size=1000, + url="https://example.com", + purl="pkg:npm/test-package@1.0.0" + ) + +def test_match_supported_files(): + """Test matching of supported file types""" + files = [ + "package.json", + "requirements.txt", + "unsupported.xyz" ] - mock_do_request.side_effect = mock_responses - - core = Core(token=TEST_API_TOKEN) - assert core.token == f"{TEST_API_TOKEN}:" - assert core.base_api_url is None - assert core.request_timeout is None - - # Verify both API calls were made - assert mock_do_request.call_count == 2 - -def test_core_set_timeout(): - current_module = sys.modules[__name__] - core_module = sys.modules['socketsecurity.core'] - - print(f"Test module id: {id(current_module)}") - print(f"Core module id: {id(core_module)}") - print(f"Timeout in test: {id(timeout)}") - print(f"Timeout in core: {id(core_module.timeout)}") - - Core.set_timeout(60) - assert timeout == 60 - -def test_core_set_api_url(): - test_url = "https://test.api.com" - Core.set_api_url(test_url) - assert api_url == test_url - -# File handling tests -def test_save_file(tmp_path): - test_file = tmp_path / "test.txt" - Core.save_file(str(test_file), "test content") - assert test_file.read_text() == "test content" - -# Package handling tests -def test_create_sbom_dict(): - test_sbom = [{ - "id": "test-pkg@1.0.0", - "name": "test-pkg", - "version": "1.0.0", - "type": "npm", - "direct": True, - "topLevelAncestors": [], - "manifestFiles": [{"file": "package.json"}], - "alerts": [] - }] - - result = Core.create_sbom_dict(test_sbom) - assert "test-pkg@1.0.0" in result - assert result["test-pkg@1.0.0"].name == "test-pkg" - -# Capability checking tests -def test_check_alert_capabilities(): - package = Package(**{ - "id": "test-pkg@1.0.0", - "alerts": [{"type": "envVars"}] - }) - capabilities = {} - result = Core.check_alert_capabilities(package, capabilities, package.id) - assert result[package.id] == ["Environment"] + result = Core.match_supported_files(files) + assert result is False # Found supported files -# PURL creation tests -def test_create_purl(): - packages = { - "test-pkg@1.0.0": Package(**{ - "id": "test-pkg@1.0.0", - "name": "test-pkg", + files = ["unsupported.xyz"] + result = Core.match_supported_files(files) + assert result is True # No supported files + +def test_save_file(): + """Test file saving functionality""" + with patch('builtins.open', mock_open()) as mock_file: + Core.save_file("test.txt", "test content") + mock_file.assert_called_once_with("test.txt", "w") + mock_file().write.assert_called_once_with("test content") + +def test_get_manifest_files(): + """Test manifest file handling for all branches""" + # Branch 1: Direct package with single manifest + direct_pkg_single = Package( + id="pkg1", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[] + ) + packages = {"pkg1": direct_pkg_single} + result = Core.get_manifest_files(direct_pkg_single, packages) + assert result == "package.json" + + # Branch 2: Direct package with multiple manifests + direct_pkg_multiple = Package( + id="pkg2", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[ + {"file": "package.json"}, + {"file": "package-lock.json"} + ], + alerts=[] + ) + packages["pkg2"] = direct_pkg_multiple + result = Core.get_manifest_files(direct_pkg_multiple, packages) + assert result == "package.json;package-lock.json" + + # Branch 3: Transitive package with single top-level ancestor + transitive_single = Package( + id="pkg3", + name="transitive-pkg", + version="1.0.0", + type="npm", + direct=False, + topLevelAncestors=["pkg1"], + manifestFiles=[], + alerts=[] + ) + packages["pkg3"] = transitive_single + result = Core.get_manifest_files(transitive_single, packages) + assert result == "transitive-pkg@1.0.0(package.json)" + + # Branch 4: Transitive package with multiple top-level ancestors + transitive_multiple = Package( + id="pkg4", + name="transitive-pkg", + version="1.0.0", + type="npm", + direct=False, + topLevelAncestors=["pkg2"], + manifestFiles=[], + alerts=[] + ) + packages["pkg4"] = transitive_multiple + result = Core.get_manifest_files(transitive_multiple, packages) + assert result == "transitive-pkg@1.0.0(package.json);transitive-pkg@1.0.0(package-lock.json)" + +def test_create_sbom_dict_all_branches(): + """Test SBOM dictionary creation covering all conditional branches""" + sbom_data = [ + { # Top-level package with transitives + "id": "pkg1", + "name": "root-pkg", "version": "1.0.0", "type": "npm", "direct": True, - "topLevelAncestors": [], "manifestFiles": [{"file": "package.json"}], - "alerts": [] - }) + "alerts": [], + "topLevelAncestors": [] + }, + { # First transitive + "id": "pkg2", + "name": "dep-pkg", + "version": "2.0.0", + "type": "npm", + "direct": False, + "manifestFiles": [], + "alerts": [], + "topLevelAncestors": ["pkg1"] + }, + { # Second transitive for same top-level + "id": "pkg3", + "name": "another-dep", + "version": "1.0.0", + "type": "npm", + "direct": False, + "manifestFiles": [], + "alerts": [], + "topLevelAncestors": ["pkg1"] + }, + { # Duplicate package (same ID as pkg1) + "id": "pkg1", + "name": "root-pkg", + "version": "1.0.0", + "type": "npm", + "direct": True, + "manifestFiles": [{"file": "package.json"}], + "alerts": [], + "topLevelAncestors": [] + }, + { # Package with no transitives + "id": "pkg4", + "name": "standalone", + "version": "1.0.0", + "type": "npm", + "direct": True, + "manifestFiles": [{"file": "package.json"}], + "alerts": [], + "topLevelAncestors": [] + } + ] + + with patch('builtins.print') as mock_print: + result = Core.create_sbom_dict(sbom_data) + + mock_print.assert_called_once_with("Duplicate package?") + assert len(result) == 4 + + # Verify transitive counting + assert result["pkg1"].transitives == 2 # Two transitive dependencies + assert result["pkg4"].transitives == 0 # No transitives, but property exists + + # Verify all packages present + assert all(pkg_id in result for pkg_id in ["pkg1", "pkg2", "pkg3", "pkg4"]) + +def test_check_alert_capabilities(): + """Test all branches of alert capability checking""" + package = Package( + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + alerts=[ + {"type": "envVars"}, + {"type": "networkAccess"}, + {"type": "unsupportedType"}, # Should be ignored + {"type": "shellAccess"} # Will be duplicate in capabilities + ] + ) + + # Test new package (no head_package) + capabilities = {} + result = Core.check_alert_capabilities(package, capabilities, "pkg1") + assert "pkg1" in result + assert set(result["pkg1"]) == {"Environment", "Network", "Shell"} + + # Test with existing capabilities + capabilities = {"pkg1": ["Shell"]} # Existing capability + result = Core.check_alert_capabilities(package, capabilities, "pkg1") + assert len(result["pkg1"]) == 3 + assert "Shell" in result["pkg1"] # Existing capability + assert "Environment" in result["pkg1"] # New capability + assert "Network" in result["pkg1"] # New capability + + # Test with head_package having some matching alerts + head_package = Package( + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + alerts=[ + {"type": "envVars"}, # Should be skipped as existing + {"type": "filesystemAccess"} # Different alert + ] + ) + + capabilities = {} + result = Core.check_alert_capabilities(package, capabilities, "pkg1", head_package) + assert "pkg1" in result + assert "Environment" not in result["pkg1"] # Should be skipped + assert "Network" in result["pkg1"] # Should be included + assert "File System" not in result["pkg1"] # Not in new package + +def test_add_capabilities_to_purl(): + """Test adding capabilities to PURLs in a diff""" + diff = Diff() + + # Create PURLs with and without capabilities + purl1 = Purl( + id="pkg1", + name="test-pkg", + version="1.0.0", + ecosystem="npm", + direct=True, + introduced_by=[("direct", "package.json")], + author=["Test Author"], + size=1000, + url="https://example.com", + purl="pkg:npm/test-pkg@1.0.0" + ) + purl2 = Purl( + id="pkg2", + name="other-pkg", + version="2.0.0", + ecosystem="npm", + direct=True, + introduced_by=[("direct", "package.json")], + author=["Other Author"], + size=2000, + url="https://example.com/other", + purl="pkg:npm/other-pkg@2.0.0" + ) + + diff.new_packages = [purl1, purl2] + diff.new_capabilities = { + "pkg1": ["Network", "Shell"], # Has capabilities + # pkg2 intentionally missing from capabilities + } + + result = Core.add_capabilities_to_purl(diff) + + # Verify purl1 got its capabilities + assert result.new_packages[0].capabilities == ["Network", "Shell"] + + # Verify purl2 has empty capabilities + assert result.new_packages[1].capabilities == {} + + # Verify both PURLs are still present + assert len(result.new_packages) == 2 + +def test_compare_capabilities(): + """Test comparison of capabilities between package sets""" + # Create test packages + new_pkg1 = Package( + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + alerts=[ + {"type": "envVars"}, + {"type": "networkAccess"} + ] + ) + + new_pkg2 = Package( + id="pkg2", + name="other-pkg", + version="2.0.0", + type="npm", + direct=True, + alerts=[ + {"type": "shellAccess"} + ] + ) + + head_pkg1 = Package( + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + alerts=[ + {"type": "envVars"} # Only environment vars in head + ] + ) + + # Test cases: + # 1. Package exists in head with some matching alerts + # 2. Package exists in head with no matching alerts + # 3. Package doesn't exist in head + + new_packages = { + "pkg1": new_pkg1, + "pkg2": new_pkg2 } - purl, package = Core.create_purl("test-pkg@1.0.0", packages) - assert isinstance(purl, Purl) - assert purl.name == "test-pkg" - assert purl.version == "1.0.0" - -# API interaction tests -@patch('socketsecurity.core.do_request') -def test_get_org_id_slug(mock_do_request): - mock_response = Mock() - mock_response.json.return_value = mock_org_response - mock_do_request.return_value = mock_response - - org_id, org_slug = Core.get_org_id_slug() - assert org_id == "test-org-123" - assert org_slug == "test-org" + head_packages = { + "pkg1": head_pkg1 + } + + result = Core.compare_capabilities(new_packages, head_packages) + + # Verify pkg1 only shows new capability + assert "Environment" not in result["pkg1"] # Exists in head + assert "Network" in result["pkg1"] # New in current + + # Verify pkg2 shows all capabilities (not in head) + assert "Shell" in result["pkg2"] + + # Verify no unexpected capabilities + assert len(result["pkg1"]) == 1 + assert len(result["pkg2"]) == 1 + +def test_get_source_data(): + """Test source data generation for direct and transitive packages""" + # Test Case 1: Direct package with single manifest + direct_pkg = Package( + id="pkg1", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[] + ) + packages = {"pkg1": direct_pkg} + result = Core.get_source_data(direct_pkg, packages) + assert result == [("direct", "package.json")] + + # Test Case 2: Direct package with multiple manifests + direct_multi = Package( + id="pkg2", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[ + {"file": "package.json"}, + {"file": "package-lock.json"} + ], + alerts=[] + ) + packages["pkg2"] = direct_multi + result = Core.get_source_data(direct_multi, packages) + assert result == [("direct", "package.json;package-lock.json")] + + # Test Case 3: Transitive package with single top-level ancestor + top_pkg = Package( + id="top1", + name="top-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[] + ) + transitive = Package( + id="trans1", + name="trans-package", + version="2.0.0", + type="npm", + direct=False, + topLevelAncestors=["top1"], + manifestFiles=[], + alerts=[] + ) + packages.update({ + "top1": top_pkg, + "trans1": transitive + }) + result = Core.get_source_data(transitive, packages) + assert result == [("npm/top-package@1.0.0", "package.json")] + + # Test Case 4: Transitive package with multiple top-level ancestors + top_pkg2 = Package( + id="top2", + name="top-package-2", + version="3.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "other-package.json"}], + alerts=[] + ) + transitive_multi = Package( + id="trans2", + name="trans-package-2", + version="4.0.0", + type="npm", + direct=False, + topLevelAncestors=["top1", "top2"], + manifestFiles=[], + alerts=[] + ) + packages.update({ + "top2": top_pkg2, + "trans2": transitive_multi + }) + result = Core.get_source_data(transitive_multi, packages) + assert result == [ + ("npm/top-package@1.0.0", "package.json"), + ("npm/top-package-2@3.0.0", "other-package.json") + ] + +def test_create_purl(): + """Test PURL creation with all required fields""" + # Test Case 1: Direct package with provided PURL + direct_pkg = Package( + id="pkg1", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[], + author=["Test Author"], + size=1000, + transitives=0, + url="https://socket.dev/npm/package/test-package/overview/1.0.0", + purl="pkg:npm/test-package@1.0.0" # Explicitly provided PURL + ) + packages = {"pkg1": direct_pkg} + purl, package = Core.create_purl("pkg1", packages) + + # Verify PURL format is preserved + assert purl.purl == "pkg:npm/test-package@1.0.0" + + # Test Case 2: Package without provided PURL (should be generated) + no_purl_pkg = Package( + id="pkg2", + name="auto-package", + version="2.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[], + author=["Test Author"], + size=1000, + transitives=0, + url="https://socket.dev/npm/package/auto-package/overview/2.0.0" + # No purl provided - should be auto-generated + ) + packages["pkg2"] = no_purl_pkg + purl, package = Core.create_purl("pkg2", packages) + + # Verify auto-generated PURL has correct format + assert purl.purl == "pkg:npm/auto-package@2.0.0" + + # Rest of existing test cases... + +def test_find_files(): + """Test file discovery with glob patterns""" + time_calls = [] # Track time.time() calls + + def mock_time(): + val = len(time_calls) + time_calls.append(val) + return val + + with patch('socketsecurity.core.glob') as mock_glob, \ + patch('socketsecurity.core.log') as mock_log, \ + patch('time.time', side_effect=mock_time): + + # Mock glob to return different files for different patterns + def mock_glob_side_effect(pattern, recursive=True): + if "package.json" in pattern: + return [ + "/path/to/package.json", + "/path/to/nested/package.json", + "C:/path/with/windows/style/package.json" # This is actually a unique path + ] + elif "requirements.txt" in pattern: + return [ + "/path/to/requirements.txt", + "/path/to/requirements.txt" # Duplicate that will be removed + ] + elif "go.mod" in pattern: + return ["/path/to/go.mod"] + return [] + + mock_glob.side_effect = mock_glob_side_effect + + # Test file discovery + result = Core.find_files("/path/to") + + # Print debug info for verification + print(f"Total time.time() calls: {len(time_calls)}") + print("Found files:", result) + + # Verify results contain all unique files (5 total) + assert len(result) == 5 # Updated to expect 5 unique files + assert "/path/to/package.json" in result + assert "/path/to/nested/package.json" in result + assert "/path/to/requirements.txt" in result + assert "/path/to/go.mod" in result + assert "C:/path/with/windows/style/package.json" in result # Windows path is unique + + # Verify glob was called with correct patterns + glob_patterns = [call[0][0] for call in mock_glob.call_args_list] + assert any("package.json" in pattern for pattern in glob_patterns) + assert any("requirements.txt" in pattern for pattern in glob_patterns) + assert any("go.mod" in pattern for pattern in glob_patterns) + + # Verify logging with actual number of time calls + mock_log.debug.assert_any_call("Starting Find Files") + mock_log.debug.assert_any_call("Finished Find Files") + final_time = len(time_calls) - 1 + mock_log.info.assert_called_with(f"Found 5 in {final_time:.2f} seconds") + +def test_create_purl_edge_cases(): + """Test PURL creation with edge cases""" + # Test with missing optional fields + minimal_pkg = Package( + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[] + # Missing: author, size, transitives, url, purl + ) + packages = {"pkg1": minimal_pkg} + purl, package = Core.create_purl("pkg1", packages) + + # Verify defaults for missing fields + assert purl.author == [] # Should default to empty list + assert purl.size == 0 # Should default to 0 + assert purl.transitives == 0 # Should default to 0 + assert purl.url == "https://socket.dev/npm/package/test-pkg/overview/1.0.0" # URL is auto-generated for all packages + assert purl.purl == "pkg:npm/test-pkg@1.0.0" # Should generate purl + + # Test with different package type + pip_pkg = Package( + id="pkg2", + name="test-pkg", + version="1.0.0", + type="pip", # Different package type + direct=True, + manifestFiles=[{"file": "requirements.txt"}], + alerts=[] + ) + packages = {"pkg2": pip_pkg} + purl, package = Core.create_purl("pkg2", packages) + assert purl.url == "https://socket.dev/pip/package/test-pkg/overview/1.0.0" # URL is auto-generated with pip type + assert purl.purl == "pkg:pip/test-pkg@1.0.0" # Should generate purl with pip type + +def test_get_license_details(): + """Test license details handling with mocked Licenses""" + with patch('socketsecurity.core.Licenses') as MockLicenses: + # Setup mock license object with licenseText property + mock_license = type('MockLicense', (), {'licenseText': 'Mock License Text'})() + + # Configure the mock Licenses class + mock_licenses_instance = MagicMock() + mock_licenses_instance.MIT = mock_license + MockLicenses.return_value = mock_licenses_instance + + # Test package with valid license + package = Package( + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + license="MIT", + manifestFiles=[], + alerts=[] + ) + + # First test: valid license + MockLicenses.make_python_safe = MagicMock(return_value="MIT") + result = Core.get_license_details(package) + assert result.license_text == "Mock License Text" + + # Second test: unknown license + package = Package( # Create fresh package without license_text + id="pkg1", + name="test-pkg", + version="1.0.0", + type="npm", + direct=True, + license="Unknown-License", + manifestFiles=[], + alerts=[] + ) + MockLicenses.make_python_safe = MagicMock(return_value=None) + mock_licenses_instance = MagicMock(spec=[]) # Empty spec means no attributes + MockLicenses.return_value = mock_licenses_instance + + result = Core.get_license_details(package) + assert result.license_text == "" # Check for empty string instead of missing attribute + +def test_compare_issue_alerts(): + """Test comparison of issue alerts between scans, covering all branches""" + + def create_issue(key, error=False, warn=False, purl="pkg:npm/test@1.0.0", type="security"): + return Issue( + key=key, + type=type, + severity="high", + description="Test desc", + title="Test title", + purl=purl, + manifests="package.json", + error=error, + warn=warn + ) + + # Branch 1: alert_key not in head_scan_alerts + issue_error = create_issue("key1", error=True, purl="pkg:npm/test1@1.0.0") + issue_warn = create_issue("key2", warn=True, purl="pkg:npm/test2@1.0.0") + issue_no_alert = create_issue("key3", purl="pkg:npm/test3@1.0.0") # Neither error nor warn + + new_alerts = { + "key1": [issue_error], + "key2": [issue_warn], + "key3": [issue_no_alert] + } + head_alerts = {} + result = Core.compare_issue_alerts(new_alerts, head_alerts, []) + assert len(result) == 2 # Only error and warn issues should be included + assert {i.key for i in result} == {"key1", "key2"} + assert {i.purl for i in result} == {"pkg:npm/test1@1.0.0", "pkg:npm/test2@1.0.0"} + + # Branch 1a: Duplicate consolidated alerts (same purl/manifests/type) + duplicate_issue = create_issue( + "key4", + error=True, + purl=issue_error.purl, + type=issue_error.type + ) + new_alerts = { + "key1": [issue_error], + "key4": [duplicate_issue] + } + result = Core.compare_issue_alerts(new_alerts, head_alerts, []) + assert len(result) == 1 # Duplicate should be consolidated + assert result[0].purl == issue_error.purl + + # Branch 2: alert_key exists in head_scan_alerts but with different purl + new_issue = create_issue("key5", error=True, purl="pkg:npm/new@1.0.0") + head_issue = create_issue("key5", error=True, purl="pkg:npm/old@1.0.0") + + new_alerts = {"key5": [new_issue]} + head_alerts = {"key5": [head_issue]} + result = Core.compare_issue_alerts(new_alerts, head_alerts, []) + assert len(result) == 1 # Different purl should be treated as new alert + assert result[0].purl == "pkg:npm/new@1.0.0" + + # Branch 2a & 2b: Multiple alerts with mixed conditions + new_alerts = { + "key6": [ + create_issue("key6", error=True, purl="pkg:npm/test1@1.0.0"), # New error + create_issue("key6", warn=True, purl="pkg:npm/test2@1.0.0"), # New warning + create_issue("key6", purl="pkg:npm/test3@1.0.0"), # No error/warn + create_issue("key6", error=True, purl="pkg:npm/test4@1.0.0") # Will be in head + ] + } + head_alerts = { + "key6": [ + create_issue("key6", error=True, purl="pkg:npm/test4@1.0.0") # Existing in head + ] + } + result = Core.compare_issue_alerts(new_alerts, head_alerts, []) + assert len(result) == 2 # Should only include new error and warning alerts + assert {i.purl for i in result} == {"pkg:npm/test1@1.0.0", "pkg:npm/test2@1.0.0"} diff --git a/tests/unit/test_core_instance.py b/tests/unit/test_core_instance.py new file mode 100644 index 0000000..302865d --- /dev/null +++ b/tests/unit/test_core_instance.py @@ -0,0 +1,906 @@ +import pytest +from unittest.mock import MagicMock +from socketsecurity.core import Core +from socketsecurity.core.config import SocketConfig +from socketsecurity.core.client import CliClient +from socketsecurity.core.classes import Package, Alert, FullScan, FullScanParams + +from socketsecurity.core.exceptions import APIResourceNotFound +from unittest import mock +import json + +@pytest.fixture +def mock_config(): + """Fixture for a mocked SocketConfig""" + config = SocketConfig(api_key="test-key") + config.org_slug = "test-org" + config.org_id = "test-id" + config.repository_path = "orgs/test-org/repos" + # Add mock issues for capabilities + class MockIssueProps: + description = "Test description" + title = "Test title" + suggestion = "Test suggestion" + nextStepTitle = "Test next step" + + config.all_issues.envVars = MockIssueProps() + config.all_issues.networkAccess = MockIssueProps() + config.security_policy = { + 'envVars': {'action': 'warn'}, + 'networkAccess': {'action': 'error'} + } + return config + +@pytest.fixture +def mock_client(): + """Fixture for a mocked CliClient""" + client = MagicMock(spec=CliClient) + client.request.return_value = MagicMock() # Ensure request always returns a MagicMock + return client + +@pytest.fixture +def core_instance(mock_config, mock_client): + """Fixture for a Core instance with mocked dependencies""" + # Prevent set_org_vars from running in __init__ + with mock.patch('socketsecurity.core.Core.set_org_vars'): + instance = Core(mock_config, mock_client) + return instance + +def test_create_issue_alerts(core_instance): + """Test creation of issue alerts with different scenarios""" + # Create mock issue properties + class MockVulnProps: + description = "Known vulnerability found" + title = "Vulnerability Alert" + suggestion = "Update package" + nextStepTitle = "Fix Now" + + # Set up the mock issues + core_instance.config.all_issues.knownVulnerability = MockVulnProps() + core_instance.config.security_policy = {'knownVulnerability': {'action': 'error'}} + + # Create test package with Alert objects instead of dicts + package = Package( + id="test-pkg", + name="test-package", + version="1.0.0", + type="npm", + direct=True, + manifestFiles=[{"file": "package.json"}], + alerts=[ + Alert( + type="knownVulnerability", + severity="high", + key="vuln-1", + props={"details": "CVE-2023-1234"} + ), + Alert( + type="licenseSpdxDisj", + severity="low", + key="license-1", + props={} + ) + ], + purl="pkg:npm/test-package@1.0.0", + url="https://example.com" + ) + + packages = {"test-pkg": package} + alerts = {} + + # Test alert creation + result = core_instance.create_issue_alerts(package, alerts, packages) + + # Verify results + assert len(result) == 1 + assert "vuln-1" in result + + created_alert = result["vuln-1"][0] + assert created_alert.type == "knownVulnerability" + assert created_alert.severity == "high" + assert created_alert.description == "Known vulnerability found" + assert created_alert.title == "Vulnerability Alert" + assert created_alert.suggestion == "Update package" + assert created_alert.next_step_title == "Fix Now" + assert created_alert.error is True + assert created_alert.purl == "pkg:npm/test-package@1.0.0" + +def test_get_org_id_slug_single_org(core_instance, mock_client): + """Test getting org ID and slug when there is a single organization""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = { + "organizations": { + "org123": {"slug": "test-org"} + } + } + mock_client.request.return_value = mock_response + + # Test the method + org_id, org_slug = core_instance.get_org_id_slug() + + # Verify results + assert org_id == "org123" + assert org_slug == "test-org" + mock_client.request.assert_called_once_with("organizations") + +def test_get_org_id_slug_no_orgs(core_instance, mock_client): + """Test getting org ID and slug when there are no organizations""" + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = {"organizations": {}} + mock_client.request.return_value = mock_response + + # Test the method + org_id, org_slug = core_instance.get_org_id_slug() + + # Verify results + assert org_id is None + assert org_slug is None + mock_client.request.assert_called_once_with("organizations") + +def test_set_org_vars(core_instance, mock_client): + """Test setting organization variables""" + # Reset mock before test + mock_client.reset_mock() + + # Setup mock responses + org_response = MagicMock() + org_response.json.return_value = { + "organizations": { + "org123": {"slug": "test-org"} + } + } + + security_response = MagicMock() + security_response.json.return_value = { + "defaults": { + "issueRules": { + "rule1": {"action": "warn"} + } + }, + "entries": [ + { + "settings": { + "organization": { + "issueRules": { + "rule2": {"action": "error"} + } + } + } + } + ] + } + + # Setup mock client to return different responses for different calls + def mock_request(path, **kwargs): + if path == "organizations": + return org_response + elif path == "settings": + return security_response + raise ValueError(f"Unexpected path: {path}") + + mock_client.request.side_effect = mock_request + + # Test the method + core_instance.set_org_vars() + + # Verify results + assert core_instance.config.org_id == "org123" # From organizations API response + assert core_instance.config.org_slug == "test-org" + + expected_base_path = "orgs/test-org" + assert core_instance.config.full_scan_path == f"{expected_base_path}/full-scans" + assert core_instance.config.repository_path == f"{expected_base_path}/repos" + + assert core_instance.config.security_policy == { + "rule1": {"action": "warn"}, + "rule2": {"action": "error"} + } + + # Verify API calls + expected_calls = [ + mock.call("organizations"), + mock.call( + path="settings", + method="POST", + payload=json.dumps([{"organization": "org123"}]) # Using org_id from organizations API response + ) + ] + mock_client.request.assert_has_calls(expected_calls, any_order=False) + +def test_get_security_policy(core_instance, mock_client): + """Test getting security policy with different scenarios""" + # Reset mock before test + mock_client.reset_mock() + + # Setup mock response + mock_response = MagicMock() + mock_response.json.return_value = { + "defaults": { + "issueRules": { + "rule1": {"action": "warn"}, + "rule3": {"action": "ignore"} + } + }, + "entries": [ + { + "settings": { + "organization": { + "issueRules": { + "rule1": {"action": "error"}, # Override default + "rule2": {"action": "warn"} # New rule + } + } + } + } + ] + } + mock_client.request.return_value = mock_response + + # Test the method + result = core_instance.get_security_policy() + + # Verify results + expected_policy = { + "rule1": {"action": "error"}, # Org rule overrides default + "rule2": {"action": "warn"}, # Org-specific rule + "rule3": {"action": "ignore"} # Default rule (no override) + } + assert result == expected_policy + + # Verify API call + mock_client.request.assert_called_once_with( + path="settings", + method="POST", + payload=json.dumps([{"organization": core_instance.config.org_id}]) + ) + +def test_get_sbom_data(core_instance, mock_client): + """Test getting SBOM data with different response scenarios""" + # Reset mock before test + mock_client.reset_mock() + + # Test case 1: Happy path with multiple packages + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = ( + '{"type":"pypi","name":"click","version":"8.1.7","id":"12453","topLevelAncestors":["6381179126"]}\n' + '{"type":"pypi","name":"chardet","version":"5.2.0","id":"25259","topLevelAncestors":["6381179126"]}\n' + '\n' # Empty line should be skipped + '"' # Quote should be skipped + ) + mock_client.request.return_value = mock_response + + result = core_instance.get_sbom_data("test-scan-id") + + assert len(result) == 2 + assert result[0]["name"] == "click" + assert result[1]["name"] == "chardet" + mock_client.request.assert_called_once_with("orgs/test-org/full-scans/test-scan-id") + + # Test case 2: Failed API response + mock_client.reset_mock() + mock_response.status_code = 404 + mock_response.text = "Not found" + + result = core_instance.get_sbom_data("bad-scan-id") + + assert result == [] + mock_client.request.assert_called_once() + + # Test case 3: Empty response + mock_client.reset_mock() + mock_response.status_code = 200 + mock_response.text = "" + + result = core_instance.get_sbom_data("empty-scan-id") + + assert result == [] + mock_client.request.assert_called_once() + +def test_create_sbom_output(core_instance, mock_client): + """Test creating SBOM output with different scenarios""" + # Reset mock before test + mock_client.reset_mock() + + # Create mock diff object + class MockDiff: + id = "test-diff-id" + + diff = MockDiff() + + # Test case 1: Successful JSON response + mock_response = MagicMock() + mock_response.json.return_value = { + "bomFormat": "CycloneDX", + "specVersion": "1.4", + "components": [ + { + "type": "library", + "name": "test-package", + "version": "1.0.0" + } + ] + } + mock_client.request.return_value = mock_response + + result = core_instance.create_sbom_output(diff) + + assert result == mock_response.json.return_value + mock_client.request.assert_called_once_with( + path="orgs/test-org/export/cdx/test-diff-id" + ) + + # Test case 2: JSON parsing error + mock_client.reset_mock() + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + result = core_instance.create_sbom_output(diff) + + assert result == {} # Should return empty dict on error + mock_client.request.assert_called_once_with( + path="orgs/test-org/export/cdx/test-diff-id" + ) + +def test_create_full_scan(core_instance, mock_client, tmp_path): + """Test creating full scan with different scenarios""" + # Reset mock before test + mock_client.reset_mock() + + # Set up the full scan path in config + core_instance.config.full_scan_path = "orgs/test-org/full-scans" + + # Create test files and store their full paths + workspace = str(tmp_path) + test_files = [] + file_contents = { + "package.json": "{}", + "nested/package.json": "{}", + "requirements.txt": "requests==2.0.0" + } + + for rel_path, content in file_contents.items(): + full_path = tmp_path / rel_path + full_path.parent.mkdir(parents=True, exist_ok=True) + full_path.write_text(content) + test_files.append(str(full_path)) + + # Create test params + class MockParams: + repo = "test-repo" + branch = "main" + commit = "abc123" + message = "test commit" + pr = None + + params = MockParams() + + # Test case 1: Successful scan with multiple files + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "id": "scan123", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "organization_id": "org123", + "repository_id": "repo123", + "branch": "main", + "commit_message": "test commit", + "commit_hash": "abc123", + "pull_request": None + } + + sbom_response = MagicMock() + sbom_response.status_code = 200 + sbom_response.text = ( + '{"type":"npm","name":"test-pkg","version":"1.0.0","id":"12453","license":"MIT","direct":true,"manifestFiles":[{"file":"package.json"}]}\n' + '{"type":"pypi","name":"requests","version":"2.0.0","id":"25259","license":"Apache-2.0","direct":true,"manifestFiles":[{"file":"requirements.txt"}]}\n' + ) + + def mock_request(path, **kwargs): + if "orgs/test-org/full-scans?" in path: # Creating new scan + return mock_response + elif "orgs/test-org/full-scans/scan123" in path: # Getting SBOM data + return sbom_response + raise ValueError(f"Unexpected path: {path}") + + mock_client.request.side_effect = mock_request + + # Test the method + result = core_instance.create_full_scan(test_files, params, workspace) + + # Verify results + assert isinstance(result, FullScan) + assert result.id == "scan123" + assert result.branch == "main" + assert result.commit_hash == "abc123" + assert len(result.sbom_artifacts) == 2 + assert result.sbom_artifacts[0]["name"] == "test-pkg" + assert result.sbom_artifacts[1]["name"] == "requests" + + # Verify API calls + assert mock_client.request.call_count == 2 + create_call = mock_client.request.call_args_list[0] + + # Check the first positional argument (path) + assert "orgs/test-org/full-scans?" in create_call[0][0] + assert create_call[1]["method"] == "POST" + assert len(create_call[1]["files"]) == 3 # All test files included + + # Verify file paths in request are relative to workspace + files_in_request = [f[0] for f in create_call[1]["files"]] + assert "package.json" in files_in_request + assert "nested/package.json" in files_in_request + assert "requirements.txt" in files_in_request + + # Test case 2: Empty file list + mock_client.reset_mock() + def mock_request_empty_files(path, **kwargs): + if "orgs/test-org/full-scans?" in path: # Creating new scan + response = MagicMock() + response.status_code = 400 + response.json.return_value = { + "id": None, + "error": "No files provided" + } + return response + elif "orgs/test-org/full-scans/None" in path: # Getting SBOM data for failed scan + response = MagicMock() + response.status_code = 404 + response.text = "" + return response + raise ValueError(f"Unexpected path: {path}") + + mock_client.request.side_effect = mock_request_empty_files + + result = core_instance.create_full_scan([], params, workspace) + assert result.id is None + assert result.sbom_artifacts == [] # Empty list for failed scan + + # Test case 3: Failed API response + mock_client.reset_mock() + def mock_request_failed_api(path, **kwargs): + if "orgs/test-org/full-scans?" in path: # Creating new scan + response = MagicMock() + response.status_code = 500 + response.json.return_value = { + "id": None, + "error": "Server error" + } + return response + elif "orgs/test-org/full-scans/None" in path: # Getting SBOM data for failed scan + response = MagicMock() + response.status_code = 404 + response.text = "" + return response + raise ValueError(f"Unexpected path: {path}") + + mock_client.request.side_effect = mock_request_failed_api + + result = core_instance.create_full_scan(test_files, params, workspace) + assert result.id is None + assert result.sbom_artifacts == [] # Empty list for failed scan + +def test_get_head_scan_for_repo(core_instance, mock_client): + """Test getting head scan ID for a repository with different scenarios""" + # Reset mock before test + mock_client.reset_mock() + + # Set up the repository path in config + core_instance.config.repository_path = "orgs/test-org/repos" + + # Test case 1: Repository exists with head scan + mock_response = MagicMock() + mock_response.json.return_value = { + "repository": { + "id": "repo123", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "head_full_scan_id": "scan123", + "name": "test-repo", + "description": "Test repository", + "homepage": "https://example.com", + "visibility": "public", + "archived": False, + "default_branch": "main" + } + } + mock_client.request.return_value = mock_response + + result = core_instance.get_head_scan_for_repo("test-repo") + + assert result == "scan123" + mock_client.request.assert_called_once_with("orgs/test-org/repos/test-repo") + + # Test case 2: Repository exists but no head scan + mock_client.reset_mock() + mock_response.json.return_value = { + "repository": { + "id": "repo123", + "head_full_scan_id": None, + # ... other fields omitted for brevity + } + } + + result = core_instance.get_head_scan_for_repo("test-repo") + + assert result is None + mock_client.request.assert_called_once() + + # Test case 3: Repository not found + mock_client.reset_mock() + mock_response.status_code = 404 + mock_response.json.side_effect = APIResourceNotFound("Repository not found") + mock_client.request.side_effect = APIResourceNotFound("Repository not found") + + with pytest.raises(APIResourceNotFound) as exc_info: + core_instance.get_head_scan_for_repo("nonexistent-repo") + + assert "Repository not found" in str(exc_info.value) + mock_client.request.assert_called_once() + +def test_get_full_scan(core_instance, mock_client): + """Test getting full scan data with different scenarios""" + mock_client.reset_mock() + core_instance.config.full_scan_path = "orgs/test-org/full-scans" + + # Test case 1: Successful scan with SBOM data + mock_scan_response = MagicMock() + mock_scan_response.status_code = 200 + mock_scan_response.json.return_value = { + "id": "scan123", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "organization_id": "org123", + "repository_id": "repo123", + "branch": "main", + "commit_message": "test commit", + "commit_hash": "abc123", + "pull_request": None + } + + # Second request should return SBOM data + mock_sbom_response = MagicMock() + mock_sbom_response.status_code = 200 + mock_sbom_response.text = ( + '{"type":"npm","name":"test-pkg","version":"1.0.0","id":"12453","license":"MIT","direct":true,"manifestFiles":[{"file":"package.json"}]}\n' + '{"type":"pypi","name":"requests","version":"2.0.0","id":"25259","license":"Apache-2.0","direct":true,"manifestFiles":[{"file":"requirements.txt"}]}\n' + ) + + call_count = 0 + def mock_request(path, **kwargs): + nonlocal call_count + print(f"Mock request called with path: {path}, kwargs: {kwargs}") + if path == f"{core_instance.config.full_scan_path}/scan123": + call_count += 1 + if call_count == 1: # First call returns scan data + print("Returning scan response") + return mock_scan_response + else: # Second call returns SBOM data + print("Returning SBOM response") + return mock_sbom_response + raise ValueError(f"Unexpected path: {path}") + + mock_client.request.side_effect = mock_request + + result = core_instance.get_full_scan("scan123") + print(f"Result SBOM artifacts: {result.sbom_artifacts}") + + assert isinstance(result, FullScan) + assert result.id == "scan123" + assert result.branch == "main" + assert result.commit_hash == "abc123" + assert len(result.sbom_artifacts) == 2 + assert result.sbom_artifacts[0]["name"] == "test-pkg" + assert result.sbom_artifacts[1]["name"] == "requests" + + # Test case 2: Scan not found + mock_client.reset_mock() + def mock_request_not_found(path, **kwargs): + raise APIResourceNotFound("Scan not found") + + mock_client.request.side_effect = mock_request_not_found + + with pytest.raises(APIResourceNotFound) as exc_info: + core_instance.get_full_scan("nonexistent-scan") + + assert "Scan not found" in str(exc_info.value) + mock_client.request.assert_called_once() + + # Test case 3: Invalid response format + mock_client.reset_mock() + mock_scan_invalid = MagicMock() + mock_scan_invalid.status_code = 200 + mock_scan_invalid.json.return_value = { + "id": "scan123", + # Missing required fields + } + + def mock_request_invalid(path, **kwargs): + return mock_scan_invalid + + mock_client.request.side_effect = mock_request_invalid + + result = core_instance.get_full_scan("scan123") + + # Should still create a FullScan object with default values + assert isinstance(result, FullScan) + assert result.id == "scan123" + assert not hasattr(result, "branch") + assert not hasattr(result, "commit_hash") + assert result.sbom_artifacts == [] # Default empty list + +def test_compare_sboms(core_instance, mock_client): + """Test SBOM comparison with different scenarios""" + print("Setting up test data...") + + # Add debug logging to see what's happening with alerts + def create_package_dict(pkg_id, name, version, alerts=None): + return { + "id": pkg_id, + "name": name, + "version": version, + "type": "npm", + "direct": True, + "license": "MIT", + "manifestFiles": [{"file": "package.json"}], + "alerts": [Alert(**alert) for alert in (alerts or [])], # Convert to Alert objects here + "author": ["Test Author"], + "size": 1000, + "url": f"https://example.com/{name}", + "purl": f"pkg:npm/{name}@{version}", + "topLevelAncestors": [] + } + + # Test case 1: New packages and alerts + new_scan_data = [ + create_package_dict("pkg1", "test-pkg", "1.0.0", alerts=[ + { + "type": "envVars", + "severity": "high", + "key": "env-1", + "category": "capability", + "props": {} + } + ]), + create_package_dict("pkg2", "new-pkg", "2.0.0", alerts=[ + { + "type": "networkAccess", + "severity": "high", + "key": "net-1", + "category": "capability", + "props": {} + } + ]) + ] + + head_scan_data = [ + create_package_dict("pkg1", "test-pkg", "1.0.0"), # No alerts + create_package_dict("pkg3", "removed-pkg", "3.0.0") # No alerts + ] + + print("Comparing SBOMs...") + result = core_instance.compare_sboms(new_scan_data, head_scan_data) + print(f"New packages: {result.new_packages}") + print(f"Removed packages: {result.removed_packages}") + print(f"New alerts: {result.new_alerts}") + print(f"New capabilities: {result.new_capabilities}") + + # Verify new package was added + assert len(result.new_packages) == 1 + assert result.new_packages[0].name == "new-pkg" + + # Verify package was removed + assert len(result.removed_packages) == 1 + assert result.removed_packages[0].name == "removed-pkg" + + # Verify new alerts were detected + assert len(result.new_alerts) > 0 + + # Verify new capabilities were detected + assert len(result.new_capabilities) > 0 + assert "Environment" in result.new_capabilities["pkg1"] + assert "Network" in result.new_capabilities["pkg2"] + + # Additional assertions for alert handling + assert any(alert.type == "envVars" for alert in result.new_alerts), "Should detect new envVars alert" + assert any(alert.type == "networkAccess" for alert in result.new_alerts), "Should detect new networkAccess alert" + + # Verify alerts are properly created with all required fields + for alert in result.new_alerts: + assert hasattr(alert, 'key'), "Alert should have a key" + assert hasattr(alert, 'severity'), "Alert should have severity" + assert hasattr(alert, 'description'), "Alert should have description" + assert hasattr(alert, 'title'), "Alert should have title" + +def test_create_new_diff_no_change(core_instance, mock_client): + """Test create_new_diff when no_change is True""" + params = FullScanParams( + repo="test-repo", + branch="main", + commit_message="test commit", + commit_hash="abc123", + pull_request=1, + committer="test@example.com", + make_default_branch=False, + set_as_pending_head=False + ) + + result = core_instance.create_new_diff("test/path", params, "workspace", no_change=True) + assert result.id == "no_diff_id" + +def test_create_new_diff_no_files(core_instance, mock_client): + """Test create_new_diff when no files are found""" + params = FullScanParams( + repo="test-repo", + branch="main", + commit_message="test commit", + commit_hash="abc123", + pull_request=1, + committer="test@example.com", + make_default_branch=False, + set_as_pending_head=False + ) + + with mock.patch('socketsecurity.core.Core.find_files', return_value=[]): + result = core_instance.create_new_diff("test/path", params, "workspace") + assert result.id == "no_diff_id" + +def test_create_new_diff_new_repository(core_instance, mock_client): + """Test create_new_diff for a new repository (no head scan)""" + params = FullScanParams( + repo="test-repo", + branch="main", + commit_message="test commit", + commit_hash="abc123", + pull_request=1, + committer="test@example.com", + make_default_branch=False, + set_as_pending_head=False + ) + + test_package = { + "id": "test-pkg-1", + "name": "test-pkg", + "version": "1.0.0", + "type": "npm", + "direct": True, + "license": "MIT", + "manifestFiles": [{"file": "package.json"}], + "alerts": [], + "author": ["Test Author"], + "size": 1000, + "url": "https://example.com/test-pkg", + "purl": "pkg:npm/test-pkg@1.0.0", + "topLevelAncestors": [] + } + + # Mock the repository request to return 404 + mock_response = MagicMock() + mock_response.json.return_value = {"error": "Not Found"} + mock_response.status_code = 404 + mock_client.request.side_effect = [APIResourceNotFound("Repository not found")] + + with mock.patch('socketsecurity.core.Core.find_files', return_value=["package.json"]): + with mock.patch('socketsecurity.core.Core.create_full_scan') as mock_create_scan: + mock_create_scan.return_value = FullScan( + id="new-scan-id", + sbom_artifacts=[test_package] + ) + result = core_instance.create_new_diff("test/path", params, "workspace") + assert result.id == "new-scan-id" + assert result.diff_url == result.report_url + assert result.report_url == f"https://socket.dev/dashboard/org/{core_instance.config.org_slug}/sbom/new-scan-id" + +def test_create_new_diff_existing_head_scan(core_instance, mock_client): + """Test create_new_diff with an existing head scan""" + params = FullScanParams( + repo="test-repo", + branch="main", + commit_message="test commit", + commit_hash="abc123", + pull_request=1, + committer="test@example.com", + make_default_branch=False, + set_as_pending_head=False + ) + + test_package = { + "id": "test-pkg-1", + "name": "test-pkg", + "version": "1.0.0", + "type": "npm", + "direct": True, + "license": "MIT", + "manifestFiles": [{"file": "package.json"}], + "alerts": [], + "author": ["Test Author"], + "size": 1000, + "url": "https://example.com/test-pkg", + "purl": "pkg:npm/test-pkg@1.0.0", + "topLevelAncestors": [] + } + + mock_response = MagicMock() + mock_response.json.return_value = { + "repository": { + "id": "repo-123", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "head_full_scan_id": "head-scan-id", + "name": "test-repo", + "description": "Test repository", + "homepage": "https://example.com", + "visibility": "public", + "archived": False, + "default_branch": "main" + } + } + mock_client.request.return_value = mock_response + + with mock.patch('socketsecurity.core.Core.find_files', return_value=["package.json"]): + with mock.patch('socketsecurity.core.Core.create_full_scan') as mock_create_scan: + mock_create_scan.return_value = FullScan( + id="new-scan-id", + sbom_artifacts=[test_package] + ) + result = core_instance.create_new_diff("test/path", params, "workspace") + assert result.id == "new-scan-id" + assert "head-scan-id" in result.diff_url + +def test_create_new_diff_empty_head_scan(core_instance, mock_client): + """Test create_new_diff with an empty head scan""" + params = FullScanParams( + repo="test-repo", + branch="main", + commit_message="test commit", + commit_hash="abc123", + pull_request=1, + committer="test@example.com", + make_default_branch=False, + set_as_pending_head=False + ) + + test_package = { + "id": "test-pkg-1", + "name": "test-pkg", + "version": "1.0.0", + "type": "npm", + "direct": True, + "license": "MIT", + "manifestFiles": [{"file": "package.json"}], + "alerts": [], + "author": ["Test Author"], + "size": 1000, + "url": "https://example.com/test-pkg", + "purl": "pkg:npm/test-pkg@1.0.0", + "topLevelAncestors": [] + } + + mock_response = MagicMock() + mock_response.json.return_value = { + "repository": { + "id": "repo-123", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "head_full_scan_id": "", # Empty head scan ID + "name": "test-repo", + "description": "Test repository", + "homepage": "https://example.com", + "visibility": "public", + "archived": False, + "default_branch": "main" + } + } + mock_client.request.return_value = mock_response + + with mock.patch('socketsecurity.core.Core.find_files', return_value=["package.json"]): + with mock.patch('socketsecurity.core.Core.create_full_scan') as mock_create_scan: + mock_create_scan.return_value = FullScan( + id="new-scan-id", + sbom_artifacts=[test_package] + ) + result = core_instance.create_new_diff("test/path", params, "workspace") + assert result.id == "new-scan-id" + assert result.diff_url == result.report_url From be86980829e1ab77678ec511bedbefaeded2000a Mon Sep 17 00:00:00 2001 From: Eric Hibbs Date: Tue, 12 Nov 2024 21:01:54 -0800 Subject: [PATCH 3/6] about to big refactor socketcli --- socketsecurity/config.py | 186 ++++++++++++++++++++++++++++++++ socketsecurity/core/__init__.py | 48 --------- socketsecurity/output.py | 89 +++++++++++++++ socketsecurity/socketcli.py | 6 -- tests/unit/test_cli_config.py | 31 ++++++ tests/unit/test_output.py | 50 +++++++++ 6 files changed, 356 insertions(+), 54 deletions(-) create mode 100644 socketsecurity/config.py create mode 100644 socketsecurity/output.py create mode 100644 tests/unit/test_cli_config.py create mode 100644 tests/unit/test_output.py diff --git a/socketsecurity/config.py b/socketsecurity/config.py new file mode 100644 index 0000000..97bfad9 --- /dev/null +++ b/socketsecurity/config.py @@ -0,0 +1,186 @@ +from dataclasses import dataclass +from typing import List, Optional +import argparse +import os + +@dataclass +class CliConfig: + api_token: str + repo: Optional[str] + branch: str = "" + committer: Optional[List[str]] = None + pr_number: str = "0" + commit_message: Optional[str] = None + default_branch: bool = False + target_path: str = "./" + scm: str = "api" + sbom_file: Optional[str] = None + commit_sha: str = "" + generate_license: bool = False + enable_debug: bool = False + allow_unverified: bool = False + enable_json: bool = False + disable_overview: bool = False + disable_security_issue: bool = False + files: str = "[]" + ignore_commit_files: bool = False + disable_blocking: bool = False + + @classmethod + def from_args(cls, args_list: Optional[List[str]] = None) -> 'CliConfig': + parser = create_argument_parser() + args = parser.parse_args(args_list) + + # Get API token from env or args + api_token = os.getenv("SOCKET_SECURITY_API_KEY") or args.api_token + + return cls( + api_token=api_token, + repo=args.repo, + branch=args.branch, + committer=args.committer, + pr_number=args.pr_number, + commit_message=args.commit_message, + default_branch=args.default_branch, + target_path=args.target_path, + scm=args.scm, + sbom_file=args.sbom_file, + commit_sha=args.commit_sha, + generate_license=args.generate_license, + enable_debug=args.enable_debug, + allow_unverified=args.allow_unverified, + enable_json=args.enable_json, + disable_overview=args.disable_overview, + disable_security_issue=args.disable_security_issue, + files=args.files, + ignore_commit_files=args.ignore_commit_files, + disable_blocking=args.disable_blocking + ) + +def create_argument_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + prog="socketcli", + description="Socket Security CLI" + ) + + parser.add_argument( + "--api-token", + help="Socket Security API token (can also be set via SOCKET_SECURITY_API_KEY env var)", + required=False + ) + + parser.add_argument( + "--repo", + help="Repository name in owner/repo format", + required=False + ) + + parser.add_argument( + "--branch", + help="Branch name", + default="" + ) + + parser.add_argument( + "--committer", + help="Committer(s) to filter by", + nargs="*" + ) + + parser.add_argument( + "--pr-number", + help="Pull request number", + default="0" + ) + + parser.add_argument( + "--commit-message", + help="Commit message" + ) + + # Boolean flags + parser.add_argument( + "--default-branch", + action="store_true", + help="Use default branch" + ) + + parser.add_argument( + "--generate-license", + action="store_true", + help="Generate license information" + ) + + parser.add_argument( + "--enable-debug", + action="store_true", + help="Enable debug logging" + ) + + parser.add_argument( + "--allow-unverified", + action="store_true", + help="Allow unverified packages" + ) + + parser.add_argument( + "--enable-json", + action="store_true", + help="Output in JSON format" + ) + + parser.add_argument( + "--disable-overview", + action="store_true", + help="Disable overview output" + ) + + parser.add_argument( + "--disable-security-issue", + action="store_true", + help="Disable security issue checks" + ) + + parser.add_argument( + "--ignore-commit-files", + action="store_true", + help="Ignore commit files" + ) + + parser.add_argument( + "--disable-blocking", + action="store_true", + help="Disable blocking mode" + ) + + # Path and file related arguments + parser.add_argument( + "--target-path", + default="./", + help="Target path for analysis" + ) + + parser.add_argument( + "--scm", + default="api", + help="Source control management type" + ) + + parser.add_argument( + "--sbom-file", + help="SBOM file path" + ) + + parser.add_argument( + "--commit-sha", + default="", + help="Commit SHA" + ) + + parser.add_argument( + "--files", + default="[]", + help="Files to analyze (JSON array string)" + ) + + return parser \ No newline at end of file diff --git a/socketsecurity/core/__init__.py b/socketsecurity/core/__init__.py index f554938..35810a9 100644 --- a/socketsecurity/core/__init__.py +++ b/socketsecurity/core/__init__.py @@ -140,10 +140,6 @@ def do_request( class Core: - # token: str - # base_api_url: str - # request_timeout: int - # reports: list client: CliClient config: SocketConfig @@ -153,30 +149,6 @@ def __init__(self, config: SocketConfig, client: CliClient): self.client = client self.set_org_vars() - # def __init__( - # self, - # token: str, - # base_api_url: str = None, - # request_timeout: int = None, - # enable_all_alerts: bool = False, - # allow_unverified: bool = False - # ): - # global allow_unverified_ssl - # allow_unverified_ssl = allow_unverified - # self.token = token + ":" - # encode_key(self.token) - # self.socket_date_format = "%Y-%m-%dT%H:%M:%S.%fZ" - # self.base_api_url = base_api_url - # if self.base_api_url is not None: - # Core.set_api_url(self.base_api_url) - # self.request_timeout = request_timeout - # if self.request_timeout is not None: - # Core.set_timeout(self.request_timeout) - # if enable_all_alerts: - # global all_new_alerts - # all_new_alerts = True - # Core.set_org_vars() - def set_org_vars(self) -> None: """Sets the main shared configuration variables""" @@ -231,26 +203,6 @@ def get_sbom_data(self, full_scan_id: str) -> list: return results - # ORIGINAL - remove after verification - # def get_sbom_data(self, full_scan_id: str) -> list: - # path = f"orgs/{self.config.org_slug}/full-scans/{full_scan_id}" - # response = self.client.request(path) - # results = [] - # try: - # data = response.json() - # results = data.get("sbom_artifacts") or [] - # except Exception as error: - # log.debug("Failed with old style full-scan API using new format") - # log.debug(error) - # data = response.text - # data.strip('"') - # data.strip() - # for line in data.split("\n"): - # if line != '"' and line != "" and line is not None: - # item = json.loads(line) - # results.append(item) - # return results - def get_security_policy(self) -> dict: """Get the Security policy and determine the effective Org security policy""" payload = [{"organization": self.config.org_id}] diff --git a/socketsecurity/output.py b/socketsecurity/output.py new file mode 100644 index 0000000..25e15da --- /dev/null +++ b/socketsecurity/output.py @@ -0,0 +1,89 @@ +from typing import Optional, Dict, Any +import json +import sys +import logging +from pathlib import Path +from .core.classes import Diff, Issue + + +class OutputHandler: + blocking_disabled: bool + logger: logging.Logger + + def __init__(self, blocking_disabled: bool): + self.blocking_disabled = blocking_disabled + self.logger = logging.getLogger("socketcli") + + def handle_output(self, diff_report: Diff, sbom_file_name: Optional[str] = None, json_output: bool = False) -> int: + """Main output handler that determines output format and returns exit code""" + if json_output: + self.output_console_json(diff_report, sbom_file_name) + else: + self.output_console_comments(diff_report, sbom_file_name) + + self.save_sbom_file(diff_report, sbom_file_name) + return 0 if self.report_pass(diff_report) else 1 + + def output_console_comments(self, diff_report: Diff, sbom_file_name: Optional[str] = None) -> None: + """Outputs formatted console comments""" + if not diff_report.issues: + self.logger.info("No issues found") + return + + for issue in diff_report.issues: + self._output_issue(issue) + + def output_console_json(self, diff_report: Diff, sbom_file_name: Optional[str] = None) -> None: + """Outputs JSON formatted results""" + output = { + "issues": [self._format_issue(issue) for issue in diff_report.issues], + "pass": self.report_pass(diff_report) + } + if sbom_file_name: + output["sbom_file"] = sbom_file_name + + json.dump(output, sys.stdout, indent=2) + sys.stdout.write("\n") + + def report_pass(self, diff_report: Diff) -> bool: + """Determines if the report passes security checks""" + if not diff_report.issues: + return True + + if self.blocking_disabled: + return True + + return not any(issue.blocking for issue in diff_report.issues) + + def save_sbom_file(self, diff_report: Diff, sbom_file_name: Optional[str] = None) -> None: + """Saves SBOM file if filename is provided""" + if not sbom_file_name or not diff_report.sbom: + return + + sbom_path = Path(sbom_file_name) + sbom_path.parent.mkdir(parents=True, exist_ok=True) + + with open(sbom_path, "w") as f: + json.dump(diff_report.sbom, f, indent=2) + + def _output_issue(self, issue: Issue) -> None: + """Helper method to format and output a single issue""" + severity = issue.severity.upper() if issue.severity else "UNKNOWN" + status = "🚫 Blocking" if issue.blocking else "⚠️ Warning" + + self.logger.warning(f"\n{status} - Severity: {severity}") + self.logger.warning(f"Title: {issue.title}") + if issue.description: + self.logger.warning(f"Description: {issue.description}") + if issue.recommendation: + self.logger.warning(f"Recommendation: {issue.recommendation}") + + def _format_issue(self, issue: Issue) -> Dict[str, Any]: + """Helper method to format an issue for JSON output""" + return { + "title": issue.title, + "description": issue.description, + "severity": issue.severity, + "blocking": issue.blocking, + "recommendation": issue.recommendation + } diff --git a/socketsecurity/socketcli.py b/socketsecurity/socketcli.py index 3496c9e..72caff9 100644 --- a/socketsecurity/socketcli.py +++ b/socketsecurity/socketcli.py @@ -1,7 +1,6 @@ import argparse import json -import socketsecurity.core from socketsecurity.core import Core, __version__ from socketsecurity.logging import initialize_logging, set_debug_mode from socketsecurity.core.classes import FullScanParams, Diff, Package, Issue @@ -11,13 +10,8 @@ from git import InvalidGitRepositoryError, NoSuchPathError import os import sys -import logging socket_logger, cli_logger = initialize_logging() - -log_format = "%(asctime)s: %(message)s" -logging.basicConfig(level=logging.INFO, format=log_format) -socketsecurity.core.log.setLevel(level=logging.INFO) log = cli_logger blocking_disabled = False diff --git a/tests/unit/test_cli_config.py b/tests/unit/test_cli_config.py new file mode 100644 index 0000000..db7b1f5 --- /dev/null +++ b/tests/unit/test_cli_config.py @@ -0,0 +1,31 @@ +import pytest +from socketsecurity.config import CliConfig + +class TestCliConfig: + def test_api_token_from_env(self, monkeypatch): + monkeypatch.setenv("SOCKET_SECURITY_API_KEY", "test-token") + config = CliConfig.from_args([]) # Empty args list + assert config.api_token == "test-token" + + def test_required_args(self): + """Test that api token is required if not in environment""" + with pytest.raises(ValueError, match="API token is required"): + config = CliConfig.from_args([]) + if not config.api_token: + raise ValueError("API token is required") + + def test_default_values(self): + # Test that default values are set correctly + config = CliConfig.from_args(["--api-token", "test"]) + assert config.branch == "" + assert config.target_path == "./" + assert config.files == "[]" + + @pytest.mark.parametrize("flag,attr", [ + ("--enable-debug", "enable_debug"), + ("--disable-blocking", "disable_blocking"), + ("--allow-unverified", "allow_unverified") + ]) + def test_boolean_flags(self, flag, attr): + config = CliConfig.from_args(["--api-token", "test", flag]) + assert getattr(config, attr) is True \ No newline at end of file diff --git a/tests/unit/test_output.py b/tests/unit/test_output.py new file mode 100644 index 0000000..5c170ef --- /dev/null +++ b/tests/unit/test_output.py @@ -0,0 +1,50 @@ +import pytest +from socketsecurity.output import OutputHandler +from socketsecurity.core.classes import Diff, Issue +import json + +class TestOutputHandler: + @pytest.fixture + def handler(self): + return OutputHandler(blocking_disabled=False) + + def test_report_pass_with_blocking_issues(self, handler): + diff = Diff() + diff.issues = [Issue(blocking=True)] + assert not handler.report_pass(diff) + + def test_report_pass_with_blocking_disabled(self): + handler = OutputHandler(blocking_disabled=True) + diff = Diff() + diff.issues = [Issue(blocking=True)] + assert handler.report_pass(diff) + + def test_json_output_format(self, handler, capsys): + diff = Diff() + test_issue = Issue( + title="Test", + severity="high", + blocking=True, + description="Test description", + recommendation=None + ) + diff.issues = [test_issue] + + handler.output_console_json(diff) + captured = capsys.readouterr() + + # Parse the JSON output and verify structure + output = json.loads(captured.out) + assert output["issues"][0]["title"] == "Test" + assert output["issues"][0]["severity"] == "high" + assert output["issues"][0]["blocking"] is True + assert output["issues"][0]["description"] == "Test description" + assert output["issues"][0]["recommendation"] is None + + def test_sbom_file_saving(self, handler, tmp_path): + # Test SBOM file is created correctly + diff = Diff() + diff.sbom = {"test": "data"} + sbom_path = tmp_path / "test.json" + handler.save_sbom_file(diff, str(sbom_path)) + assert sbom_path.exists() \ No newline at end of file From c2f4b459c809cb85f1fde3203946fc34a24b3285 Mon Sep 17 00:00:00 2001 From: Eric Hibbs Date: Tue, 12 Nov 2024 21:25:11 -0800 Subject: [PATCH 4/6] post-socketcli refactor --- socketsecurity/socketcli.py | 382 ++++++++++-------------------------- 1 file changed, 101 insertions(+), 281 deletions(-) diff --git a/socketsecurity/socketcli.py b/socketsecurity/socketcli.py index 72caff9..6b0858f 100644 --- a/socketsecurity/socketcli.py +++ b/socketsecurity/socketcli.py @@ -1,334 +1,149 @@ -import argparse import json -from socketsecurity.core import Core, __version__ +from socketsecurity.core import Core from socketsecurity.logging import initialize_logging, set_debug_mode -from socketsecurity.core.classes import FullScanParams, Diff, Package, Issue +from socketsecurity.core.classes import FullScanParams, Diff from socketsecurity.core.messages import Messages from socketsecurity.core.scm_comments import Comments from socketsecurity.core.git_interface import Git from git import InvalidGitRepositoryError, NoSuchPathError -import os +from socketsecurity.config import CliConfig +from socketsecurity.output import OutputHandler +from socketsecurity.core.config import SocketConfig +from socketsecurity.core.client import CliClient + import sys socket_logger, cli_logger = initialize_logging() log = cli_logger blocking_disabled = False -parser = argparse.ArgumentParser( - prog="socketcli", - description="Socket Security CLI" -) -parser.add_argument( - '--api_token', - help='The Socket API token can be set via SOCKET_SECURITY_API_KEY', - required=False -) -parser.add_argument( - '--repo', - help='The name of the repository', - required=False -) -parser.add_argument( - '--branch', - default='', - help='The name of the branch', - required=False -) -parser.add_argument( - - '--committer', - help='The name of the person or bot running this', - action="append", - required=False -) -parser.add_argument( - '--pr_number', - default="0", - help='The pr or build number', - required=False -) -parser.add_argument( - '--commit_message', - help='Commit or build message for the run', - required=False -) -parser.add_argument( - '--default_branch', - default=False, - action='store_true', - help='Whether this is the default/head for run' -) -parser.add_argument( - '--target_path', - default='./', - help='Path to look for manifest files', - required=False -) - -parser.add_argument( - '--scm', - default='api', - help='Integration mode choices are api, github, gitlab, and bitbucket', - choices=["api", "github", "gitlab"], - required=False -) - -parser.add_argument( - '--sbom-file', - default=None, - help='If soecified save the SBOM details to the specified file', - required=False -) - -parser.add_argument( - '--commit-sha', - default="", - help='Optional git commit sha', - required=False -) - -parser.add_argument( - '--generate-license', - default=False, - help='Run in license mode to generate license output', - required=False -) - -parser.add_argument( - '-v', - '--version', - action="version", - version=f'%(prog)s {__version__}', - help='Display the version', -) - -parser.add_argument( - '--enable-debug', - help='Enable debug mode', - action='store_true', - default=False -) - -parser.add_argument( - '--allow-unverified', - help='Allow unverified SSL Connections', - action='store_true', - default=False -) - -parser.add_argument( - '--enable-json', - help='Enable json output of results instead of table formatted', - action='store_true', - default=False -) - -parser.add_argument( - '--disable-overview', - help='Disables Dependency Overview comments', - action='store_true', - default=False -) - -parser.add_argument( - '--disable-security-issue', - help='Disables Security Issues comment', - action='store_true', - default=False -) - -parser.add_argument( - '--files', - help='Specify a list of files in the format of ["file1", "file2"]', - default="[]" -) - -parser.add_argument( - '--ignore-commit-files', - help='Ignores only looking for changed files form the commit. Will find any supported manifest file type', - action='store_true', - default=False -) - -parser.add_argument( - '--disable-blocking', - help='Disables failing checks and will only exit with an exit code of 0', - action='store_true', - default=False -) - - -def output_console_comments(diff_report: Diff, sbom_file_name: str = None) -> None: - if diff_report.id != "NO_DIFF_RAN": - console_security_comment = Messages.create_console_security_alert_table(diff_report) - save_sbom_file(diff_report, sbom_file_name) - log.info(f"Socket Full Scan ID: {diff_report.id}") - if len(diff_report.new_alerts) > 0: - log.info("Security issues detected by Socket Security") - msg = f"\n{console_security_comment}" - log.info(msg) - if not report_pass(diff_report) and not blocking_disabled: - sys.exit(1) - else: - # Means only warning alerts with no blocked - if not blocking_disabled: - sys.exit(5) - else: - log.info("No New Security issues detected by Socket Security") - - -def output_console_json(diff_report: Diff, sbom_file_name: str = None) -> None: - if diff_report.id != "NO_DIFF_RAN": - console_security_comment = Messages.create_security_comment_json(diff_report) - save_sbom_file(diff_report, sbom_file_name) - print(json.dumps(console_security_comment)) - if not report_pass(diff_report) and not blocking_disabled: - sys.exit(1) - elif len(diff_report.new_alerts) > 0 and not blocking_disabled: - # Means only warning alerts with no blocked - sys.exit(5) - -def report_pass(diff_report: Diff) -> bool: - report_passed = True - if len(diff_report.new_alerts) > 0: - for alert in diff_report.new_alerts: - alert: Issue - if report_passed and alert.error: - report_passed = False - break - return report_passed - - -def save_sbom_file(diff_report: Diff, sbom_file_name: str = None): - if diff_report is not None and sbom_file_name is not None: - Core.save_file(sbom_file_name, json.dumps(Core.create_sbom_output(diff_report))) def cli(): try: main_code() except KeyboardInterrupt: - log.info("Keyboard Interrupt detected, exiting") - if not blocking_disabled: + cli_logger.info("Keyboard Interrupt detected, exiting") + config = CliConfig.from_args() # Get current config + if not config.disable_blocking: sys.exit(2) else: sys.exit(0) except Exception as error: - log.error("Unexpected error when running the cli") - log.error(error) - if not blocking_disabled: + cli_logger.error("Unexpected error when running the cli") + cli_logger.error(error) + config = CliConfig.from_args() # Get current config + if not config.disable_blocking: sys.exit(3) else: sys.exit(0) def main_code(): - arguments = parser.parse_args() + config = CliConfig.from_args() + output_handler = OutputHandler(blocking_disabled=config.disable_blocking) - if arguments.enable_debug: + if config.enable_debug: set_debug_mode(True) log.debug("Debug logging enabled") - repo = arguments.repo - branch = arguments.branch - commit_message = arguments.commit_message - committer = arguments.committer - default_branch = arguments.default_branch - pr_number = arguments.pr_number - target_path = arguments.target_path - scm_type = arguments.scm - commit_sha = arguments.commit_sha - sbom_file = arguments.sbom_file - license_mode = arguments.generate_license - enable_json = arguments.enable_json - disable_overview = arguments.disable_overview - disable_security_issue = arguments.disable_security_issue - ignore_commit_files = arguments.ignore_commit_files - disable_blocking = arguments.disable_blocking - allow_unverified = arguments.allow_unverified - - if disable_blocking: + if config.disable_blocking: global blocking_disabled blocking_disabled = True - files = arguments.files - log.info(f"Starting Socket Security Scan version {__version__}") - api_token = os.getenv("SOCKET_SECURITY_API_KEY") or arguments.api_token + # Validate API token + if not config.api_token: + cli_logger.info("Unable to find Socket API Token") + sys.exit(3) + + # Initialize Socket core components + socket_config = SocketConfig( + api_key=config.api_token, + allow_unverified_ssl=config.allow_unverified + ) + client = CliClient(socket_config) + core = Core(socket_config, client) + + # Load files try: - files = json.loads(files) + files = json.loads(config.files) is_repo = True except Exception as error: - log.error(f"Unable to parse {files}") + log.error(f"Unable to parse {config.files}") log.error(error) sys.exit(3) - if api_token is None: - log.info("Unable to find Socket API Token") - sys.exit(3) + + # Git setup try: - git_repo = Git(target_path) - if repo is None: - repo = git_repo.repo_name - if commit_sha is None or commit_sha == '': - commit_sha = git_repo.commit - if branch is None or branch == '': - branch = git_repo.branch - if committer is None or committer == '': - committer = git_repo.committer - if commit_message is None or commit_message == '': - commit_message = git_repo.commit_message - if len(files) == 0 and not ignore_commit_files: + git_repo = Git(config.target_path) + if not config.repo: + config.repo = git_repo.repo_name + if not config.commit_sha: + config.commit_sha = git_repo.commit + if not config.branch: + config.branch = git_repo.branch + if not config.committer: + config.committer = git_repo.committer + if not config.commit_message: + config.commit_message = git_repo.commit_message + if len(files) == 0 and not config.ignore_commit_files: files = git_repo.changed_files is_repo = True except InvalidGitRepositoryError: is_repo = False - ignore_commit_files = True - pass + config.ignore_commit_files = True except NoSuchPathError: - raise Exception(f"Unable to find path {target_path}") - # git_repo = None - if repo is None: + raise Exception(f"Unable to find path {config.target_path}") + + if not config.repo: log.info("Repo name needs to be set") sys.exit(2) - license_file = f"{repo}" - if branch is not None: - license_file += f"_{branch}" - license_file += ".json" + + # license_file = f"{repo}" + # if branch is not None: + # license_file += f"_{branch}" + # license_file += ".json" + scm = None - if scm_type == "github": + if config.scm == "github": from socketsecurity.core.github import Github scm = Github() - elif scm_type == 'gitlab': + elif config.scm == 'gitlab': from socketsecurity.core.gitlab import Gitlab scm = Gitlab() if scm is not None: - default_branch = scm.is_default_branch + config.default_branch = scm.is_default_branch - base_api_url = os.getenv("BASE_API_URL") or None - core = Core(token=api_token, request_timeout=1200, base_api_url=base_api_url, allow_unverified=allow_unverified) + + # Check for manifest changes no_change = True - if ignore_commit_files: + if config.ignore_commit_files: no_change = False elif is_repo and files is not None and len(files) > 0: log.info(files) no_change = core.match_supported_files(files) - set_as_pending_head = False - if default_branch: - set_as_pending_head = True + # Set up scan params + set_as_pending_head = config.default_branch + params = FullScanParams( - repo=repo, - branch=branch, - commit_message=commit_message, - commit_hash=commit_sha, - pull_request=pr_number, - committers=committer, - make_default_branch=default_branch, + repo=config.repo, + branch=config.branch, + commit_message=config.commit_message, + commit_hash=config.commit_sha, + pull_request=config.pr_number, + committers=config.committer, + make_default_branch=config.default_branch, set_as_pending_head=set_as_pending_head ) + + # Initialize diff diff = Diff() diff.id = "NO_DIFF_RAN" + + # Handle SCM-specific flows if scm is not None and scm.check_event_type() == "comment": log.info("Comment initiated flow") log.debug(f"Getting comments for Repo {scm.repository} for PR {scm.pr_number}") @@ -340,12 +155,11 @@ def main_code(): diff: Diff if no_change: log.info("No manifest files changes, skipping scan") - # log.info("No dependency changes") elif scm.check_event_type() == "diff": - diff = core.create_new_diff(target_path, params, workspace=target_path, no_change=no_change) + diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=no_change) log.info("Starting comment logic for PR/MR event") log.debug(f"Getting comments for Repo {scm.repository} for PR {scm.pr_number}") - comments = scm.get_comments_for_pr(repo, str(pr_number)) + comments = scm.get_comments_for_pr(config.repo, str(config.pr_number)) log.debug("Removing comment alerts") diff.new_alerts = Comments.remove_alerts(comments, diff.new_alerts) log.debug("Creating Dependency Overview Comment") @@ -364,19 +178,19 @@ def main_code(): overview_comment == "" or (len(comments) != 0 and comments.get("overview") is not None) ) - if len(diff.new_alerts) == 0 or disable_security_issue: + if len(diff.new_alerts) == 0 or config.disable_security_issue: if not update_old_security_comment: new_security_comment = False log.debug("No new alerts or security issue comment disabled") else: log.debug("Updated security comment with no new alerts") - if (len(diff.new_packages) == 0 and len(diff.removed_packages) == 0) or disable_overview: + if (len(diff.new_packages) == 0 and len(diff.removed_packages) == 0) or config.disable_overview: if not update_old_overview_comment: new_overview_comment = False log.debug("No new/removed packages or Dependency Overview comment disabled") else: log.debug("Updated overview comment with no dependencies") - log.debug(f"Adding comments for {scm_type}") + log.debug(f"Adding comments for {config.scm}") scm.add_socket_comments( security_comment, overview_comment, @@ -386,24 +200,26 @@ def main_code(): ) else: log.info("Starting non-PR/MR flow") - diff = core.create_new_diff(target_path, params, workspace=target_path, no_change=no_change) - if enable_json: + diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=no_change) + + # Use output handler for results + if config.enable_json: log.debug("Outputting JSON Results") - output_console_json(diff, sbom_file) + output_handler.output_console_json(diff, config.sbom_file) else: - output_console_comments(diff, sbom_file) + output_handler.output_console_comments(diff, config.sbom_file) else: - log.info("API Mode") - diff: Diff - diff = core.create_new_diff(target_path, params, workspace=target_path, no_change=no_change) - if enable_json: - output_console_json(diff, sbom_file) + cli_logger.info("API Mode") + diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=no_change) + if config.enable_json: + output_handler.output_console_json(diff, config.sbom_file) else: - output_console_comments(diff, sbom_file) - if diff is not None and license_mode: + output_handler.output_console_comments(diff, config.sbom_file) + + # Handle license generation + if diff is not None and config.generate_license: all_packages = {} for package_id in diff.packages: - package: Package package = diff.packages[package_id] output = { "id": package_id, @@ -416,6 +232,10 @@ def main_code(): "license_text": package.license_text } all_packages[package_id] = output + license_file = f"{config.repo}" + if config.branch: + license_file += f"_{config.branch}" + license_file += ".json" core.save_file(license_file, json.dumps(all_packages)) From 629a500f4c46a78f177c41b3bfdaee8030ed4fae Mon Sep 17 00:00:00 2001 From: Eric Hibbs Date: Wed, 13 Nov 2024 01:07:48 -0800 Subject: [PATCH 5/6] github scm passing tests; time for refactor --- pyproject.toml | 13 +- socketsecurity/core/__init__.py | 67 ++++-- .../core/{client.py => cli_client.py} | 0 socketsecurity/core/scm/__init__.py | 0 socketsecurity/core/scm/base.py | 37 ++++ socketsecurity/core/scm/client.py | 41 ++++ socketsecurity/core/{ => scm}/github.py | 103 ++++++--- socketsecurity/core/{ => scm}/gitlab.py | 0 socketsecurity/socketcli.py | 108 +++++----- tests/core/scm/test_github.py | 202 ++++++++++++++++++ tests/unit/test_client.py | 2 +- tests/unit/test_github.py | 0 12 files changed, 466 insertions(+), 107 deletions(-) rename socketsecurity/core/{client.py => cli_client.py} (100%) create mode 100644 socketsecurity/core/scm/__init__.py create mode 100644 socketsecurity/core/scm/base.py create mode 100644 socketsecurity/core/scm/client.py rename socketsecurity/core/{ => scm}/github.py (76%) rename socketsecurity/core/{ => scm}/gitlab.py (100%) create mode 100644 tests/core/scm/test_github.py create mode 100644 tests/unit/test_github.py diff --git a/pyproject.toml b/pyproject.toml index a571939..b38cfa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,8 +110,14 @@ exclude = [ # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or # McCabe complexity (`C901`) by default. -select = ["E4", "E7", "E9", "F"] -ignore = [] +select = [ + "E4", "E7", "E9", "F", # Current rules + "I", # isort + "F401", # Unused imports + "F403", # Star imports + "F405", # Star imports undefined + "F821", # Undefined names +] # Allow fix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] @@ -120,6 +126,9 @@ unfixable = [] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +[tool.ruff.lint.isort] +known-first-party = ["socketsecurity"] + [tool.ruff.format] # Like Black, use double quotes for strings. quote-style = "double" diff --git a/socketsecurity/core/__init__.py b/socketsecurity/core/__init__.py index 35810a9..ad43d96 100644 --- a/socketsecurity/core/__init__.py +++ b/socketsecurity/core/__init__.py @@ -1,32 +1,30 @@ +import base64 +import json import logging +import platform +import time +from glob import glob from pathlib import PurePath -from .utils import socket_globs -from .config import SocketConfig -from .client import CliClient -import requests from urllib.parse import urlencode -import base64 -import json + +import requests + +from socketsecurity import __version__ +from socketsecurity.core.classes import Alert, Diff, FullScan, FullScanParams, Issue, Package, Purl, Report, Repository from socketsecurity.core.exceptions import ( - APIFailure, APIKeyMissing, APIAccessDenied, APIInsufficientQuota, APIResourceNotFound, APICloudflareError + APIAccessDenied, + APICloudflareError, + APIFailure, + APIInsufficientQuota, + APIKeyMissing, + APIResourceNotFound, ) -from socketsecurity import __version__ -from socketsecurity.core.licenses import Licenses from socketsecurity.core.issues import AllIssues -from socketsecurity.core.classes import ( - Report, - Issue, - Package, - Alert, - FullScan, - FullScanParams, - Repository, - Diff, - Purl -) -import platform -from glob import glob -import time +from socketsecurity.core.licenses import Licenses + +from .cli_client import CliClient +from .config import SocketConfig +from .utils import socket_globs __all__ = [ "Core", @@ -270,6 +268,11 @@ def create_sbom_output(self, diff: Diff) -> dict: # TODO: verify what this does. It looks like it should be named "all_files_unsupported" @staticmethod def match_supported_files(files: list) -> bool: + """ + Checks if any of the files in the list match the supported file patterns + Returns True if NO files match (meaning no changes to manifest files) + Returns False if ANY files match (meaning there are manifest changes) + """ matched_files = [] not_matched = False for ecosystem in socket_globs: @@ -724,3 +727,21 @@ def save_file(file_name: str, content: str) -> None: # output = [] # for package_id in diff.packages: # purl = Core.create_purl(package_id, diff.packages) + + @staticmethod + def has_manifest_files(files: list) -> bool: + """ + Checks if any files in the list are supported manifest files. + Returns True if ANY files match our manifest patterns (meaning we need to scan) + Returns False if NO files match (meaning we can skip scanning) + """ + for ecosystem in socket_globs: + patterns = socket_globs[ecosystem] + for file_name in patterns: + pattern = patterns[file_name]["pattern"] + for file in files: + if "\\" in file: + file = file.replace("\\", "/") + if PurePath(file).match(pattern): + return True # Found a manifest file, no need to check further + return False # No manifest files found diff --git a/socketsecurity/core/client.py b/socketsecurity/core/cli_client.py similarity index 100% rename from socketsecurity/core/client.py rename to socketsecurity/core/cli_client.py diff --git a/socketsecurity/core/scm/__init__.py b/socketsecurity/core/scm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/socketsecurity/core/scm/base.py b/socketsecurity/core/scm/base.py new file mode 100644 index 0000000..715f6de --- /dev/null +++ b/socketsecurity/core/scm/base.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod +from typing import Dict + +from ..classes import Comment +from .client import ScmClient + + +class SCM(ABC): + def __init__(self, client: ScmClient): + self.client = client + + @abstractmethod + def check_event_type(self) -> str: + """Determine the type of event (push, pr, comment)""" + pass + + @abstractmethod + def add_socket_comments( + self, + security_comment: str, + overview_comment: str, + comments: Dict[str, Comment], + new_security_comment: bool = True, + new_overview_comment: bool = True + ) -> None: + """Add or update comments on PR""" + pass + + @abstractmethod + def get_comments_for_pr(self, repo: str, pr: str) -> Dict[str, Comment]: + """Get existing comments for PR""" + pass + + @abstractmethod + def remove_comment_alerts(self, comments: Dict[str, Comment]) -> None: + """Process and remove alerts from comments""" + pass diff --git a/socketsecurity/core/scm/client.py b/socketsecurity/core/scm/client.py new file mode 100644 index 0000000..e5bbb73 --- /dev/null +++ b/socketsecurity/core/scm/client.py @@ -0,0 +1,41 @@ +from abc import abstractmethod +from typing import Dict + +from ..cli_client import CliClient + + +class ScmClient(CliClient): + def __init__(self, token: str, api_url: str): + self.token = token + self.api_url = api_url + + @abstractmethod + def get_headers(self) -> Dict: + """Each SCM implements its own auth headers""" + pass + + def request(self, path: str, **kwargs): + """Override base request to use SCM-specific headers and base_url""" + headers = kwargs.pop('headers', None) or self.get_headers() + return super().request( + path=path, + headers=headers, + base_url=self.api_url, + **kwargs + ) + +class GithubClient(ScmClient): + def get_headers(self) -> Dict: + return { + 'Authorization': f"Bearer {self.token}", + 'User-Agent': 'SocketPythonScript/0.0.1', + "accept": "application/json" + } + +class GitlabClient(ScmClient): + def get_headers(self) -> Dict: + return { + 'Authorization': f"Bearer {self.token}", + 'User-Agent': 'SocketPythonScript/0.0.1', + "accept": "application/json" + } diff --git a/socketsecurity/core/github.py b/socketsecurity/core/scm/github.py similarity index 76% rename from socketsecurity/core/github.py rename to socketsecurity/core/scm/github.py index bd24339..81a431e 100644 --- a/socketsecurity/core/github.py +++ b/socketsecurity/core/scm/github.py @@ -1,30 +1,33 @@ import json import os -from socketsecurity.core import log, do_request -import requests -from socketsecurity.core.classes import Comment -from socketsecurity.core.scm_comments import Comments import sys +from dataclasses import dataclass +from git import Optional + +from socketsecurity.core import do_request, log +from socketsecurity.core.classes import Comment +from socketsecurity.core.scm_comments import Comments -global github_sha -global github_api_url -global github_ref_type -global github_event_name -global github_workspace -global github_repository -global github_ref_name -global github_actor -global default_branch -global github_env -global pr_number -global pr_name -global is_default_branch -global commit_message -global committer -global gh_api_token -global github_repository_owner -global event_action +# Declare all globals with initial None values +github_sha: Optional[str] = None +github_api_url: Optional[str] = None +github_ref_type: Optional[str] = None +github_event_name: Optional[str] = None +github_workspace: Optional[str] = None +github_repository: Optional[str] = None +github_ref_name: Optional[str] = None +github_actor: Optional[str] = None +default_branch: Optional[str] = None +github_env: Optional[str] = None +pr_number: Optional[str] = None +pr_name: Optional[str] = None +is_default_branch: bool = False +commit_message: Optional[str] = None +committer: Optional[str] = None +gh_api_token: Optional[str] = None +github_repository_owner: Optional[str] = None +event_action: Optional[str] = None github_variables = [ "GITHUB_SHA", @@ -45,11 +48,58 @@ "EVENT_ACTION" ] +@dataclass +class GithubConfig: + """Configuration from GitHub environment variables""" + sha: str + api_url: str + ref_type: str + event_name: str + workspace: str + repository: str + ref_name: str + default_branch: bool + pr_number: Optional[str] + pr_name: Optional[str] + commit_message: Optional[str] + actor: str + env: str + token: str + owner: str + event_action: Optional[str] + + @classmethod + def from_env(cls) -> 'GithubConfig': + """Create config from environment variables""" + token = os.getenv('GH_API_TOKEN') + if not token: + log.error("Unable to get Github API Token from GH_API_TOKEN") + sys.exit(2) + + return cls( + sha=os.getenv('GITHUB_SHA', ''), + api_url=os.getenv('GITHUB_API_URL', ''), + ref_type=os.getenv('GITHUB_REF_TYPE', ''), + event_name=os.getenv('GITHUB_EVENT_NAME', ''), + workspace=os.getenv('GITHUB_WORKSPACE', ''), + repository=os.getenv('GITHUB_REPOSITORY', '').split('/')[-1], + ref_name=os.getenv('GITHUB_REF_NAME', ''), + default_branch=os.getenv('DEFAULT_BRANCH', '').lower() == 'true', + pr_number=os.getenv('PR_NUMBER'), + pr_name=os.getenv('PR_NAME'), + commit_message=os.getenv('COMMIT_MESSAGE'), + actor=os.getenv('GITHUB_ACTOR', ''), + env=os.getenv('GITHUB_ENV', ''), + token=token, + owner=os.getenv('GITHUB_REPOSITORY_OWNER', ''), + event_action=os.getenv('EVENT_ACTION') + ) + + for env in github_variables: var_name = env.lower() globals()[var_name] = os.getenv(env) or None if var_name == "default_branch": - global is_default_branch if default_branch is None or default_branch.lower() == "false": is_default_branch = False else: @@ -233,15 +283,14 @@ def post_reaction(comment_id: int) -> None: def comment_reaction_exists(comment_id: int) -> bool: repo = github_repository.rsplit("/", 1)[1] path = f"repos/{github_repository_owner}/{repo}/issues/comments/{comment_id}/reactions" - response = do_request(path, headers=headers, base_url=github_api_url) - exists = False try: + response = do_request(path, headers=headers, base_url=github_api_url) data = response.json() for reaction in data: content = reaction.get("content") if content is not None and content == ":thumbsup:": - exists = True + return True except Exception as error: log.error(f"Unable to get reaction for {comment_id} for PR {pr_number}") log.error(error) - return exists + return False diff --git a/socketsecurity/core/gitlab.py b/socketsecurity/core/scm/gitlab.py similarity index 100% rename from socketsecurity/core/gitlab.py rename to socketsecurity/core/scm/gitlab.py diff --git a/socketsecurity/socketcli.py b/socketsecurity/socketcli.py index 6b0858f..5cc4e8e 100644 --- a/socketsecurity/socketcli.py +++ b/socketsecurity/socketcli.py @@ -1,39 +1,34 @@ import json +import sys + +from git import InvalidGitRepositoryError, NoSuchPathError +from socketsecurity.config import CliConfig from socketsecurity.core import Core -from socketsecurity.logging import initialize_logging, set_debug_mode -from socketsecurity.core.classes import FullScanParams, Diff +from socketsecurity.core.classes import Diff, FullScanParams +from socketsecurity.core.cli_client import CliClient +from socketsecurity.core.config import SocketConfig +from socketsecurity.core.git_interface import Git +from socketsecurity.core.logging import initialize_logging, set_debug_mode from socketsecurity.core.messages import Messages from socketsecurity.core.scm_comments import Comments -from socketsecurity.core.git_interface import Git -from git import InvalidGitRepositoryError, NoSuchPathError -from socketsecurity.config import CliConfig from socketsecurity.output import OutputHandler -from socketsecurity.core.config import SocketConfig -from socketsecurity.core.client import CliClient - -import sys - -socket_logger, cli_logger = initialize_logging() -log = cli_logger -blocking_disabled = False - - +socket_logger, log = initialize_logging() def cli(): try: main_code() except KeyboardInterrupt: - cli_logger.info("Keyboard Interrupt detected, exiting") + log.info("Keyboard Interrupt detected, exiting") config = CliConfig.from_args() # Get current config if not config.disable_blocking: sys.exit(2) else: sys.exit(0) except Exception as error: - cli_logger.error("Unexpected error when running the cli") - cli_logger.error(error) + log.error("Unexpected error when running the cli") + log.error(error) config = CliConfig.from_args() # Get current config if not config.disable_blocking: sys.exit(3) @@ -49,13 +44,9 @@ def main_code(): set_debug_mode(True) log.debug("Debug logging enabled") - if config.disable_blocking: - global blocking_disabled - blocking_disabled = True - # Validate API token if not config.api_token: - cli_logger.info("Unable to find Socket API Token") + log.info("Unable to find Socket API Token") sys.exit(3) # Initialize Socket core components @@ -66,11 +57,12 @@ def main_code(): client = CliClient(socket_config) core = Core(socket_config, client) - # Load files + # Load files - files defaults to "[]" in CliConfig try: - files = json.loads(config.files) - is_repo = True + files = json.loads(config.files) # Will always succeed with empty list by default + is_repo = True # FIXME: This is misleading - JSON parsing success doesn't indicate repo status except Exception as error: + # Only hits this if files was manually set to invalid JSON log.error(f"Unable to parse {config.files}") log.error(error) sys.exit(3) @@ -88,12 +80,12 @@ def main_code(): config.committer = git_repo.committer if not config.commit_message: config.commit_message = git_repo.commit_message - if len(files) == 0 and not config.ignore_commit_files: - files = git_repo.changed_files - is_repo = True + if files and not config.ignore_commit_files: # files is empty by default, so this is False unless files manually specified + files = git_repo.changed_files # Only gets git's changed files if files were manually specified + is_repo = True # Redundant since already True except InvalidGitRepositoryError: - is_repo = False - config.ignore_commit_files = True + is_repo = False # Overwrites previous True - this is the REAL repo status + config.ignore_commit_files = True # Silently changes config - should log this except NoSuchPathError: raise Exception(f"Unable to find path {config.target_path}") @@ -101,32 +93,35 @@ def main_code(): log.info("Repo name needs to be set") sys.exit(2) - # license_file = f"{repo}" - # if branch is not None: - # license_file += f"_{branch}" - # license_file += ".json" - scm = None if config.scm == "github": - from socketsecurity.core.github import Github + from socketsecurity.core.scm.github import Github scm = Github() elif config.scm == 'gitlab': - from socketsecurity.core.gitlab import Gitlab + from socketsecurity.core.scm.gitlab import Gitlab scm = Gitlab() if scm is not None: config.default_branch = scm.is_default_branch - # Check for manifest changes - no_change = True + # Combine manually specified files with git changes if applicable + files_to_check = set(json.loads(config.files)) # Start with manually specified files + + # Add git changes if this is a repo and we're not ignoring commit files + if is_repo and not config.ignore_commit_files: + files_to_check.update(git_repo.changed_files) + + # Determine if we need to scan based on manifest files + should_skip_scan = True # Default to skipping if config.ignore_commit_files: - no_change = False - elif is_repo and files is not None and len(files) > 0: - log.info(files) - no_change = core.match_supported_files(files) + should_skip_scan = False # Force scan if ignoring commit files + elif files_to_check: # If we have any files to check + should_skip_scan = not core.has_manifest_files(list(files_to_check)) - # Set up scan params - set_as_pending_head = config.default_branch + if should_skip_scan: + log.debug("No manifest files found in changes, skipping scan") + else: + log.debug("Found manifest files or forced scan, proceeding") params = FullScanParams( repo=config.repo, @@ -135,8 +130,8 @@ def main_code(): commit_hash=config.commit_sha, pull_request=config.pr_number, committers=config.committer, - make_default_branch=config.default_branch, - set_as_pending_head=set_as_pending_head + make_default_branch=config.default_branch, # This and + set_as_pending_head=config.default_branch # This are the same, do we need both? ) # Initialize diff @@ -145,6 +140,12 @@ def main_code(): # Handle SCM-specific flows if scm is not None and scm.check_event_type() == "comment": + # FIXME: This entire flow should be a separate command called "filter_ignored_alerts_in_comments" + # It's not related to scanning or diff generation - it just: + # 1. Triggers on comments in GitHub/GitLab + # 2. If comment was from Socket, checks for ignore reactions + # 3. Updates the comment to remove ignored alerts + # This is completely separate from the main scanning functionality log.info("Comment initiated flow") log.debug(f"Getting comments for Repo {scm.repository} for PR {scm.pr_number}") comments = scm.get_comments_for_pr(scm.repository, str(scm.pr_number)) @@ -152,11 +153,10 @@ def main_code(): scm.remove_comment_alerts(comments) elif scm is not None and scm.check_event_type() != "comment": log.info("Push initiated flow") - diff: Diff - if no_change: + if should_skip_scan: log.info("No manifest files changes, skipping scan") elif scm.check_event_type() == "diff": - diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=no_change) + diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=should_skip_scan) log.info("Starting comment logic for PR/MR event") log.debug(f"Getting comments for Repo {scm.repository} for PR {scm.pr_number}") comments = scm.get_comments_for_pr(config.repo, str(config.pr_number)) @@ -200,7 +200,7 @@ def main_code(): ) else: log.info("Starting non-PR/MR flow") - diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=no_change) + diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=should_skip_scan) # Use output handler for results if config.enable_json: @@ -209,8 +209,8 @@ def main_code(): else: output_handler.output_console_comments(diff, config.sbom_file) else: - cli_logger.info("API Mode") - diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=no_change) + log.info("API Mode") + diff = core.create_new_diff(config.target_path, params, workspace=config.target_path, no_change=should_skip_scan) if config.enable_json: output_handler.output_console_json(diff, config.sbom_file) else: diff --git a/tests/core/scm/test_github.py b/tests/core/scm/test_github.py new file mode 100644 index 0000000..ce226c0 --- /dev/null +++ b/tests/core/scm/test_github.py @@ -0,0 +1,202 @@ +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from socketsecurity.core.exceptions import APIAccessDenied +from socketsecurity.core.scm.github import Github, GithubConfig + + +@pytest.fixture +def mock_env_vars(): + return { + "GH_API_TOKEN": "fake-token", + "GITHUB_SHA": "abc123", + "GITHUB_API_URL": "https://api.github.com", + "GITHUB_REF_TYPE": "branch", + "GITHUB_EVENT_NAME": "pull_request", + "GITHUB_WORKSPACE": "/workspace", + "GITHUB_REPOSITORY": "owner/repo", + "GITHUB_REF_NAME": "main", + "DEFAULT_BRANCH": "true", + "PR_NUMBER": "123", + "PR_NAME": "test-pr", + "COMMIT_MESSAGE": "test commit", + "GITHUB_ACTOR": "test-user", + "GITHUB_ENV": "/github/env", + "GITHUB_REPOSITORY_OWNER": "owner", + "EVENT_ACTION": "opened" + } + +@pytest.fixture +def github_instance(mock_env_vars): + with patch.dict('os.environ', mock_env_vars), \ + patch('socketsecurity.core.scm.github.github_repository', 'owner/repo'), \ + patch('socketsecurity.core.scm.github.github_repository_owner', 'owner'), \ + patch('socketsecurity.core.scm.github.github_sha', 'abc123'), \ + patch('socketsecurity.core.scm.github.github_api_url', 'https://api.github.com'), \ + patch('socketsecurity.core.scm.github.github_ref_type', 'branch'), \ + patch('socketsecurity.core.scm.github.github_event_name', 'pull_request'), \ + patch('socketsecurity.core.scm.github.github_workspace', '/workspace'), \ + patch('socketsecurity.core.scm.github.github_ref_name', 'main'), \ + patch('socketsecurity.core.scm.github.default_branch', 'true'), \ + patch('socketsecurity.core.scm.github.is_default_branch', True), \ + patch('socketsecurity.core.scm.github.pr_number', '123'), \ + patch('socketsecurity.core.scm.github.pr_name', 'test-pr'), \ + patch('socketsecurity.core.scm.github.commit_message', 'test commit'), \ + patch('socketsecurity.core.scm.github.github_actor', 'test-user'), \ + patch('socketsecurity.core.scm.github.github_env', '/github/env'), \ + patch('socketsecurity.core.scm.github.gh_api_token', 'fake-token'), \ + patch('socketsecurity.core.scm.github.event_action', 'opened'): + return Github() + +class TestGithubConfig: + def test_from_env_success(self, mock_env_vars): + with patch.dict('os.environ', mock_env_vars): + config = GithubConfig.from_env() + assert config.token == "fake-token" + assert config.repository == "repo" + assert config.default_branch is True + + def test_from_env_missing_token(self): + with patch.dict('os.environ', {"GH_API_TOKEN": ""}), \ + pytest.raises(SystemExit) as exc: + GithubConfig.from_env() + assert exc.value.code == 2 + +class TestGithubEventTypes: + @pytest.mark.parametrize("event_name,pr_number,event_action,expected", [ + ("push", None, None, "main"), + ("push", "123", None, "diff"), + ("pull_request", None, "opened", "diff"), + ("pull_request", None, "synchronize", "diff"), + ("issue_comment", None, None, "comment"), + ]) + def test_check_event_type_valid(self, event_name, pr_number, event_action, expected): + with patch('socketsecurity.core.scm.github.github_event_name', event_name), \ + patch('socketsecurity.core.scm.github.pr_number', pr_number), \ + patch('socketsecurity.core.scm.github.event_action', event_action): + assert Github.check_event_type() == expected + + def test_check_event_type_unsupported_pr_action(self): + with patch('socketsecurity.core.scm.github.github_event_name', 'pull_request'), \ + patch('socketsecurity.core.scm.github.event_action', 'closed'), \ + pytest.raises(SystemExit) as exc: + Github.check_event_type() + assert exc.value.code == 0 + + def test_check_event_type_unknown(self): + with patch('socketsecurity.core.scm.github.github_event_name', 'unknown'), \ + pytest.raises(SystemExit) as exc: + Github.check_event_type() + assert exc.value.code == 0 + +class TestGithubComments: + @pytest.fixture + def mock_do_request(self): + with patch('socketsecurity.core.scm.github.do_request') as mock: + yield mock + + @pytest.fixture(autouse=True) + def setup_globals(self): + with patch.multiple('socketsecurity.core.scm.github', + github_repository='owner/repo', + github_repository_owner='owner', + pr_number='123', + github_api_url='https://api.github.com', + github_env='/github/env', + headers={'Authorization': 'Bearer fake-token'}): + yield + + def test_post_comment(self, mock_do_request): + Github.post_comment("test comment") + mock_do_request.assert_called_once() + assert mock_do_request.call_args[1]["method"] == "POST" + + def test_update_comment(self, mock_do_request): + Github.update_comment("updated comment", "123") + mock_do_request.assert_called_once() + assert mock_do_request.call_args[1]["method"] == "PATCH" + + def test_write_new_env(self): + m = mock_open() + with patch('builtins.open', m): + Github.write_new_env("TEST", "value\nwith\nnewlines") + m.assert_called_once_with("/github/env", "a") + handle = m() + handle.write.assert_called_once_with("TEST=value\\nwith\\nnewlines") + + def test_get_comments_for_pr_success(self, mock_do_request): + mock_response = MagicMock() + mock_response.json.return_value = [{ + "id": 1, + "body": "test comment", + "user": {"login": "test-user"} + }] + mock_do_request.return_value = mock_response + + comments = Github.get_comments_for_pr("repo", "123") + assert isinstance(comments, dict) + mock_do_request.assert_called_once() + + def test_get_comments_for_pr_error(self, mock_do_request): + mock_response = MagicMock() + mock_response.json.return_value = {"error": "test error"} + mock_do_request.return_value = mock_response + + comments = Github.get_comments_for_pr("repo", "123") + assert comments == {} + +class TestGithubReactions: + @pytest.fixture(autouse=True) + def setup_mocks(self): + with patch.dict('socketsecurity.core.__dict__', {'encoded_key': 'fake-encoded-key'}), \ + patch('requests.request') as mock_request, \ + patch.multiple('socketsecurity.core.scm.github', + github_repository='owner/repo', + github_repository_owner='owner', + pr_number='123', + github_api_url='https://api.github.com', + gh_api_token='fake-token', + headers={'Authorization': 'Bearer fake-token'}): + + # Set up a default successful response + mock_response = MagicMock() + mock_response.json.return_value = [] + mock_response.status_code = 200 + mock_response.text = "" + mock_request.return_value = mock_response + + yield mock_request + + def test_post_reaction(self, setup_mocks): + mock_request = setup_mocks + Github.post_reaction(123) + mock_request.assert_called_once() + assert mock_request.call_args[0][0] == "POST" + assert '"content": "+1"' in mock_request.call_args[1]["data"] + + def test_comment_reaction_exists_true(self, setup_mocks): + mock_request = setup_mocks + mock_response = MagicMock() + mock_response.json.return_value = [{"content": ":thumbsup:"}] + mock_response.status_code = 200 + mock_request.return_value = mock_response + + assert Github.comment_reaction_exists(123) is True + + def test_comment_reaction_exists_false(self, setup_mocks): + mock_request = setup_mocks + mock_response = MagicMock() + mock_response.json.return_value = [{"content": ":thumbsdown:"}] + mock_response.status_code = 200 + mock_request.return_value = mock_response + + assert Github.comment_reaction_exists(123) is False + + def test_comment_reaction_exists_error(self, setup_mocks): + mock_request = setup_mocks + mock_request.side_effect = APIAccessDenied("Unauthorized") + + with patch('socketsecurity.core.log.error'): # Suppress error logs + result = Github.comment_reaction_exists(123) + assert result is False \ No newline at end of file diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0077687..46ad4a4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -1,7 +1,7 @@ import pytest from unittest.mock import Mock, patch import requests -from socketsecurity.core.client import CliClient +from socketsecurity.core.cli_client import CliClient from socketsecurity.core.config import SocketConfig from socketsecurity.core.exceptions import APIFailure diff --git a/tests/unit/test_github.py b/tests/unit/test_github.py new file mode 100644 index 0000000..e69de29 From aefa447264b4c08b3a2201ba8ead28ddd8b0bdc9 Mon Sep 17 00:00:00 2001 From: Eric Hibbs Date: Wed, 13 Nov 2024 10:00:57 -0800 Subject: [PATCH 6/6] refactored github --- socketsecurity/core/scm/github.py | 327 +++++++++++------------------- 1 file changed, 115 insertions(+), 212 deletions(-) diff --git a/socketsecurity/core/scm/github.py b/socketsecurity/core/scm/github.py index 81a431e..06e845c 100644 --- a/socketsecurity/core/scm/github.py +++ b/socketsecurity/core/scm/github.py @@ -9,44 +9,6 @@ from socketsecurity.core.classes import Comment from socketsecurity.core.scm_comments import Comments -# Declare all globals with initial None values -github_sha: Optional[str] = None -github_api_url: Optional[str] = None -github_ref_type: Optional[str] = None -github_event_name: Optional[str] = None -github_workspace: Optional[str] = None -github_repository: Optional[str] = None -github_ref_name: Optional[str] = None -github_actor: Optional[str] = None -default_branch: Optional[str] = None -github_env: Optional[str] = None -pr_number: Optional[str] = None -pr_name: Optional[str] = None -is_default_branch: bool = False -commit_message: Optional[str] = None -committer: Optional[str] = None -gh_api_token: Optional[str] = None -github_repository_owner: Optional[str] = None -event_action: Optional[str] = None - -github_variables = [ - "GITHUB_SHA", - "GITHUB_API_URL", - "GITHUB_REF_TYPE", - "GITHUB_EVENT_NAME", - "GITHUB_WORKSPACE", - "GITHUB_REPOSITORY", - "GITHUB_REF_NAME", - "DEFAULT_BRANCH", - "PR_NUMBER", - "PR_NAME", - "COMMIT_MESSAGE", - "GITHUB_ACTOR", - "GITHUB_ENV", - "GH_API_TOKEN", - "GITHUB_REPOSITORY_OWNER", - "EVENT_ACTION" -] @dataclass class GithubConfig: @@ -67,6 +29,7 @@ class GithubConfig: token: str owner: str event_action: Optional[str] + headers: dict @classmethod def from_env(cls) -> 'GithubConfig': @@ -76,13 +39,19 @@ def from_env(cls) -> 'GithubConfig': log.error("Unable to get Github API Token from GH_API_TOKEN") sys.exit(2) + repository = os.getenv('GITHUB_REPOSITORY', '') + owner = os.getenv('GITHUB_REPOSITORY_OWNER', '') + if '/' in repository: + owner = repository.split('/')[0] + repository = repository.split('/')[1] + return cls( sha=os.getenv('GITHUB_SHA', ''), api_url=os.getenv('GITHUB_API_URL', ''), ref_type=os.getenv('GITHUB_REF_TYPE', ''), event_name=os.getenv('GITHUB_EVENT_NAME', ''), workspace=os.getenv('GITHUB_WORKSPACE', ''), - repository=os.getenv('GITHUB_REPOSITORY', '').split('/')[-1], + repository=repository, ref_name=os.getenv('GITHUB_REF_NAME', ''), default_branch=os.getenv('DEFAULT_BRANCH', '').lower() == 'true', pr_number=os.getenv('PR_NUMBER'), @@ -91,206 +60,140 @@ def from_env(cls) -> 'GithubConfig': actor=os.getenv('GITHUB_ACTOR', ''), env=os.getenv('GITHUB_ENV', ''), token=token, - owner=os.getenv('GITHUB_REPOSITORY_OWNER', ''), - event_action=os.getenv('EVENT_ACTION') + owner=owner, + event_action=os.getenv('EVENT_ACTION'), + headers={ + 'Authorization': f"Bearer {token}", + 'User-Agent': 'SocketPythonScript/0.0.1', + "accept": "application/json" + } ) -for env in github_variables: - var_name = env.lower() - globals()[var_name] = os.getenv(env) or None - if var_name == "default_branch": - if default_branch is None or default_branch.lower() == "false": - is_default_branch = False - else: - is_default_branch = True - if var_name != "gh_api_token": - value = globals()[var_name] = os.getenv(env) or None - log.debug(f"{env}={value}") +class Github: + def __init__(self, config: Optional[GithubConfig] = None): + self.config = config or GithubConfig.from_env() -headers = { - 'Authorization': f"Bearer {gh_api_token}", - 'User-Agent': 'SocketPythonScript/0.0.1', - "accept": "application/json" -} + if not self.config.token: + log.error("Unable to get Github API Token") + sys.exit(2) + def check_event_type(self) -> str: + if self.config.event_name.lower() == "push": + if not self.config.pr_number: + return "main" + return "diff" + elif self.config.event_name.lower() == "pull_request": + if self.config.event_action and self.config.event_action.lower() in ['opened', 'synchronize']: + return "diff" + log.info(f"Pull Request Action {self.config.event_action} is not a supported type") + sys.exit(0) + elif self.config.event_name.lower() == "issue_comment": + return "comment" + + log.error(f"Unknown event type {self.config.event_name}") + sys.exit(0) + + def post_comment(self, body: str) -> None: + path = f"repos/{self.config.owner}/{self.config.repository}/issues/{self.config.pr_number}/comments" + payload = json.dumps({"body": body}) + do_request( + path=path, + payload=payload, + method="POST", + headers=self.config.headers, + base_url=self.config.api_url + ) -class Github: - commit_sha: str - api_url: str - ref_type: str - event_name: str - workspace: str - repository: str - ref_name: str - default_branch: str - is_default_branch: bool - pr_number: int - pr_name: str - commit_message: str - committer: str - github_env: str - api_token: str - project_id: int - event_action: str + def update_comment(self, body: str, comment_id: str) -> None: + path = f"repos/{self.config.owner}/{self.config.repository}/issues/comments/{comment_id}" + payload = json.dumps({"body": body}) + do_request( + path=path, + payload=payload, + method="PATCH", + headers=self.config.headers, + base_url=self.config.api_url + ) - def __init__(self): - self.commit_sha = github_sha - self.api_url = github_api_url - self.ref_type = github_ref_type - self.event_name = github_event_name - self.workspace = github_workspace - self.repository = github_repository - if "/" in self.repository: - self.repository = self.repository.rsplit("/")[1] - self.branch = github_ref_name - self.default_branch = default_branch - self.is_default_branch = is_default_branch - self.pr_number = pr_number - self.pr_name = pr_name - self.commit_message = commit_message - self.committer = github_actor - self.github_env = github_env - self.api_token = gh_api_token - self.project_id = 0 - self.event_action = event_action - if self.api_token is None: - print("Unable to get Github API Token from GH_API_TOKEN") - sys.exit(2) + def write_new_env(self, name: str, content: str) -> None: + with open(self.config.env, "a") as f: + new_content = content.replace("\n", "\\n") + f.write(f"{name}={new_content}") - @staticmethod - def check_event_type() -> str: - if github_event_name.lower() == "push": - if pr_number is None or pr_number == "" or pr_number == "0": - event_type = "main" - else: - event_type = "diff" - elif github_event_name.lower() == "pull_request": - if event_action is not None and event_action != "" and ( - event_action.lower() == "opened" or event_action.lower() == 'synchronize'): - event_type = "diff" - else: - log.info(f"Pull Request Action {event_action} is not a supported type") - sys.exit(0) - elif github_event_name.lower() == "issue_comment": - event_type = "comment" + def get_comments_for_pr(self) -> dict: + path = f"repos/{self.config.owner}/{self.config.repository}/issues/{self.config.pr_number}/comments" + raw_comments = Comments.process_response( + do_request(path, headers=self.config.headers, base_url=self.config.api_url) + ) + + comments = {} + if "error" not in raw_comments: + for item in raw_comments: + comment = Comment(**item) + comments[comment.id] = comment + comment.body_list = comment.body.split("\n") else: - event_type = None - log.error(f"Unknown event type {github_event_name}") - sys.exit(0) - return event_type + log.error(raw_comments) + + return Comments.check_for_socket_comments(comments) - @staticmethod def add_socket_comments( - security_comment: str, - overview_comment: str, - comments: dict, - new_security_comment: bool = True, - new_overview_comment: bool = True + self, + security_comment: str, + overview_comment: str, + comments: dict, + new_security_comment: bool = True, + new_overview_comment: bool = True ) -> None: - existing_overview_comment = comments.get("overview") - existing_security_comment = comments.get("security") if new_overview_comment: log.debug("New Dependency Overview comment") - if existing_overview_comment is not None: + if overview := comments.get("overview"): log.debug("Previous version of Dependency Overview, updating") - existing_overview_comment: Comment - Github.update_comment(overview_comment, str(existing_overview_comment.id)) + self.update_comment(overview_comment, str(overview.id)) else: log.debug("No previous version of Dependency Overview, posting") - Github.post_comment(overview_comment) + self.post_comment(overview_comment) + if new_security_comment: log.debug("New Security Issue Comment") - if existing_security_comment is not None: + if security := comments.get("security"): log.debug("Previous version of Security Issue comment, updating") - existing_security_comment: Comment - Github.update_comment(security_comment, str(existing_security_comment.id)) + self.update_comment(security_comment, str(security.id)) else: log.debug("No Previous version of Security Issue comment, posting") - Github.post_comment(security_comment) - - @staticmethod - def post_comment(body: str) -> None: - repo = github_repository.rsplit("/", 1)[1] - path = f"repos/{github_repository_owner}/{repo}/issues/{pr_number}/comments" - payload = { - "body": body - } - payload = json.dumps(payload) - do_request(path, payload=payload, method="POST", headers=headers, base_url=github_api_url) - - @staticmethod - def update_comment(body: str, comment_id: str) -> None: - repo = github_repository.rsplit("/", 1)[1] - path = f"repos/{github_repository_owner}/{repo}/issues/comments/{comment_id}" - payload = { - "body": body - } - payload = json.dumps(payload) - do_request(path, payload=payload, method="PATCH", headers=headers, base_url=github_api_url) + self.post_comment(security_comment) - @staticmethod - def write_new_env(name: str, content: str) -> None: - file = open(github_env, "a") - new_content = content.replace("\n", "\\n") - env_output = f"{name}={new_content}" - file.write(env_output) - - @staticmethod - def get_comments_for_pr(repo: str, pr: str) -> dict: - path = f"repos/{github_repository_owner}/{repo}/issues/{pr}/comments" - raw_comments = Comments.process_response(do_request(path, headers=headers, base_url=github_api_url)) - comments = {} - if "error" not in raw_comments: - for item in raw_comments: - comment = Comment(**item) - comments[comment.id] = comment - for line in comment.body.split("\n"): - comment.body_list.append(line) - else: - log.error(raw_comments) - socket_comments = Comments.check_for_socket_comments(comments) - return socket_comments - - @staticmethod - def remove_comment_alerts(comments: dict): - security_alert = comments.get("security") - if security_alert is not None: - security_alert: Comment + def remove_comment_alerts(self, comments: dict) -> None: + if security_alert := comments.get("security"): new_body = Comments.process_security_comment(security_alert, comments) - Github.handle_ignore_reactions(comments) - Github.update_comment(new_body, str(security_alert.id)) - - @staticmethod - def handle_ignore_reactions(comments: dict) -> None: - for comment in comments["ignore"]: - comment: Comment - if "SocketSecurity ignore" in comment.body: - if not Github.comment_reaction_exists(comment.id): - Github.post_reaction(comment.id) - - @staticmethod - def post_reaction(comment_id: int) -> None: - repo = github_repository.rsplit("/", 1)[1] - path = f"repos/{github_repository_owner}/{repo}/issues/comments/{comment_id}/reactions" - payload = { - "content": "+1" - } - payload = json.dumps(payload) - do_request(path, payload=payload, method="POST", headers=headers, base_url=github_api_url) + self.handle_ignore_reactions(comments) + self.update_comment(new_body, str(security_alert.id)) + + def handle_ignore_reactions(self, comments: dict) -> None: + for comment in comments.get("ignore", []): + if "SocketSecurity ignore" in comment.body and not self.comment_reaction_exists(comment.id): + self.post_reaction(comment.id) + + def post_reaction(self, comment_id: int) -> None: + path = f"repos/{self.config.owner}/{self.config.repository}/issues/comments/{comment_id}/reactions" + payload = json.dumps({"content": "+1"}) + do_request( + path=path, + payload=payload, + method="POST", + headers=self.config.headers, + base_url=self.config.api_url + ) - @staticmethod - def comment_reaction_exists(comment_id: int) -> bool: - repo = github_repository.rsplit("/", 1)[1] - path = f"repos/{github_repository_owner}/{repo}/issues/comments/{comment_id}/reactions" + def comment_reaction_exists(self, comment_id: int) -> bool: + path = f"repos/{self.config.owner}/{self.config.repository}/issues/comments/{comment_id}/reactions" try: - response = do_request(path, headers=headers, base_url=github_api_url) - data = response.json() - for reaction in data: - content = reaction.get("content") - if content is not None and content == ":thumbsup:": + response = do_request(path, headers=self.config.headers, base_url=self.config.api_url) + for reaction in response.json(): + if reaction.get("content") == ":thumbsup:": return True except Exception as error: - log.error(f"Unable to get reaction for {comment_id} for PR {pr_number}") + log.error(f"Unable to get reaction for {comment_id} for PR {self.config.pr_number}") log.error(error) return False