Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
elyousfi5 committed Jan 29, 2024
1 parent 43f5544 commit 874c666
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 113 deletions.
93 changes: 36 additions & 57 deletions agent/api_manager/osv_service_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import dataclasses
import json
import logging
from typing import Optional, Iterator
from typing import Iterator, Any

import requests
import tenacity
from ostorlab.agent.kb import kb
from ostorlab.agent.mixins import agent_report_vulnerability_mixin

Expand All @@ -13,15 +14,8 @@
logger = logging.getLogger(__name__)

OSV_ENDPOINT = "https://api.osv.dev/v1/query"

RISK_RATING_MAPPING = {
"POTENTIALLY": agent_report_vulnerability_mixin.RiskRating.POTENTIALLY,
"LOW": agent_report_vulnerability_mixin.RiskRating.LOW,
"MEDIUM": agent_report_vulnerability_mixin.RiskRating.MEDIUM,
"MODERATE": agent_report_vulnerability_mixin.RiskRating.MEDIUM,
"HIGH": agent_report_vulnerability_mixin.RiskRating.HIGH,
"CRITICAL": agent_report_vulnerability_mixin.RiskRating.CRITICAL,
}
NUMBER_RETRIES = 3
WAIT_BETWEEN_RETRIES = 2


@dataclasses.dataclass
Expand All @@ -35,24 +29,25 @@ class VulnData:
cves: list[str]


@tenacity.retry(
stop=tenacity.stop_after_attempt(NUMBER_RETRIES),
wait=tenacity.wait_fixed(WAIT_BETWEEN_RETRIES),
retry=tenacity.retry_if_exception_type(),
retry_error_callback=lambda retry_state: retry_state.outcome.result()
if retry_state.outcome is not None
else None,
)
def query_osv_api(
package_name: str | None, version: str | None, ecosystem: str | None
) -> Optional[str]:
"""Query the OSv API with the specified version, package name, and ecosystem.
) -> dict[str, Any] | None:
"""Query the OSV API with the specified version, package name, and ecosystem.
Args:
version: The version to query.
package_name: The name of the package to query.
ecosystem: The ecosystem of the package e.g., javascript.
Returns:
The API response text if successful, None otherwise.
"""
if version is None:
logger.error("Error: Version must not be None.")
return None
if package_name is None:
logger.error("Error: Package name must not be None.")
return None

data = {
"version": version,
"package": {"name": package_name, "ecosystem": ecosystem},
Expand All @@ -61,28 +56,24 @@ def query_osv_api(
response = requests.post(OSV_ENDPOINT, data=json.dumps(data), headers=headers)

if response.status_code == 200:
return response.text
else:
logger.error(f"Error: Request failed with status code {response.status_code}")
return None
resp: dict[str, Any] = response.json()
return resp

return None


def parse_output(
api_response: Optional[str], api_key: str | None = None
api_response: dict[str, Any], api_key: str | None = None
) -> list[VulnData]:
"""Parse the OSv API response to extract vulnerabilities.
"""Parse the OSV API response to extract vulnerabilities.
Args:
api_response: The API response text.
api_response: The API response json.
api_key: The API key.
Returns:
Parsed output.
"""
if api_response is None:
logger.error("Error: API response must not be None.")
return []

try:
response_data = json.loads(api_response)
vulnerabilities = response_data.get("vulns", [])
vulnerabilities = api_response.get("vulns", [])

parsed_vulns = []
for vulnerability in vulnerabilities:
Expand All @@ -103,16 +94,13 @@ def parse_output(
summary = vulnerability.get("summary", "")
fixed_version = _get_fixed_version(vulnerability.get("affected"))
cvss_v3_vector = _get_cvss_v3_vector(vulnerability.get("severity"))
references = []
for reference in vulnerability.get("references"):
references.append(reference.get("url"))
vuln = VulnData(
risk=risk,
description=description,
summary=summary,
fixed_version=fixed_version,
cvss_v3_vector=cvss_v3_vector,
references=vulnerability.get("references"),
references=vulnerability.get("references", {}),
cves=filtered_cves,
)
parsed_vulns.append(vuln)
Expand All @@ -125,7 +113,7 @@ def parse_output(


def construct_vuln(
parsed_vulns: list[VulnData], package_name: str | None, package_version: str | None
parsed_vulns: list[VulnData], package_name: str, package_version: str
) -> Iterator[osv_output_handler.Vulnerability]:
"""Construct Vulneravilities from the parse output.
Args:
Expand Down Expand Up @@ -160,35 +148,26 @@ def construct_vuln(
recommendation=recommendation,
),
technical_detail=f"{vuln.description} \n#### CVEs:\n {', '.join(vuln.cves)}",
risk_rating=RISK_RATING_MAPPING[vuln.risk],
risk_rating=agent_report_vulnerability_mixin.RiskRating[vuln.risk],
)


def _get_fixed_version(
affected_data: list[dict[str, list[dict[str, list[dict[str, str]]]]]],
affected_data: list[dict[str, Any]],
) -> str:
fixed_version = ""
if affected_data is not None:
try:
ranges_data: list[dict[str, list[dict[str, str]]]] = affected_data[0].get(
"ranges", []
)
if ranges_data:
events_data = ranges_data[0].get("events", [])
if len(events_data) > 1:
fixed_version = events_data[1].get("fixed", "")
except IndexError:
logger.warning("Can't get the fixed version.")
ranges_data: list[dict[str, Any]] = affected_data[0].get("ranges", [])
if ranges_data is not None and len(ranges_data) > 0:
events_data = ranges_data[0].get("events", [])
if len(events_data) > 1:
fixed_version = events_data[1].get("fixed", "")

return fixed_version


def _get_cvss_v3_vector(severity_data: list[dict[str, str]]) -> str:
cvss_v3_vector = ""
if severity_data:
try:
cvss_data = severity_data[0].get("score", "")
cvss_v3_vector = cvss_data if isinstance(cvss_data, str) else ""
except IndexError:
logger.warning("Can't get the cvss v3 vector.")
return cvss_v3_vector
if severity_data is not None and len(severity_data) > 0:
return severity_data[0].get("score", "")

return ""
98 changes: 60 additions & 38 deletions agent/osv_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
"requirements.txt",
"yarn.lock",
]
OSV_ECOSYSTEM_MAPPING = {
"JAVASCRIPT_LIBRARY": "npm",
"JAVA_LIBRARY": "Maven",
"FLUTTER_FRAMEWORK": "Pub",
"CORDOVA_LIBRARY": "npm",
"DOTNET_FRAMEWORK": "NuGet",
"IOS_FRAMEWORK": "SwiftURL",
}

logging.basicConfig(
format="%(message)s",
Expand Down Expand Up @@ -120,50 +128,14 @@ def process(self, message: m.Message) -> None:
"""
logger.info("processing message of selector : %s", message.selector)
if message.selector == "v3.asset.file":
content = _get_content(message)
if content is None or content == b"":
logger.warning("Message file content is empty.")
return
for file_name in SUPPORTED_OSV_FILE_NAMES:
scan_results = _run_osv(file_name, content)
if scan_results is not None:
logger.info("Found valid name for file: %s", file_name)
self._emit_results(scan_results)
break
self._process_asset_file(message)

elif message.selector in [
"v3.fingerprint.file.android.library",
"v3.fingerprint.file.ios.library",
"v3.fingerprint.file.library",
]:
package_name = message.data.get("library_name")
package_version = message.data.get("library_version")
package_type = message.data.get("library_type")

api_result = osv_service_api.query_osv_api(
package_name=package_name,
version=package_version,
ecosystem=package_type,
)
if api_result is None:
return

parsed_osv_output = osv_service_api.parse_output(api_result, self.api_key)

if len(parsed_osv_output) == 0:
return

vulnz = osv_service_api.construct_vuln(
parsed_osv_output, package_name, package_version
)

for vuln in vulnz:
logger.info("Reporting vulnerability.")
self.report_vulnerability(
entry=vuln.entry,
technical_detail=vuln.technical_detail,
risk_rating=vuln.risk_rating,
)
self._process_fingerprint_file(message)

def _emit_results(self, output: str) -> None:
"""Parses results and emits vulnerabilities."""
Expand All @@ -176,6 +148,56 @@ def _emit_results(self, output: str) -> None:
risk_rating=vuln.risk_rating,
)

def _process_asset_file(self, message: m.Message) -> None:
"""Process message of type v3.asset.file."""
content = _get_content(message)
if content is None or content == b"":
logger.warning("Message file content is empty.")
return
for file_name in SUPPORTED_OSV_FILE_NAMES:
scan_results = _run_osv(file_name, content)
if scan_results is not None:
logger.info("Found valid name for file: %s", file_name)
self._emit_results(scan_results)
break

def _process_fingerprint_file(self, message: m.Message) -> None:
"""Process message of type v3.fingerprint.file."""
package_name = message.data.get("library_name")
package_version = message.data.get("library_version")
package_type = message.data.get("library_type")

if package_version is None:
logger.error("Error: Version must not be None.")
return None
if package_name is None:
logger.error("Error: Package name must not be None.")
return None

api_result = osv_service_api.query_osv_api(
package_name=package_name,
version=package_version,
ecosystem=OSV_ECOSYSTEM_MAPPING.get(str(package_type), ""),
)
if api_result is None:
return None

parsed_osv_output = osv_service_api.parse_output(api_result, self.api_key)

if len(parsed_osv_output) == 0:
return None

vulnz = osv_service_api.construct_vuln(
parsed_osv_output, package_name, package_version
)

for vuln in vulnz:
self.report_vulnerability(
entry=vuln.entry,
technical_detail=vuln.technical_detail,
risk_rating=vuln.risk_rating,
)


def _is_valid_osv_result(results: str | None) -> bool:
"""Check if the results are valid."""
Expand Down
12 changes: 9 additions & 3 deletions tests/api_manager/osv_service_api_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Unit tests for OSV service api."""
from typing import Any

from agent.api_manager import osv_service_api


Expand All @@ -9,12 +11,16 @@ def testQueryOSVOutput_withPackage_returnListOfVulnerabilities() -> None:
)

assert osv_output is not None
assert "vulns" in osv_output
assert "Jinja2 sandbox escape via string formatting" in osv_output
assert isinstance(osv_output, dict) is True
assert len(osv_output["vulns"]) == 11
assert any(
"Jinja2 sandbox escape via string formatting" in vuln["summary"]
for vuln in osv_output["vulns"]
)


def testPasrseOSVOutput_withValidResponse_returnListOfVulnzData(
osv_api_output: str,
osv_api_output: dict[str, Any],
) -> None:
"""Parse the output of osv api call."""
cves_data = osv_service_api.parse_output(osv_api_output)
Expand Down
17 changes: 8 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,11 @@ def output_file(tmp_path: pathlib.Path) -> str:


@pytest.fixture(name="osv_api_output")
def osv_api_output() -> str:
"""Return a temporary file and write JSON data to it"""
with open(
f"{pathlib.Path(__file__).parent.parent}/tests/files/osv_api_output.json",
"r",
encoding="utf-8",
) as of:
data = of.read()
return data
def osv_api_output() -> dict[str, Any]:
"""Read and return the OSV API output from a file as a dict."""
file_path = (
f"{pathlib.Path(__file__).parent.parent}/tests/files/osv_api_output.json"
)
data = pathlib.Path(file_path).read_text(encoding="utf-8")
json_data: dict[str, Any] = json.loads(data)
return json_data
9 changes: 3 additions & 6 deletions tests/osv_agent_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Unittests for OSV agent."""
import subprocess
from typing import Union, Callable
from typing import Union, Callable, Any

import requests_mock as rq_mock
from ostorlab.agent.message import message
Expand Down Expand Up @@ -184,13 +184,10 @@ def testAgentOSV_whenFingerprintMessage_processMessage(
test_agent: osv_agent.OSVAgent,
agent_mock: list[message.Message],
agent_persist_mock: dict[Union[str, bytes], Union[str, bytes]],
scan_message_file: message.Message,
osv_output_as_dict: dict[str, str],
fake_osv_output: str,
mocker: plugin.MockerFixture,
osv_api_output: str,
osv_api_output: dict[str, Any],
) -> None:
"""Unittest for the full life cycle of the agent:
"""Unit test for the full life cycle of the agent:
case where the osv scan a package.
"""
mocker.patch(
Expand Down
Loading

0 comments on commit 874c666

Please sign in to comment.