Skip to content

Commit

Permalink
Update Naming and Split SavedModelScan (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko authored Jan 22, 2024
1 parent ee254ab commit 8dfa996
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 112 deletions.
16 changes: 10 additions & 6 deletions modelscan/scanners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from modelscan.scanners.h5.scan import H5Scan
from modelscan.scanners.h5.scan import H5LambdaDetectScan
from modelscan.scanners.pickle.scan import (
PickleScan,
NumpyScan,
PyTorchScan,
PickleUnsafeOpScan,
NumpyUnsafeOpScan,
PyTorchUnsafeOpScan,
)
from modelscan.scanners.saved_model.scan import SavedModelScan
from modelscan.scanners.keras.scan import KerasScan
from modelscan.scanners.saved_model.scan import (
SavedModelScan,
SavedModelLambdaDetectScan,
SavedModelTensorflowOpScan,
)
from modelscan.scanners.keras.scan import KerasLambdaDetectScan
37 changes: 20 additions & 17 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@

from modelscan.error import ModelScanError
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelScan
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan

logger = logging.getLogger("modelscan")


class H5Scan(SavedModelScan):
class H5LambdaDetectScan(SavedModelLambdaDetectScan):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"]
in self._settings["scanners"][H5LambdaDetectScan.full_name()][
"supported_extensions"
]
):
return None

Expand All @@ -44,11 +46,13 @@ def scan(
def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults:
machine_learning_library_name = "Keras"
operators_in_model = self._get_keras_h5_operator_names(source)
return H5Scan._check_for_unsafe_tf_keras_operator(
return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
raw_operator=operators_in_model,
source=source,
settings=self._settings,
unsafe_operators=self._settings["scanners"][
SavedModelLambdaDetectScan.full_name()
]["unsafe_keras_operators"],
)

def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]:
Expand All @@ -73,21 +77,20 @@ def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]:

return []

@staticmethod
def name() -> str:
return "hdf5"

@staticmethod
def full_name() -> str:
return "modelscan.scanners.H5Scan"

@staticmethod
def handle_binary_dependencies(
settings: Optional[Dict[str, Any]] = None
self, settings: Optional[Dict[str, Any]] = None
) -> Optional[ModelScanError]:
if not h5py_installed:
return ModelScanError(
SavedModelScan.name(),
f"To use {H5Scan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.",
H5LambdaDetectScan.name(),
f"To use {H5LambdaDetectScan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.",
)
return None

@staticmethod
def name() -> str:
return "hdf5"

@staticmethod
def full_name() -> str:
return "modelscan.scanners.H5LambdaDetectScan"
20 changes: 12 additions & 8 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,23 @@

from modelscan.error import ModelScanError
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelScan
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan


logger = logging.getLogger("modelscan")


class KerasScan(SavedModelScan):
class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"]
in self._settings["scanners"][KerasLambdaDetectScan.full_name()][
"supported_extensions"
]
):
return None

Expand All @@ -46,7 +48,7 @@ def scan(
[],
[
ModelScanError(
KerasScan.name(),
KerasLambdaDetectScan.name(),
f"Skipping zip file {source}, due to error: {e}",
)
],
Expand All @@ -57,7 +59,7 @@ def scan(
[],
[
ModelScanError(
KerasScan.name(),
KerasLambdaDetectScan.name(),
f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
)
],
Expand All @@ -68,11 +70,13 @@ def _scan_keras_config_file(
) -> ScanResults:
machine_learning_library_name = "Keras"
operators_in_model = self._get_keras_operator_names(source, config_file)
return KerasScan._check_for_unsafe_tf_keras_operator(
return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
raw_operator=operators_in_model,
source=source,
settings=self._settings,
unsafe_operators=self._settings["scanners"][
SavedModelLambdaDetectScan.full_name()
]["unsafe_keras_operators"],
)

def _get_keras_operator_names(
Expand Down Expand Up @@ -101,4 +105,4 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.KerasScan"
return "modelscan.scanners.KerasLambdaDetectScan"
20 changes: 11 additions & 9 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
logger = logging.getLogger("modelscan")


class PyTorchScan(ScanBase):
class PyTorchUnsafeOpScan(ScanBase):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][PyTorchScan.full_name()][
in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][
"supported_extensions"
]
):
Expand All @@ -47,18 +47,20 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.PyTorchScan"
return "modelscan.scanners.PyTorchUnsafeOpScan"


class NumpyScan(ScanBase):
class NumpyUnsafeOpScan(ScanBase):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"]
in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][
"supported_extensions"
]
):
return None

Expand All @@ -76,18 +78,18 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.NumpyScan"
return "modelscan.scanners.NumpyUnsafeOpScan"


class PickleScan(ScanBase):
class PickleUnsafeOpScan(ScanBase):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][PickleScan.full_name()][
in self._settings["scanners"][PickleUnsafeOpScan.full_name()][
"supported_extensions"
]
):
Expand All @@ -112,4 +114,4 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.PickleScan"
return "modelscan.scanners.PickleUnsafeOpScan"
Loading

0 comments on commit 8dfa996

Please sign in to comment.