From 159cbcb7511fd1c0381d3c682551afbaa0bf6e81 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Fri, 12 Jan 2024 16:05:46 -0800 Subject: [PATCH] add CLI option for json or custom reporting from toml --- modelscan/cli.py | 51 +++++++++++++++++++++++++++++++++++-------- modelscan/reports.py | 7 ++---- modelscan/settings.py | 9 ++++---- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/modelscan/cli.py b/modelscan/cli.py index b5bd5ff..1ffd394 100644 --- a/modelscan/cli.py +++ b/modelscan/cli.py @@ -1,16 +1,21 @@ import logging import sys import os +import importlib from pathlib import Path -from typing import Optional +from typing import Optional, Dict, Any from tomlkit import parse import click from modelscan.modelscan import ModelScan -from modelscan.reports import ConsoleReport +from modelscan.reports import Report from modelscan._version import __version__ -from modelscan.settings import SettingsUtils, DEFAULT_SETTINGS +from modelscan.settings import ( + SettingsUtils, + DEFAULT_SETTINGS, + AVAILABLE_REPORTING_MODULES, +) from modelscan.tools.cli_utils import DefaultGroup logger = logging.getLogger("modelscan") @@ -69,6 +74,20 @@ def cli() -> None: type=click.Path(exists=True, dir_okay=False), help="Specify a settings file to use for the scan. Defaults to ./modelscan-settings.toml.", ) +@click.option( + "-f", + "--format", + type=click.Choice(["console", "json", "custom"]), + default="console", + help="Format of the output. Options are console or json (default: console)", +) +@click.option( + "-o", + "--output-file", + type=click.Path(), + default=None, + help="Optional json reporting output file", +) @cli.command( help="[Default] Scan a model file or diretory for ability to execute suspicious actions. " ) # type: ignore @@ -79,6 +98,8 @@ def scan( path: Optional[str], show_skipped: bool, settings_file: Optional[str], + format: str, + output_file: Path, ) -> int: logger.setLevel(logging.INFO) logger.addHandler(logging.StreamHandler(stream=sys.stdout)) @@ -112,13 +133,25 @@ def scan( else: raise click.UsageError("Command line must include a path") - # get reporting module + report_settings: Dict[str, Any] = {} + if format == "custom": + reporting_module = settings["reporting"]["module"] # type: ignore[index] + report_settings = settings["reporting"]["settings"] # type: ignore[index] + else: + reporting_module = AVAILABLE_REPORTING_MODULES[format] + + report_settings["show_skipped"] = show_skipped + report_settings["output_file"] = output_file + + try: + (modulename, classname) = reporting_module.rsplit(".", 1) + imported_module = importlib.import_module(name=modulename, package=classname) + + report_class: Report = getattr(imported_module, classname) + report_class.generate(scan=modelscan, settings=report_settings) - # ConsoleReport.generate( - # scan=modelscan, - # show_skipped=show_skipped, - # settings=settings["reporting"] - # ) + except Exception as e: + logger.error(f"Error generating report using {reporting_module}: {e}") # exit code 3 if no supported files were passed if not modelscan.scanned: diff --git a/modelscan/reports.py b/modelscan/reports.py index 721e05a..231bf23 100644 --- a/modelscan/reports.py +++ b/modelscan/reports.py @@ -23,7 +23,6 @@ def __init__(self) -> None: @staticmethod def generate( scan: ModelScan, - show_skipped: bool = False, settings: Dict[str, Any] = {}, ) -> Optional[str]: """ @@ -41,7 +40,6 @@ class ConsoleReport(Report): @staticmethod def generate( scan: ModelScan, - show_skipped: bool = False, settings: Dict[str, Any] = {}, ) -> None: issues_by_severity = scan.issues.group_by_severity() @@ -77,7 +75,7 @@ def generate( print( f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list." ) - if show_skipped: + if settings["show_skipped"]: print(f"\nSkipped files list:\n") for file_name in scan.skipped: print(str(file_name)) @@ -87,11 +85,10 @@ class JSONReport(Report): @staticmethod def generate( scan: ModelScan, - show_skipped: bool = False, settings: Dict[str, Any] = {}, ) -> None: report: Dict[str, Any] = scan._generate_results() - if not show_skipped: + if not settings["show_skipped"]: del report["skipped"] print(json.dumps(report)) diff --git a/modelscan/settings.py b/modelscan/settings.py index 00b008d..441adfa 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -2,6 +2,10 @@ from typing import Any +AVAILABLE_REPORTING_MODULES = { + "console": "modelscan.reports.ConsoleReport", + "json": "modelscan.reports.JSONReport", +} DEFAULT_SETTINGS = { "supported_zip_extensions": [".zip", ".npz"], @@ -86,11 +90,8 @@ }, "MEDIUM": {}, "LOW": {}, - "reporting_module": { - "module": "modelscan.reports.ConsoleReport", - "settings": {}, - }, }, + "reporting": {"module": "modelscan.reports.ConsoleReport", "settings": {}}, }