Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Naming and Split SavedModelScan #83

Merged
merged 5 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions modelscan/scanners/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from modelscan.scanners.h5.scan import H5Scan
from modelscan.scanners.h5.scan import H5LambdaDetectScan
from modelscan.scanners.pickle.scan import (
PickleScan,
NumpyScan,
PyTorchScan,
PickleUnsafeOpScan,
NumpyUnsafeOpScan,
PyTorchUnsafeOpScan,
)
from modelscan.scanners.saved_model.scan import SavedModelScan
from modelscan.scanners.keras.scan import KerasScan
from modelscan.scanners.saved_model.scan import (
SavedModelScan,
SavedModelLambdaDetectScan,
SavedModelTensorflowOpScan,
)
from modelscan.scanners.keras.scan import KerasLambdaDetectScan
37 changes: 20 additions & 17 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@

from modelscan.error import ModelScanError
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelScan
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan

logger = logging.getLogger("modelscan")


class H5Scan(SavedModelScan):
class H5LambdaDetectScan(SavedModelLambdaDetectScan):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"]
in self._settings["scanners"][H5LambdaDetectScan.full_name()][
"supported_extensions"
]
):
return None

Expand All @@ -44,11 +46,13 @@ def scan(
def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults:
machine_learning_library_name = "Keras"
operators_in_model = self._get_keras_h5_operator_names(source)
return H5Scan._check_for_unsafe_tf_keras_operator(
return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
raw_operator=operators_in_model,
source=source,
settings=self._settings,
unsafe_operators=self._settings["scanners"][
SavedModelLambdaDetectScan.full_name()
]["unsafe_keras_operators"],
)

def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]:
Expand All @@ -73,21 +77,20 @@ def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]:

return []

@staticmethod
def name() -> str:
return "hdf5"

@staticmethod
def full_name() -> str:
return "modelscan.scanners.H5Scan"

@staticmethod
def handle_binary_dependencies(
settings: Optional[Dict[str, Any]] = None
self, settings: Optional[Dict[str, Any]] = None
) -> Optional[ModelScanError]:
if not h5py_installed:
return ModelScanError(
SavedModelScan.name(),
f"To use {H5Scan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.",
H5LambdaDetectScan.name(),
f"To use {H5LambdaDetectScan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.",
)
return None

@staticmethod
def name() -> str:
return "hdf5"

@staticmethod
def full_name() -> str:
return "modelscan.scanners.H5LambdaDetectScan"
20 changes: 12 additions & 8 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,23 @@

from modelscan.error import ModelScanError
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelScan
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan


logger = logging.getLogger("modelscan")


class KerasScan(SavedModelScan):
class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"]
in self._settings["scanners"][KerasLambdaDetectScan.full_name()][
"supported_extensions"
]
):
return None

Expand All @@ -46,7 +48,7 @@ def scan(
[],
[
ModelScanError(
KerasScan.name(),
KerasLambdaDetectScan.name(),
f"Skipping zip file {source}, due to error: {e}",
)
],
Expand All @@ -57,7 +59,7 @@ def scan(
[],
[
ModelScanError(
KerasScan.name(),
KerasLambdaDetectScan.name(),
f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
)
],
Expand All @@ -68,11 +70,13 @@ def _scan_keras_config_file(
) -> ScanResults:
machine_learning_library_name = "Keras"
operators_in_model = self._get_keras_operator_names(source, config_file)
return KerasScan._check_for_unsafe_tf_keras_operator(
return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
raw_operator=operators_in_model,
source=source,
settings=self._settings,
unsafe_operators=self._settings["scanners"][
SavedModelLambdaDetectScan.full_name()
]["unsafe_keras_operators"],
)

def _get_keras_operator_names(
Expand Down Expand Up @@ -101,4 +105,4 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.KerasScan"
return "modelscan.scanners.KerasLambdaDetectScan"
20 changes: 11 additions & 9 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
logger = logging.getLogger("modelscan")


class PyTorchScan(ScanBase):
class PyTorchUnsafeOpScan(ScanBase):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][PyTorchScan.full_name()][
in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][
"supported_extensions"
]
):
Expand All @@ -47,18 +47,20 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.PyTorchScan"
return "modelscan.scanners.PyTorchUnsafeOpScan"


class NumpyScan(ScanBase):
class NumpyUnsafeOpScan(ScanBase):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"]
in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][
"supported_extensions"
]
):
return None

Expand All @@ -76,18 +78,18 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.NumpyScan"
return "modelscan.scanners.NumpyUnsafeOpScan"


class PickleScan(ScanBase):
class PickleUnsafeOpScan(ScanBase):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings["scanners"][PickleScan.full_name()][
in self._settings["scanners"][PickleUnsafeOpScan.full_name()][
"supported_extensions"
]
):
Expand All @@ -112,4 +114,4 @@ def name() -> str:

@staticmethod
def full_name() -> str:
return "modelscan.scanners.PickleScan"
return "modelscan.scanners.PickleUnsafeOpScan"
Loading
Loading