diff --git a/modelscan/cli.py b/modelscan/cli.py index 7383d7d..db93869 100644 --- a/modelscan/cli.py +++ b/modelscan/cli.py @@ -1,16 +1,21 @@ import logging import sys import os +import importlib from pathlib import Path -from typing import Optional +from typing import Optional, Dict, Any from tomlkit import parse import click from modelscan.modelscan import ModelScan -from modelscan.reports import ConsoleReport +from modelscan.reports import Report from modelscan._version import __version__ -from modelscan.settings import SettingsUtils, DEFAULT_SETTINGS +from modelscan.settings import ( + SettingsUtils, + DEFAULT_SETTINGS, + DEFAULT_REPORTING_MODULES, +) from modelscan.tools.cli_utils import DefaultGroup logger = logging.getLogger("modelscan") @@ -69,6 +74,20 @@ def cli() -> None: type=click.Path(exists=True, dir_okay=False), help="Specify a settings file to use for the scan. Defaults to ./modelscan-settings.toml.", ) +@click.option( + "-r", + "--reporting-format", + type=click.Choice(["console", "json", "custom"]), + default="console", + help="Format of the output. Options are console, json, or custom (to be defined in settings-file). Default is console.", +) +@click.option( + "-o", + "--output-file", + type=click.Path(), + default=None, + help="Optional file name for output report", +) @cli.command( help="[Default] Scan a model file or diretory for ability to execute suspicious actions. " ) # type: ignore @@ -79,6 +98,8 @@ def scan( path: Optional[str], show_skipped: bool, settings_file: Optional[str], + reporting_format: str, + output_file: Path, ) -> int: logger.setLevel(logging.INFO) logger.addHandler(logging.StreamHandler(stream=sys.stdout)) @@ -111,12 +132,26 @@ def scan( modelscan.scan(pathlibPath) else: raise click.UsageError("Command line must include a path") - ConsoleReport.generate( - modelscan.issues, - modelscan.errors, - modelscan._skipped, - show_skipped=show_skipped, - ) + + report_settings: Dict[str, Any] = {} + if reporting_format == "custom": + reporting_module = settings["reporting"]["module"] # type: ignore[index] + else: + reporting_module = DEFAULT_REPORTING_MODULES[reporting_format] + + report_settings = settings["reporting"]["settings"] # type: ignore[index] + report_settings["show_skipped"] = show_skipped + report_settings["output_file"] = output_file + + try: + (modulename, classname) = reporting_module.rsplit(".", 1) + imported_module = importlib.import_module(name=modulename, package=classname) + + report_class: Report = getattr(imported_module, classname) + report_class.generate(scan=modelscan, settings=report_settings) + + except Exception as e: + logger.error(f"Error generating report using {reporting_module}: {e}") # exit code 3 if no supported files were passed if not modelscan.scanned: diff --git a/modelscan/reports.py b/modelscan/reports.py index f0e9fdb..231bf23 100644 --- a/modelscan/reports.py +++ b/modelscan/reports.py @@ -1,9 +1,11 @@ import abc import logging -from typing import List, Optional +import json +from typing import Optional, Dict, Any from rich import print +from modelscan.modelscan import ModelScan from modelscan.error import Error from modelscan.issues import Issues, IssueSeverity @@ -20,10 +22,8 @@ def __init__(self) -> None: @staticmethod def generate( - issues: Issues, - errors: List[Error], - skipped: List[str], - show_skipped: bool = False, + scan: ModelScan, + settings: Dict[str, Any] = {}, ) -> Optional[str]: """ Generate report for the given codebase. @@ -39,14 +39,12 @@ def generate( class ConsoleReport(Report): @staticmethod def generate( - issues: Issues, - errors: List[Error], - skipped: List[str], - show_skipped: bool = False, + scan: ModelScan, + settings: Dict[str, Any] = {}, ) -> None: - issues_by_severity = issues.group_by_severity() + issues_by_severity = scan.issues.group_by_severity() print("\n[blue]--- Summary ---") - total_issue_count = len(issues.all_issues) + total_issue_count = len(scan.issues.all_issues) if total_issue_count > 0: print(f"\nTotal Issues: {total_issue_count}") print(f"\nTotal Issues By Severity:\n") @@ -66,18 +64,36 @@ def generate( else: print("\n[green] No issues found! 🎉") - if len(errors) > 0: + if len(scan.errors) > 0: print("\n[red]--- Errors --- ") - for index, error in enumerate(errors): + for index, error in enumerate(scan.errors): print(f"\nError {index+1}:") print(str(error)) - if len(skipped) > 0: + if len(scan.skipped) > 0: print("\n[blue]--- Skipped --- ") print( - f"\nTotal skipped: {len(skipped)} - run with --show-skipped to see the full list." + f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list." ) - if show_skipped: + if settings["show_skipped"]: print(f"\nSkipped files list:\n") - for file_name in skipped: + for file_name in scan.skipped: print(str(file_name)) + + +class JSONReport(Report): + @staticmethod + def generate( + scan: ModelScan, + settings: Dict[str, Any] = {}, + ) -> None: + report: Dict[str, Any] = scan._generate_results() + if not settings["show_skipped"]: + del report["skipped"] + + print(json.dumps(report)) + + output = settings["output_file"] + if output: + with open(output, "w") as outfile: + json.dump(report, outfile) diff --git a/modelscan/settings.py b/modelscan/settings.py index 9f92cce..c4bad2e 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -2,6 +2,10 @@ from typing import Any +DEFAULT_REPORTING_MODULES = { + "console": "modelscan.reports.ConsoleReport", + "json": "modelscan.reports.JSONReport", +} DEFAULT_SETTINGS = { "supported_zip_extensions": [".zip", ".npz"], @@ -87,6 +91,10 @@ "MEDIUM": {}, "LOW": {}, }, + "reporting": { + "module": "modelscan.reports.ConsoleReport", + "settings": {}, + }, # JSON reporting can be configured by changing "module" to "modelscan.reports.JSONReport" and adding an optional "output_file" field. For custom reporting modules, change "module" to the module name and add the applicable settings fields }