From ee812c204aa01176423255e9b0ece4659d2f7133 Mon Sep 17 00:00:00 2001 From: Oleksandr Yaremchuk Date: Mon, 18 Mar 2024 19:07:08 +0100 Subject: [PATCH] Introduce middlewares (#119) --- modelscan/middlewares/__init__.py | 1 + modelscan/middlewares/format_via_extension.py | 17 ++++++ modelscan/middlewares/middleware.py | 59 +++++++++++++++++++ modelscan/model.py | 15 ++++- modelscan/modelscan.py | 19 ++++++ modelscan/scanners/h5/scan.py | 8 +-- modelscan/scanners/keras/scan.py | 7 +-- modelscan/scanners/pickle/scan.py | 21 +------ modelscan/scanners/saved_model/scan.py | 5 +- modelscan/settings.py | 34 ++++++----- 10 files changed, 135 insertions(+), 51 deletions(-) create mode 100644 modelscan/middlewares/__init__.py create mode 100644 modelscan/middlewares/format_via_extension.py create mode 100644 modelscan/middlewares/middleware.py diff --git a/modelscan/middlewares/__init__.py b/modelscan/middlewares/__init__.py new file mode 100644 index 0000000..e6dcece --- /dev/null +++ b/modelscan/middlewares/__init__.py @@ -0,0 +1 @@ +from modelscan.middlewares.format_via_extension import FormatViaExtensionMiddleware diff --git a/modelscan/middlewares/format_via_extension.py b/modelscan/middlewares/format_via_extension.py new file mode 100644 index 0000000..2150acf --- /dev/null +++ b/modelscan/middlewares/format_via_extension.py @@ -0,0 +1,17 @@ +from .middleware import MiddlewareBase +from modelscan.model import Model +from typing import Callable + + +class FormatViaExtensionMiddleware(MiddlewareBase): + def __call__(self, model: Model, call_next: Callable[[Model], None]) -> None: + extension = model.get_source().suffix + formats = [ + format + for format, extensions in self._settings["formats"].items() + if extension in extensions + ] + if len(formats) > 0: + model.set_context("formats", model.get_context("formats") or [] + formats) + + call_next(model) diff --git a/modelscan/middlewares/middleware.py b/modelscan/middlewares/middleware.py new file mode 100644 index 0000000..613c895 --- /dev/null +++ b/modelscan/middlewares/middleware.py @@ -0,0 +1,59 @@ +import abc +from modelscan.model import Model +from typing import Callable, Dict, Any, List +import importlib + + +class MiddlewareImportError(Exception): + pass + + +class MiddlewareBase(metaclass=abc.ABCMeta): + _settings: Dict[str, Any] + + def __init__(self, settings: Dict[str, Any]): + self._settings = settings + + @abc.abstractmethod + def __call__( + self, + model: Model, + call_next: Callable[[Model], None], + ) -> None: + raise NotImplementedError + + +class MiddlewarePipeline: + _middlewares: List[MiddlewareBase] + + def __init__(self) -> None: + self._middlewares = [] + + @staticmethod + def from_settings(middleware_settings: Dict[str, Any]) -> "MiddlewarePipeline": + pipeline = MiddlewarePipeline() + + for path, params in middleware_settings.items(): + try: + (modulename, classname) = path.rsplit(".", 1) + imported_module = importlib.import_module( + name=modulename, package=classname + ) + + middleware_class: MiddlewareBase = getattr(imported_module, classname) + pipeline.add_middleware(middleware_class(params)) # type: ignore + except Exception as e: + raise MiddlewareImportError(f"Error importing middleware {path}: {e}") + + return pipeline + + def add_middleware(self, middleware: MiddlewareBase) -> "MiddlewarePipeline": + self._middlewares.append(middleware) + return self + + def run(self, model: Model) -> None: + def runner(model: Model, index: int) -> None: + if index < len(self._middlewares): + self._middlewares[index](model, lambda model: runner(model, index + 1)) + + runner(model, 0) diff --git a/modelscan/model.py b/modelscan/model.py index cdf3ab0..f3950ea 100644 --- a/modelscan/model.py +++ b/modelscan/model.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Union, Optional, IO, Generator +from typing import List, Union, Optional, IO, Generator, Dict, Any from modelscan.tools.utils import _is_zipfile import zipfile @@ -24,12 +24,15 @@ def __init__(self, e: zipfile.BadZipFile, source: str): class Model: _source: Path - _stream: Optional[IO[bytes]] = None - _source_file_used: bool = False + _stream: Optional[IO[bytes]] + _source_file_used: bool + _context: Dict[str, Any] def __init__(self, source: Union[str, Path], stream: Optional[IO[bytes]] = None): self._source = Path(source) self._stream = stream + self._source_file_used = False + self._context = {"formats": []} @staticmethod def from_path(path: Path) -> "Model": @@ -41,6 +44,12 @@ def from_path(path: Path) -> "Model": return Model(path) + def set_context(self, key: str, value: Any) -> None: + self._context[key] = value + + def get_context(self, key: str) -> Any: + return self._context.get(key) + def open(self) -> "Model": if self._stream: return self diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 72cdb6c..5739421 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -14,6 +14,7 @@ from modelscan._version import __version__ from modelscan.tools.utils import _is_zipfile from modelscan.model import Model, ModelPathNotValid, ModelBadZip, ModelIsDir +from modelscan.middlewares.middleware import MiddlewarePipeline, MiddlewareImportError logger = logging.getLogger("modelscan") @@ -35,6 +36,22 @@ def __init__( self._scanners_to_run: List[ScanBase] = [] self._settings: Dict[str, Any] = settings self._load_scanners() + self._load_middlewares() + + def _load_middlewares(self) -> None: + try: + self._middleware_pipeline = MiddlewarePipeline.from_settings( + self._settings["middlewares"] or {} + ) + except MiddlewareImportError as e: + logger.exception(e) + self._init_errors.append( + ModelScanError( + "MiddlewarePipeline", + ErrorCategories.MODEL_SCAN, + f"Error loading middlewares: {e}", + ) + ) def _load_scanners(self) -> None: for scanner_path, scanner_settings in self._settings["scanners"].items(): @@ -173,6 +190,8 @@ def _scan_source( self, model: Model, ) -> bool: + self._middleware_pipeline.run(model) + scanned = False for scan_class in self._scanners_to_run: scanner = scan_class(self._settings) # type: ignore[operator] diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index ac1c197..d8452e6 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -24,13 +24,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if ( - not model.get_source().suffix - in self._settings["scanners"][H5LambdaDetectScan.full_name()][ - "supported_extensions" - ] - ): + if "keras_h5" not in model.get_context("formats"): return None + dep_error = self.handle_binary_dependencies() if dep_error: return ScanResults( diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index b9b276b..08f897e 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -17,12 +17,7 @@ class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: - if ( - not model.get_source().suffix - in self._settings["scanners"][KerasLambdaDetectScan.full_name()][ - "supported_extensions" - ] - ): + if "keras" not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index f9b4105..d138202 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -18,12 +18,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if ( - not model.get_source().suffix - in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][ - "supported_extensions" - ] - ): + if "pytorch" not in model.get_context("formats"): return None if _is_zipfile(model.get_source(), model.get_stream()): @@ -50,12 +45,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if ( - not model.get_source().suffix - in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][ - "supported_extensions" - ] - ): + if "numpy" not in model.get_context("formats"): return None results = scan_numpy( @@ -79,12 +69,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if ( - not model.get_source().suffix - in self._settings["scanners"][PickleUnsafeOpScan.full_name()][ - "supported_extensions" - ] - ): + if "pickle" not in model.get_context("formats"): return None results = scan_pickle_bytes( diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index eb57395..c78bb73 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -29,10 +29,7 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if ( - not model.get_source().suffix - in self._settings["scanners"][self.full_name()]["supported_extensions"] - ): + if "tf_saved_model" not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/settings.py b/modelscan/settings.py index 8041ffe..f4bfa9e 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -15,22 +15,18 @@ "scanners": { "modelscan.scanners.H5LambdaDetectScan": { "enabled": True, - "supported_extensions": [".h5"], }, "modelscan.scanners.KerasLambdaDetectScan": { "enabled": True, - "supported_extensions": [".keras"], }, "modelscan.scanners.SavedModelLambdaDetectScan": { "enabled": True, - "supported_extensions": [".pb"], "unsafe_keras_operators": { "Lambda": "MEDIUM", }, }, "modelscan.scanners.SavedModelTensorflowOpScan": { "enabled": True, - "supported_extensions": [".pb"], "unsafe_tf_operators": { "ReadFile": "HIGH", "WriteFile": "HIGH", @@ -38,24 +34,34 @@ }, "modelscan.scanners.NumpyUnsafeOpScan": { "enabled": True, - "supported_extensions": [".npy"], }, "modelscan.scanners.PickleUnsafeOpScan": { "enabled": True, - "supported_extensions": [ - ".pkl", - ".pickle", - ".joblib", - ".dill", - ".dat", - ".data", - ], }, "modelscan.scanners.PyTorchUnsafeOpScan": { "enabled": True, - "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], }, }, + "middlewares": { + "modelscan.middlewares.FormatViaExtensionMiddleware": { + "formats": { + "tf": [".pb"], + "tf_saved_model": [".pb"], + "keras_h5": [".h5"], + "keras": [".keras"], + "numpy": [".npy"], + "pytorch": [".bin", ".pt", ".pth", ".ckpt"], + "pickle": [ + ".pkl", + ".pickle", + ".joblib", + ".dill", + ".dat", + ".data", + ], + } + } + }, "unsafe_globals": { "CRITICAL": { "__builtin__": [