Skip to content

Commit

Permalink
* updates to the model
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter committed Mar 13, 2024
1 parent 80fc9db commit a997b6d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
9 changes: 6 additions & 3 deletions modelscan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ def __init__(self, e: zipfile.BadZipFile, source: str):
super().__init__(f"Bad Zip File: {e}")


@dataclass
class Model:
source: Union[str, Path]
source: Path
data: Optional[IO[bytes]] = None

def __init__(self, source: Union[str, Path], data: Optional[IO[bytes]] = None):
self.source = Path(source)
self.data = data

@staticmethod
def from_path(path: Path) -> "Model":
if not Path.exists(path):
Expand All @@ -38,7 +41,7 @@ def get_zip_files(
) -> Generator["Model", None, None]:
if (
not _is_zipfile(self.source)
or Path(self.source).suffix not in supported_extensions
and Path(self.source).suffix not in supported_extensions
):
return

Expand Down
45 changes: 22 additions & 23 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,36 +72,35 @@ def scan(
self._scanned = []
self._input_path = str(path)
pathlibPath = Path().cwd() if path == "." else Path(path).absolute()
self._scan_path(Path(pathlibPath))
return self._generate_results()

def _scan_path(
self,
path: Path,
) -> None:
try:
model = Model.from_path(path)
model = Model.from_path(Path(pathlibPath))
self._scan_model(model)
except ModelPathNotValid as e:
logger.exception(e)
self._errors.append(
ModelScanError(
"ModelScan", ErrorCategories.PATH, "Path is not valid", str(path)
)
)
return

return self._generate_results()

def _scan_model(
self,
model: Model,
) -> None:
scanned = self._scan_source(model)
if scanned:
return

has_extracted = False
extracted_models = model.get_files()
for extracted_model in extracted_models:
has_extracted = True
self._scan_source(extracted_model)
if not scanned:
extracted_models = model.get_files()
for extracted_model in extracted_models:
has_extracted = True
self._scan_model(extracted_model)

if has_extracted:
return
if has_extracted:
return

try:
extracted_models = model.get_zip_files(
Expand All @@ -121,43 +120,43 @@ def _scan_path(
)
return

has_extracted = False
for extracted_model in extracted_models:
has_extracted = True
scanned = self._scan_source(extracted_model)

if not scanned:
if _is_zipfile(extracted_model.source, data=extracted_model.data):
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.NESTED_ZIP,
"ModelScan does not support nested zip files.",
extracted_model.source,
str(extracted_model.source),
)
)

# check if added to skipped already
all_skipped_files = [skipped.source for skipped in self._skipped]
if extracted_model.source not in all_skipped_files:
if str(extracted_model.source) not in all_skipped_files:
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
extracted_model.source,
str(extracted_model.source),
)
)

if not has_extracted:
if not scanned and not has_extracted:
# check if added to skipped already
all_skipped_files = [skipped.source for skipped in self._skipped]
if str(path) not in all_skipped_files:
if str(model.source) not in all_skipped_files:
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
str(path),
str(model.source),
)
)

Expand Down

0 comments on commit a997b6d

Please sign in to comment.