Skip to content

Commit

Permalink
Change: Introduce more flexible NVDResults class
Browse files Browse the repository at this point in the history
Before all NVD API classes methods returned an async iterator. This
didn't allow much control of what the user actually wants and how the
requests are issued. To improve the situation a new NVDResults class is
returned which itself is an async iterator so that the previous API is
kept compatible.

But additionally the NVDResults instance allows to get the plain JSON
data, the number of available results and also to iterate over chunks of
results (which the NVD API is always returning).

Most important improvement the NVDResults instance keeps the state. That
means if an http error occurs it is possible to request the same data
again. With the old API the requests need to start from the beginning if
something did go wrong. For example if we downloaded already 100k CVEs
and a http error was raised we needed to start from CVE number 1 again.
With the new implementation we can just continue with the last request
again.
  • Loading branch information
bjoernricks committed Nov 29, 2023
1 parent 1ebde9f commit 0b6b52b
Show file tree
Hide file tree
Showing 6 changed files with 721 additions and 135 deletions.
255 changes: 252 additions & 3 deletions pontos/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,26 @@
from abc import ABC
from datetime import datetime, timezone
from types import TracebackType
from typing import Any, Dict, Optional, Type, Union
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
Generator,
Generic,
Iterator,
Optional,
Sequence,
Type,
TypeVar,
Union,
)

from httpx import AsyncClient, Response, Timeout
from httpx import URL, AsyncClient, Response, Timeout

from pontos.errors import PontosError
from pontos.helper import snake_case

SLEEP_TIMEOUT = 30.0 # in seconds
Expand Down Expand Up @@ -78,6 +94,239 @@ def convert_camel_case(dct: Dict[str, Any]) -> Dict[str, Any]:
return converted


class NoMoreResults(PontosError):
"""
Raised if the NVD API has no more results to consume
"""


class InvalidState(PontosError):
"""
Raised if the state of the NVD API is invalid
"""


T = TypeVar("T")

result_iterator_func = Callable[[JSON], Iterator[T]]


class NVDResults(Generic[T], AsyncIterable[T], Awaitable["NVDResults"]):
"""
A generic object for accessing the results of a NVD API response
It implements the pagination and will issue requests against the NVD API.
"""

def __init__(
self,
api: "NVDApi",
params: Params,
result_func: result_iterator_func,
*,
request_results: Optional[int] = None,
results_per_page: Optional[int] = None,
start_index: int = 0,
) -> None:
self._api = api
self._params = params
self._url: Optional[URL] = None

self._data: Optional[JSON] = None
self._it: Optional[Iterator[T]] = None
self._total_results: Optional[int] = None
self._downloaded_results: int = 0

self._start_index = start_index
self._request_results = request_results
self._results_per_page = results_per_page

self._current_index = start_index
self._current_request_results = request_results
self._current_results_per_page = results_per_page

self._result_func = result_func

async def chunks(self) -> AsyncIterator[Sequence[T]]:
"""
Return the results in chunks
The size of the chunks is defined by results_per_page.
Examples:
.. code-block:: python
nvd_results: NVDResults = ...
async for results in nvd_results.chunks():
for result in results:
print(result)
"""
while True:
try:
if self._it:
yield list(self._it)
await self._next_iterator()
except NoMoreResults:
return

Check warning on line 171 in pontos/nvd/api.py

View check run for this annotation

Codecov / codecov/patch

pontos/nvd/api.py#L171

Added line #L171 was not covered by tests

async def items(self) -> AsyncIterator[T]:
"""
Return the results of the NVD API response
Examples:
.. code-block:: python
nvd_results: NVDResults = ...
async for result in nvd_results.items():
print(result)
"""
while True:
try:
if self._it:
for result in self._it:
yield result
await self._next_iterator()
except NoMoreResults:
return

async def json(self) -> Optional[JSON]:
"""
Return the result from the NVD API request as JSON
Examples:
.. code-block:: python
nvd_results: NVDResults = ...
while data := await nvd_results.json():
print(data)
Returns:
The response data as JSON or None if the response is exhausted.
"""
try:
if not self._data:
await self._next_iterator()

data = self._data
self._data = None
return data
except NoMoreResults:
return None

def __len__(self) -> int:
"""
Get the number of available result items for a NVD API request
Examples:
.. code-block:: python
nvd_results: NVDResults = ...
total_results = len(nvd_results) # None because it hasn't been awaited yet
json = await nvd_results.json() # request the plain JSON data
total_results = len(nvd_results) # contains the total number of results now
nvd_results: NVDResults = ...
total_results = len(nvd_results) # None because it hasn't been awaited yet
async for result in nvd_results:
print(result)
total_results = len(nvd_results) # contains the total number of results now
Returns:
The total number of available results if the NVDResults has been awaited
"""
if self._total_results is None:
raise InvalidState(
f"{self.__class__.__name__} has not been awaited yet."
)
return self._total_results

def __aiter__(self) -> AsyncIterator[T]:
"""
Return the results of the NVD API response
Same as the items() method. @see items()
Examples:
.. code-block:: python
nvd_results: NVDResults = ...
async for result in nvd_results:
print(result)
"""
return self.items()

def __await__(self) -> Generator[Any, None, "NVDResults"]:
"""
Request the next results from the NVD API
Examples:
.. code-block:: python
nvd_results: NVDResults = ...
print(len(nvd_results)) # None, because no request has been send yet
await nvd_results # creates a request to the NVD API
print(len(nvd_results))
Returns:
The response data as JSON or None if the response is exhausted.
"""

return self._next_iterator().__await__()

async def _load_next_data(self) -> None:
if (
not self._current_request_results
or self._downloaded_results < self._current_request_results
):
params = self._params
params["startIndex"] = self._current_index

if self._current_results_per_page is not None:
params["resultsPerPage"] = self._current_results_per_page

response = await self._api._get(params=params)
response.raise_for_status()

self._url = response.url
data: JSON = response.json(object_hook=convert_camel_case)

self._data = data
self._current_results_per_page = int(data["results_per_page"]) # type: ignore
self._total_results = int(data["total_results"]) # type: ignore
self._current_index += self._current_results_per_page
self._downloaded_results += self._current_results_per_page

if not self._current_request_results:
self._current_request_results = self._total_results

if (
self._request_results
and self._downloaded_results + self._current_results_per_page
> self._request_results
):
# avoid downloading more results then requested
self._current_results_per_page = (

Check warning on line 311 in pontos/nvd/api.py

View check run for this annotation

Codecov / codecov/patch

pontos/nvd/api.py#L311

Added line #L311 was not covered by tests
self._request_results - self._downloaded_results
)

else:
raise NoMoreResults()

async def _get_next_iterator(self) -> Iterator[T]:
await self._load_next_data()
return self._result_func(self._data) # type: ignore

async def _next_iterator(self) -> "NVDResults":
self._it = await self._get_next_iterator()
return self

def __repr__(self) -> str:
return f'<{self.__class__.__name__} url="{self._url}" total_results={self._total_results} start_index={self._start_index} current_index={self._current_index} results_per_page={self._results_per_page}>'

Check warning on line 327 in pontos/nvd/api.py

View check run for this annotation

Codecov / codecov/patch

pontos/nvd/api.py#L327

Added line #L327 was not covered by tests


class NVDApi(ABC):
"""
Abstract base class for querying the NIST NVD API.
Expand Down Expand Up @@ -155,7 +404,7 @@ async def _get(
params: Optional[Params] = None,
) -> Response:
"""
A request against the NIST NVD CVE REST API.
A request against the NIST NVD REST API.
"""
headers = self._request_headers()

Expand Down
Loading

0 comments on commit 0b6b52b

Please sign in to comment.