Skip to content

Commit

Permalink
* use model in each scanner
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter committed Mar 14, 2024
1 parent a997b6d commit 705d17f
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 201 deletions.
39 changes: 27 additions & 12 deletions modelscan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
from typing import List, Union, Optional, IO, Generator
from modelscan.tools.utils import _is_zipfile
import zipfile
from dataclasses import dataclass


class ModelPathNotValid(ValueError):
pass


class ModelDataEmpty(ValueError):
pass


class ModelBadZip(ValueError):
def __init__(self, e: zipfile.BadZipFile, source: str):
self.source = source
super().__init__(f"Bad Zip File: {e}")


class Model:
source: Path
data: Optional[IO[bytes]] = None
_source: Path
_data: Optional[IO[bytes]] = None

def __init__(self, source: Union[str, Path], data: Optional[IO[bytes]] = None):
self.source = Path(source)
self.data = data
self._source = Path(source)
self._data = data

@staticmethod
def from_path(path: Path) -> "Model":
Expand All @@ -31,25 +34,37 @@ def from_path(path: Path) -> "Model":
return Model(path)

def get_files(self) -> Generator["Model", None, None]:
if Path.is_dir(self.source):
for f in Path(self.source).rglob("*"):
if Path.is_dir(self._source):
for f in Path(self._source).rglob("*"):
if Path.is_file(f):
yield Model(f)

def get_zip_files(
self, supported_extensions: List[str]
) -> Generator["Model", None, None]:
if (
not _is_zipfile(self.source)
and Path(self.source).suffix not in supported_extensions
not _is_zipfile(self._source)
and Path(self._source).suffix not in supported_extensions
):
return

try:
with zipfile.ZipFile(self.source, "r") as zip:
with zipfile.ZipFile(self._source, "r") as zip:
file_names = zip.namelist()
for file_name in file_names:
with zip.open(file_name, "r") as file_io:
yield Model(f"{self.source}:{file_name}", file_io)
yield Model(f"{self._source}:{file_name}", file_io)
except zipfile.BadZipFile as e:
raise ModelBadZip(e, f"{self.source}:{file_name}")
raise ModelBadZip(e, f"{self._source}:{file_name}")

def get_source(self) -> Path:
return self._source

def has_data(self) -> bool:
return self._data is not None

def get_data(self) -> IO[bytes]:
if not self._data:
raise ModelDataEmpty("Model data is empty.")

return self._data
32 changes: 18 additions & 14 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def _scan_model(
)
except ModelBadZip as e:
logger.debug(
f"Skipping zip file {model.source}, due to error", e, exc_info=True
f"Skipping zip file {str(model.get_source())}, due to error",
e,
exc_info=True,
)
self._skipped.append(
ModelScanSkipped(
Expand All @@ -125,38 +127,43 @@ def _scan_model(
has_extracted = True
scanned = self._scan_source(extracted_model)
if not scanned:
if _is_zipfile(extracted_model.source, data=extracted_model.data):
if _is_zipfile(
extracted_model.get_source(),
data=extracted_model.get_data()
if extracted_model.has_data()
else None,
):
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.NESTED_ZIP,
"ModelScan does not support nested zip files.",
str(extracted_model.source),
str(extracted_model.get_source()),
)
)

# check if added to skipped already
all_skipped_files = [skipped.source for skipped in self._skipped]
if str(extracted_model.source) not in all_skipped_files:
if str(extracted_model.get_source()) not in all_skipped_files:
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
str(extracted_model.source),
str(extracted_model.get_source()),
)
)

if not scanned and not has_extracted:
# check if added to skipped already
all_skipped_files = [skipped.source for skipped in self._skipped]
if str(model.source) not in all_skipped_files:
if str(model.get_source()) not in all_skipped_files:
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
str(model.source),
str(model.get_source()),
)
)

Expand All @@ -167,26 +174,23 @@ def _scan_source(
scanned = False
for scan_class in self._scanners_to_run:
scanner = scan_class(self._settings) # type: ignore[operator]
scan_results = scanner.scan(
source=model.source,
data=model.data,
)
scan_results = scanner.scan(model)

if scan_results is not None:
scanned = True
logger.info(
f"Scanning {model.source} using {scanner.full_name()} model scan"
f"Scanning {model.get_source()} using {scanner.full_name()} model scan"
)
if scan_results.errors:
self._errors.extend(scan_results.errors)
elif scan_results.issues:
self._scanned.append(str(model.source))
self._scanned.append(str(model.get_source()))
self._issues.add_issues(scan_results.issues)

elif scan_results.skipped:
self._skipped.extend(scan_results.skipped)
else:
self._scanned.append(str(model.source))
self._scanned.append(str(model.get_source()))

return scanned

Expand Down
47 changes: 23 additions & 24 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import logging
from pathlib import Path
from typing import IO, List, Union, Optional, Dict, Any
from typing import List, Optional, Dict, Any


try:
Expand All @@ -15,18 +14,18 @@
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model

logger = logging.getLogger("modelscan")


class H5LambdaDetectScan(SavedModelLambdaDetectScan):
def scan(
self,
source: Union[str, Path],
data: Optional[IO[bytes]] = None,
model: Model,
) -> Optional[ScanResults]:
if (
not Path(source).suffix
not model.get_source().suffix
in self._settings["scanners"][H5LambdaDetectScan.full_name()][
"supported_extensions"
]
Expand All @@ -46,7 +45,7 @@ def scan(
[],
)

if data:
if model.has_data():
logger.warning(
f"{self.full_name()} got data bytes. It only support direct file scanning."
)
Expand All @@ -58,21 +57,21 @@ def scan(
self.name(),
SkipCategories.H5_DATA,
f"{self.full_name()} got data bytes. It only support direct file scanning.",
str(source),
str(model.get_source()),
)
],
)

results = self._scan_keras_h5_file(source)
results = self._scan_keras_h5_file(model)
if results:
return self.label_results(results)
else:
return None

def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]:
return None

def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
machine_learning_library_name = "Keras"
if self._check_model_config(source):
operators_in_model = self._get_keras_h5_operator_names(source)
if self._check_model_config(model):
operators_in_model = self._get_keras_h5_operator_names(model)
if operators_in_model is None:
return None

Expand All @@ -84,15 +83,15 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
str(source),
str(model.get_source()),
)
],
[],
)
return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
raw_operator=operators_in_model,
source=source,
model=model,
unsafe_operators=self._settings["scanners"][
SavedModelLambdaDetectScan.full_name()
]["unsafe_keras_operators"],
Expand All @@ -106,25 +105,23 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]
self.name(),
SkipCategories.MODEL_CONFIG,
f"Model Config not found",
str(source),
str(model.get_source()),
)
],
)

def _check_model_config(self, source: Union[str, Path]) -> bool:
with h5py.File(source, "r") as model_hdf5:
def _check_model_config(self, model: Model) -> bool:
with h5py.File(model.get_source(), "r") as model_hdf5:
if "model_config" in model_hdf5.attrs.keys():
return True
else:
logger.error(f"Model Config not found in: {source}")
logger.error(f"Model Config not found in: {model.get_source()}")
return False

def _get_keras_h5_operator_names(
self, source: Union[str, Path]
) -> Optional[List[Any]]:
def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]:
# Todo: source isn't guaranteed to be a file

with h5py.File(source, "r") as model_hdf5:
with h5py.File(model.get_source(), "r") as model_hdf5:
try:
if not "model_config" in model_hdf5.attrs.keys():
return None
Expand All @@ -138,7 +135,9 @@ def _get_keras_h5_operator_names(
layer.get("config", {}).get("function", {})
)
except json.JSONDecodeError as e:
logger.error(f"Not a valid JSON data from source: {source}, error: {e}")
logger.error(
f"Not a valid JSON data from source: {model.get_source()}, error: {e}"
)
return ["JSONDecodeError"]

if lambda_layers:
Expand Down
Loading

0 comments on commit 705d17f

Please sign in to comment.