Skip to content

Commit

Permalink
Fix issues identified by mypy (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
nightlark committed Oct 23, 2024
1 parent 11ce1cc commit b3597fd
Show file tree
Hide file tree
Showing 17 changed files with 131 additions and 86 deletions.
4 changes: 2 additions & 2 deletions surfactant/cmd/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Any, List, Optional

import click

Expand Down Expand Up @@ -31,7 +31,7 @@ def config(key: str, values: Optional[List[str]]):
else:
# Set the configuration value
# Convert 'true' and 'false' strings to boolean
converted_values = []
converted_values: List[Any] = []
for value in values:
if value.lower() == "true":
converted_values.append(True)
Expand Down
97 changes: 67 additions & 30 deletions surfactant/cmd/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pathlib
import queue
import re
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import click
from loguru import logger
Expand Down Expand Up @@ -51,11 +51,11 @@ def get_software_entry(
else:
sw_entry.containerPath = [re.sub("^" + root_path, container_uuid + "/", filepath)]
sw_entry.recordedInstitution = user_institution_name
sw_children = []
sw_children: List[Software] = []

# for unsupported file types, details are just empty; this is the case for archive files (e.g. zip, tar, iso)
# as well as intel hex or motorola s-rec files
extracted_info_results = pluginmanager.hook.extract_file_info(
extracted_info_results: List[object] = pluginmanager.hook.extract_file_info(
sbom=parent_sbom,
software=sw_entry,
filename=filepath,
Expand All @@ -66,8 +66,19 @@ def get_software_entry(
)
# add metadata extracted from the file, and set SBOM fields if metadata has relevant info
for file_details in extracted_info_results:
# None as details doesn't add any useful info...
if file_details is None:
continue

# ensure metadata exists for the software entry
if sw_entry.metadata is None:
sw_entry.metadata = []
sw_entry.metadata.append(file_details)

# before checking for keys, make sure the file details object is a dictionary
if not isinstance(file_details, Dict):
continue

# common case is Windows PE file has these details under FileInfo, otherwise fallback default value is fine
if "FileInfo" in file_details:
fi = file_details["FileInfo"]
Expand All @@ -90,6 +101,9 @@ def get_software_entry(
if "revision_number" in file_details["ole"]:
sw_entry.version = file_details["ole"]["revision_number"]
if "author" in file_details["ole"]:
# ensure the vendor list has been created
if sw_entry.vendor is None:
sw_entry.vendor = []
sw_entry.vendor.append(file_details["ole"]["author"])
if "comments" in file_details["ole"]:
sw_entry.comments = file_details["ole"]["comments"]
Expand Down Expand Up @@ -149,9 +163,9 @@ def determine_install_prefix(
Optional[str]: The install prefix to use, or 'NoneType' if an install path shouldn't be listed.
"""
install_prefix = None
if entry.installPrefix or entry.installPrefix == "":
if entry and (entry.installPrefix or entry.installPrefix == ""):
install_prefix = entry.installPrefix
elif not skip_extract_path:
elif not skip_extract_path and extract_path is not None:
# pathlib doesn't include the trailing slash
epath = pathlib.Path(extract_path)
if epath.is_file():
Expand Down Expand Up @@ -252,16 +266,16 @@ def get_default_from_config(option: str, fallback: Optional[Any] = None) -> Any:
# Disable positional argument linter check -- could make keyword-only, but then defaults need to be set
# pylint: disable-next=too-many-positional-arguments
def sbom(
config_file,
sbom_outfile,
input_sbom,
skip_gather,
skip_relationships,
skip_install_path,
recorded_institution,
output_format,
input_format,
include_all_files,
config_file: str,
sbom_outfile: click.File,
input_sbom: click.File,
skip_gather: bool,
skip_relationships: bool,
skip_install_path: bool,
recorded_institution: str,
output_format: str,
input_format: str,
include_all_files: bool,
):
"""Generate a sbom configured in CONFIG_FILE and output to SBOM_OUTPUT.
Expand Down Expand Up @@ -289,11 +303,13 @@ def sbom(
if not validate_config(config):
return

context = queue.Queue()
context: queue.Queue[ContextEntry] = queue.Queue()

for entry in config:
context.put(ContextEntry(**entry))
for cfg_entry in config:
context.put(ContextEntry(**cfg_entry))

# define the new_sbom variable type
new_sbom: SBOM
if not input_sbom:
new_sbom = SBOM()
else:
Expand Down Expand Up @@ -322,7 +338,11 @@ def sbom(
user_institution_name=recorded_institution,
)
archive_entry = new_sbom.find_software(parent_entry.sha256)
if Software.check_for_hash_collision(archive_entry, parent_entry):
if (
archive_entry
and parent_entry
and Software.check_for_hash_collision(archive_entry, parent_entry)
):
logger.warning(
f"Hash collision between {archive_entry.name} and {parent_entry.name}; unexpected results may occur"
)
Expand Down Expand Up @@ -351,17 +371,20 @@ def sbom(
)
entry.installPrefix = entry.installPrefix.replace("\\", "\\\\")

for epath in entry.extractPaths:
for epath_str in entry.extractPaths:
# convert to pathlib.Path, ensures trailing "/" won't be present and some more consistent path formatting
epath = pathlib.Path(epath)
epath = pathlib.Path(epath_str)
install_prefix = determine_install_prefix(
entry, epath, skip_extract_path=skip_install_path
)
logger.trace("Extracted Path: " + epath.as_posix())

# variable used to track software entries to add to the SBOM
entries: List[Software]

# handle individual file case, since os.walk doesn't
if epath.is_file():
entries: List[Software] = []
entries = []
filepath = epath.as_posix()
# breakpoint()
try:
Expand Down Expand Up @@ -404,11 +427,11 @@ def sbom(
)
dir_symlinks.append((install_source, install_dest))

entries: List[Software] = []
for f in files:
entries = []
for file in files:
# os.path.join will insert an OS specific separator between cdir and f
# need to make sure that separator is a / and not a \ on windows
filepath = pathlib.Path(cdir, f).as_posix()
filepath = pathlib.Path(cdir, file).as_posix()
# TODO: add CI tests for generating SBOMs in scenarios with symlinks... (and just generally more CI tests overall...)
# Record symlink details but don't run info extractors on them
if os.path.islink(filepath):
Expand Down Expand Up @@ -488,6 +511,14 @@ def sbom(

# Add symlinks to install paths and file names
for software in new_sbom.software:
# ensure fileName, installPath, and metadata lists for the software entry have been created
# for a user supplied input SBOM, there are no guarantees
if software.fileName is None:
software.fileName = []
if software.installPath is None:
software.installPath = []
if software.metadata is None:
software.metadata = []
if software.sha256 in filename_symlinks:
filename_symlinks_added = []
for filename in filename_symlinks[software.sha256]:
Expand All @@ -511,6 +542,8 @@ def sbom(
for software in new_sbom.software:
# NOTE: this probably doesn't actually add any containerPath symlinks
for paths in (software.containerPath, software.installPath):
if paths is None:
continue
paths_to_add = []
for path in paths:
for link_source, link_dest in dir_symlinks:
Expand All @@ -521,10 +554,14 @@ def sbom(
paths_to_add.append(path.replace(link_dest, link_source, 1))
if paths_to_add:
found_md_installpathsymlinks = False
for md in software.metadata:
if "installPathSymlinks" in md:
found_md_installpathsymlinks = True
md["installPathSymlinks"] += paths_to_add
# make sure software.metadata list has been initialized
if software.metadata is None:
software.metadata = []
if isinstance(software.metadata, Iterable):
for md in software.metadata:
if isinstance(md, Dict) and "installPathSymlinks" in md:
found_md_installpathsymlinks = True
md["installPathSymlinks"] += paths_to_add
if not found_md_installpathsymlinks:
software.metadata.append({"installPathSymlinks": paths_to_add})
paths += paths_to_add
Expand All @@ -542,7 +579,7 @@ def sbom(


def resolve_link(
path: str, cur_dir: str, extract_dir: str, install_prefix: str = None
path: str, cur_dir: str, extract_dir: str, install_prefix: Optional[str] = None
) -> Union[str, None]:
assert cur_dir.startswith(extract_dir)
# Links seen before
Expand Down
3 changes: 2 additions & 1 deletion surfactant/cmd/merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import uuid as uuid_module
from collections import deque
from typing import Dict, List

import click
from loguru import logger
Expand Down Expand Up @@ -83,7 +84,7 @@ def construct_relationship_graph(sbom: SBOM):
sbom (SBOM): The sbom to generate relationship graph from.
"""
# construct a graph for adding a system relationship to all root software entries
rel_graph = {}
rel_graph: Dict[str, List[str]] = {}
# add all UUIDs as nodes in the graph
for system in sbom.systems:
rel_graph[system.UUID] = []
Expand Down
5 changes: 3 additions & 2 deletions surfactant/configmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import platform
from pathlib import Path
from threading import Lock
from typing import Any, Optional, Union
from typing import Any, Dict, Optional, Union

import tomlkit

Expand All @@ -19,7 +19,8 @@ class ConfigManager:
config_file_path (Path): The path to the configuration file.
"""

_instances = {}
_initialized: bool = False
_instances: Dict[str, "ConfigManager"] = {}
_lock = Lock()

def __new__(
Expand Down
2 changes: 1 addition & 1 deletion surfactant/infoextractors/elf_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def extract_file_info(sbom: SBOM, software: Software, filename: str, filetype: s
}


def extract_elf_info(filename):
def extract_elf_info(filename: str) -> object:
with open(filename, "rb") as f:
try:
elf = ELFFile(f)
Expand Down
2 changes: 1 addition & 1 deletion surfactant/infoextractors/java_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def handle_java_class(info: Dict[str, Any], class_info: javatools.JavaClassInfo)


def extract_java_info(filename: str, filetype: str) -> object:
info = {"javaClasses": {}}
info: Dict[str, Any] = {"javaClasses": {}}
if filetype in ("JAR", "EAR", "WAR"):
with javatools.jarinfo.JarInfo(filename) as jarinfo:
for class_ in jarinfo.get_classes():
Expand Down
2 changes: 1 addition & 1 deletion surfactant/infoextractors/js_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def extract_file_info(sbom: SBOM, software: Software, filename: str, filetype: s
return extract_js_info(filename)


def extract_js_info(filename):
def extract_js_info(filename: str) -> object:
js_info: Dict[str, Any] = {"jsLibraries": []}
js_lib_file = pathlib.Path(__file__).parent / "js_library_patterns.json"

Expand Down
2 changes: 1 addition & 1 deletion surfactant/infoextractors/mach_o_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def extract_mach_o_info(filename: str) -> object:
except OSError:
return {}

file_details: Dict[str:Any] = {"OS": "MacOS", "numBinaries": binaries.size, "binaries": []}
file_details: Dict[str, Any] = {"OS": "MacOS", "numBinaries": binaries.size, "binaries": []}

# Iterate over all binaries in the FAT binary
for binary in binaries:
Expand Down
2 changes: 1 addition & 1 deletion surfactant/infoextractors/ole_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def extract_file_info(sbom: SBOM, software: Software, filename: str, filetype: s
return extract_ole_info(filename)


def extract_ole_info(filename):
def extract_ole_info(filename: str) -> object:
file_details: Dict[str, Any] = {}

ole = olefile.OleFileIO(filename)
Expand Down
21 changes: 10 additions & 11 deletions surfactant/infoextractors/pe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pathlib
import re
from typing import Any, Dict
from typing import Any, Dict, List, Optional

import defusedxml.ElementTree
import dnfile
Expand Down Expand Up @@ -79,8 +79,7 @@ def extract_file_info(sbom: SBOM, software: Software, filename: str, filetype: s
}


def extract_pe_info(filename):
dnfile.fast_load = False
def extract_pe_info(filename: str) -> object:
try:
pe = dnfile.dnPE(filename, fast_load=False)
except (OSError, dnfile.PEFormatError):
Expand Down Expand Up @@ -169,7 +168,7 @@ def extract_pe_info(filename):
assembly_refs.append(get_assemblyref_info(ar_info))
file_details["dotnetAssemblyRef"] = assembly_refs
if implmap_info := getattr(dnet_mdtables, "ImplMap", None):
imp_modules = []
imp_modules: List[Dict[str, Any]] = []
for im_info in implmap_info:
insert_implmap_info(im_info, imp_modules)
file_details["dotnetImplMap"] = imp_modules
Expand All @@ -190,7 +189,7 @@ def extract_pe_info(filename):
return file_details


def add_core_assembly_info(asm_dict, asm_info):
def add_core_assembly_info(asm_dict: Dict[str, Any], asm_info):
# REFERENCE: https://github.com/malwarefrank/dnfile/blob/096de1b3/src/dnfile/stream.py#L36-L39
# HeapItemString value will be decoded string, or None if there was a UnicodeDecodeError
asm_dict["Name"] = asm_info.Name.value if asm_info.Name.value else asm_info.raw_data.hex()
Expand Down Expand Up @@ -233,7 +232,7 @@ def add_assembly_flags_info(asm_dict, asm_info):
}


def get_assembly_info(asm_info):
def get_assembly_info(asm_info) -> Dict[str, Any]:
asm: Dict[str, Any] = {}
add_core_assembly_info(asm, asm_info)
# REFERENCE: https://github.com/malwarefrank/dnfile/blob/fcccdaf/src/dnfile/enums.py#L851-L863
Expand All @@ -243,7 +242,7 @@ def get_assembly_info(asm_info):
return asm


def get_assemblyref_info(asmref_info):
def get_assemblyref_info(asmref_info) -> Dict[str, Any]:
asmref: Dict[str, Any] = {}
add_core_assembly_info(asmref, asmref_info)
# REFERENCE: https://github.com/malwarefrank/dnfile/blob/096de1b3/src/dnfile/stream.py#L62-L66
Expand All @@ -254,7 +253,7 @@ def get_assemblyref_info(asmref_info):
return asmref


def insert_implmap_info(im_info, imp_modules):
def insert_implmap_info(im_info, imp_modules: List[Dict[str, Any]]):
# REFERENCE: https://github.com/malwarefrank/dnfile/blob/096de1b3/src/dnfile/stream.py#L36-L39
# HeapItemString value will be decoded string, or None if there was a UnicodeDecodeError
dllName = (
Expand Down Expand Up @@ -284,7 +283,7 @@ def get_xmlns_and_tag(uri):

# check for manifest file on Windows (note: could also be a resource contained within an exe/dll)
# return any info that could be useful for establishing "Uses" relationships later
def get_windows_manifest_info(filename):
def get_windows_manifest_info(filename: str) -> Optional[Dict[str, Any]]:
binary_filepath = pathlib.Path(filename)
manifest_filepath = binary_filepath.with_suffix(binary_filepath.suffix + ".manifest")
if manifest_filepath.exists():
Expand Down Expand Up @@ -428,7 +427,7 @@ def get_assemblyBinding_info(ab_et, config_filepath=""):

# DLL redirection summary: redirection file with name_of_exe.local (contents are ignored) makes a check for mydll.dll happen in the application directory first,
# regardless of what the full path specified for LoadLibrary or LoadLibraryEx is (if no dll found in local directory, uses the typical search order)
def check_windows_dll_redirection_local(filename):
def check_windows_dll_redirection_local(filename: str):
binary_filepath = pathlib.Path(filename)
config_filepath = binary_filepath.with_suffix(binary_filepath.suffix + ".local")
return config_filepath.exists()
Expand All @@ -437,7 +436,7 @@ def check_windows_dll_redirection_local(filename):
# check for an application configuration file and return (potentially) useful information
# https://learn.microsoft.com/en-us/dotnet/framework/deployment/how-the-runtime-locates-assemblies#application-configuration-file
# https://learn.microsoft.com/en-us/windows/win32/sbscs/application-configuration-files
def get_windows_application_config_info(filename):
def get_windows_application_config_info(filename: str):
binary_filepath = pathlib.Path(filename)
config_filepath = binary_filepath.with_suffix(binary_filepath.suffix + ".config")
if config_filepath.exists():
Expand Down
Loading

0 comments on commit b3597fd

Please sign in to comment.