Skip to content

Commit

Permalink
add CLI option for json or custom reporting from toml
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko committed Jan 13, 2024
1 parent 6c75687 commit 159cbcb
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
51 changes: 42 additions & 9 deletions modelscan/cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import logging
import sys
import os
import importlib
from pathlib import Path
from typing import Optional
from typing import Optional, Dict, Any
from tomlkit import parse

import click

from modelscan.modelscan import ModelScan
from modelscan.reports import ConsoleReport
from modelscan.reports import Report
from modelscan._version import __version__
from modelscan.settings import SettingsUtils, DEFAULT_SETTINGS
from modelscan.settings import (
SettingsUtils,
DEFAULT_SETTINGS,
AVAILABLE_REPORTING_MODULES,
)
from modelscan.tools.cli_utils import DefaultGroup

logger = logging.getLogger("modelscan")
Expand Down Expand Up @@ -69,6 +74,20 @@ def cli() -> None:
type=click.Path(exists=True, dir_okay=False),
help="Specify a settings file to use for the scan. Defaults to ./modelscan-settings.toml.",
)
@click.option(
"-f",
"--format",
type=click.Choice(["console", "json", "custom"]),
default="console",
help="Format of the output. Options are console or json (default: console)",
)
@click.option(
"-o",
"--output-file",
type=click.Path(),
default=None,
help="Optional json reporting output file",
)
@cli.command(
help="[Default] Scan a model file or diretory for ability to execute suspicious actions. "
) # type: ignore
Expand All @@ -79,6 +98,8 @@ def scan(
path: Optional[str],
show_skipped: bool,
settings_file: Optional[str],
format: str,
output_file: Path,
) -> int:
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
Expand Down Expand Up @@ -112,13 +133,25 @@ def scan(
else:
raise click.UsageError("Command line must include a path")

# get reporting module
report_settings: Dict[str, Any] = {}
if format == "custom":
reporting_module = settings["reporting"]["module"] # type: ignore[index]
report_settings = settings["reporting"]["settings"] # type: ignore[index]
else:
reporting_module = AVAILABLE_REPORTING_MODULES[format]

report_settings["show_skipped"] = show_skipped
report_settings["output_file"] = output_file

try:
(modulename, classname) = reporting_module.rsplit(".", 1)
imported_module = importlib.import_module(name=modulename, package=classname)

report_class: Report = getattr(imported_module, classname)
report_class.generate(scan=modelscan, settings=report_settings)

# ConsoleReport.generate(
# scan=modelscan,
# show_skipped=show_skipped,
# settings=settings["reporting"]
# )
except Exception as e:
logger.error(f"Error generating report using {reporting_module}: {e}")

# exit code 3 if no supported files were passed
if not modelscan.scanned:
Expand Down
7 changes: 2 additions & 5 deletions modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(self) -> None:
@staticmethod
def generate(
scan: ModelScan,
show_skipped: bool = False,
settings: Dict[str, Any] = {},
) -> Optional[str]:
"""
Expand All @@ -41,7 +40,6 @@ class ConsoleReport(Report):
@staticmethod
def generate(
scan: ModelScan,
show_skipped: bool = False,
settings: Dict[str, Any] = {},
) -> None:
issues_by_severity = scan.issues.group_by_severity()
Expand Down Expand Up @@ -77,7 +75,7 @@ def generate(
print(
f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list."
)
if show_skipped:
if settings["show_skipped"]:
print(f"\nSkipped files list:\n")
for file_name in scan.skipped:
print(str(file_name))
Expand All @@ -87,11 +85,10 @@ class JSONReport(Report):
@staticmethod
def generate(
scan: ModelScan,
show_skipped: bool = False,
settings: Dict[str, Any] = {},
) -> None:
report: Dict[str, Any] = scan._generate_results()
if not show_skipped:
if not settings["show_skipped"]:
del report["skipped"]

print(json.dumps(report))
Expand Down
9 changes: 5 additions & 4 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from typing import Any

AVAILABLE_REPORTING_MODULES = {
"console": "modelscan.reports.ConsoleReport",
"json": "modelscan.reports.JSONReport",
}

DEFAULT_SETTINGS = {
"supported_zip_extensions": [".zip", ".npz"],
Expand Down Expand Up @@ -86,11 +90,8 @@
},
"MEDIUM": {},
"LOW": {},
"reporting_module": {
"module": "modelscan.reports.ConsoleReport",
"settings": {},
},
},
"reporting": {"module": "modelscan.reports.ConsoleReport", "settings": {}},
}


Expand Down

0 comments on commit 159cbcb

Please sign in to comment.