Skip to content

Commit

Permalink
Introduce middlewares (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter authored Mar 18, 2024
1 parent a1363ec commit ee812c2
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 51 deletions.
1 change: 1 addition & 0 deletions modelscan/middlewares/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from modelscan.middlewares.format_via_extension import FormatViaExtensionMiddleware
17 changes: 17 additions & 0 deletions modelscan/middlewares/format_via_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .middleware import MiddlewareBase
from modelscan.model import Model
from typing import Callable


class FormatViaExtensionMiddleware(MiddlewareBase):
def __call__(self, model: Model, call_next: Callable[[Model], None]) -> None:
extension = model.get_source().suffix
formats = [
format
for format, extensions in self._settings["formats"].items()
if extension in extensions
]
if len(formats) > 0:
model.set_context("formats", model.get_context("formats") or [] + formats)

call_next(model)
59 changes: 59 additions & 0 deletions modelscan/middlewares/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import abc
from modelscan.model import Model
from typing import Callable, Dict, Any, List
import importlib


class MiddlewareImportError(Exception):
pass


class MiddlewareBase(metaclass=abc.ABCMeta):
_settings: Dict[str, Any]

def __init__(self, settings: Dict[str, Any]):
self._settings = settings

@abc.abstractmethod
def __call__(
self,
model: Model,
call_next: Callable[[Model], None],
) -> None:
raise NotImplementedError


class MiddlewarePipeline:
_middlewares: List[MiddlewareBase]

def __init__(self) -> None:
self._middlewares = []

@staticmethod
def from_settings(middleware_settings: Dict[str, Any]) -> "MiddlewarePipeline":
pipeline = MiddlewarePipeline()

for path, params in middleware_settings.items():
try:
(modulename, classname) = path.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)

middleware_class: MiddlewareBase = getattr(imported_module, classname)
pipeline.add_middleware(middleware_class(params)) # type: ignore
except Exception as e:
raise MiddlewareImportError(f"Error importing middleware {path}: {e}")

return pipeline

def add_middleware(self, middleware: MiddlewareBase) -> "MiddlewarePipeline":
self._middlewares.append(middleware)
return self

def run(self, model: Model) -> None:
def runner(model: Model, index: int) -> None:
if index < len(self._middlewares):
self._middlewares[index](model, lambda model: runner(model, index + 1))

runner(model, 0)
15 changes: 12 additions & 3 deletions modelscan/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List, Union, Optional, IO, Generator
from typing import List, Union, Optional, IO, Generator, Dict, Any
from modelscan.tools.utils import _is_zipfile
import zipfile

Expand All @@ -24,12 +24,15 @@ def __init__(self, e: zipfile.BadZipFile, source: str):

class Model:
_source: Path
_stream: Optional[IO[bytes]] = None
_source_file_used: bool = False
_stream: Optional[IO[bytes]]
_source_file_used: bool
_context: Dict[str, Any]

def __init__(self, source: Union[str, Path], stream: Optional[IO[bytes]] = None):
self._source = Path(source)
self._stream = stream
self._source_file_used = False
self._context = {"formats": []}

@staticmethod
def from_path(path: Path) -> "Model":
Expand All @@ -41,6 +44,12 @@ def from_path(path: Path) -> "Model":

return Model(path)

def set_context(self, key: str, value: Any) -> None:
self._context[key] = value

def get_context(self, key: str) -> Any:
return self._context.get(key)

def open(self) -> "Model":
if self._stream:
return self
Expand Down
19 changes: 19 additions & 0 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from modelscan._version import __version__
from modelscan.tools.utils import _is_zipfile
from modelscan.model import Model, ModelPathNotValid, ModelBadZip, ModelIsDir
from modelscan.middlewares.middleware import MiddlewarePipeline, MiddlewareImportError

logger = logging.getLogger("modelscan")

Expand All @@ -35,6 +36,22 @@ def __init__(
self._scanners_to_run: List[ScanBase] = []
self._settings: Dict[str, Any] = settings
self._load_scanners()
self._load_middlewares()

def _load_middlewares(self) -> None:
try:
self._middleware_pipeline = MiddlewarePipeline.from_settings(
self._settings["middlewares"] or {}
)
except MiddlewareImportError as e:
logger.exception(e)
self._init_errors.append(
ModelScanError(
"MiddlewarePipeline",
ErrorCategories.MODEL_SCAN,
f"Error loading middlewares: {e}",
)
)

def _load_scanners(self) -> None:
for scanner_path, scanner_settings in self._settings["scanners"].items():
Expand Down Expand Up @@ -173,6 +190,8 @@ def _scan_source(
self,
model: Model,
) -> bool:
self._middleware_pipeline.run(model)

scanned = False
for scan_class in self._scanners_to_run:
scanner = scan_class(self._settings) # type: ignore[operator]
Expand Down
8 changes: 2 additions & 6 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if (
not model.get_source().suffix
in self._settings["scanners"][H5LambdaDetectScan.full_name()][
"supported_extensions"
]
):
if "keras_h5" not in model.get_context("formats"):
return None

dep_error = self.handle_binary_dependencies()
if dep_error:
return ScanResults(
Expand Down
7 changes: 1 addition & 6 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@

class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(self, model: Model) -> Optional[ScanResults]:
if (
not model.get_source().suffix
in self._settings["scanners"][KerasLambdaDetectScan.full_name()][
"supported_extensions"
]
):
if "keras" not in model.get_context("formats"):
return None

dep_error = self.handle_binary_dependencies()
Expand Down
21 changes: 3 additions & 18 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if (
not model.get_source().suffix
in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][
"supported_extensions"
]
):
if "pytorch" not in model.get_context("formats"):
return None

if _is_zipfile(model.get_source(), model.get_stream()):
Expand All @@ -50,12 +45,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if (
not model.get_source().suffix
in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][
"supported_extensions"
]
):
if "numpy" not in model.get_context("formats"):
return None

results = scan_numpy(
Expand All @@ -79,12 +69,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if (
not model.get_source().suffix
in self._settings["scanners"][PickleUnsafeOpScan.full_name()][
"supported_extensions"
]
):
if "pickle" not in model.get_context("formats"):
return None

results = scan_pickle_bytes(
Expand Down
5 changes: 1 addition & 4 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if (
not model.get_source().suffix
in self._settings["scanners"][self.full_name()]["supported_extensions"]
):
if "tf_saved_model" not in model.get_context("formats"):
return None

dep_error = self.handle_binary_dependencies()
Expand Down
34 changes: 20 additions & 14 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,53 @@
"scanners": {
"modelscan.scanners.H5LambdaDetectScan": {
"enabled": True,
"supported_extensions": [".h5"],
},
"modelscan.scanners.KerasLambdaDetectScan": {
"enabled": True,
"supported_extensions": [".keras"],
},
"modelscan.scanners.SavedModelLambdaDetectScan": {
"enabled": True,
"supported_extensions": [".pb"],
"unsafe_keras_operators": {
"Lambda": "MEDIUM",
},
},
"modelscan.scanners.SavedModelTensorflowOpScan": {
"enabled": True,
"supported_extensions": [".pb"],
"unsafe_tf_operators": {
"ReadFile": "HIGH",
"WriteFile": "HIGH",
},
},
"modelscan.scanners.NumpyUnsafeOpScan": {
"enabled": True,
"supported_extensions": [".npy"],
},
"modelscan.scanners.PickleUnsafeOpScan": {
"enabled": True,
"supported_extensions": [
".pkl",
".pickle",
".joblib",
".dill",
".dat",
".data",
],
},
"modelscan.scanners.PyTorchUnsafeOpScan": {
"enabled": True,
"supported_extensions": [".bin", ".pt", ".pth", ".ckpt"],
},
},
"middlewares": {
"modelscan.middlewares.FormatViaExtensionMiddleware": {
"formats": {
"tf": [".pb"],
"tf_saved_model": [".pb"],
"keras_h5": [".h5"],
"keras": [".keras"],
"numpy": [".npy"],
"pytorch": [".bin", ".pt", ".pth", ".ckpt"],
"pickle": [
".pkl",
".pickle",
".joblib",
".dill",
".dat",
".data",
],
}
}
},
"unsafe_globals": {
"CRITICAL": {
"__builtin__": [
Expand Down

0 comments on commit ee812c2

Please sign in to comment.