Skip to content

Commit

Permalink
Remove Hugging Face Support
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani committed Nov 21, 2023
1 parent 045ddb5 commit 0c30962
Show file tree
Hide file tree
Showing 10 changed files with 647 additions and 760 deletions.
13 changes: 3 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Despite this, models are not scanned with the rigor of a PDF file in your inbox.

This needs to change, and proper tooling is the first step.

![ModelScan Preview](https://github.com/protectai/modelscan/raw/main/imgs/modelscan_hf_demo.gif)
![ModelScan Preview](/imgs/modelscan-unsafe-model.gif)

ModelScan is an open source project that scans models to determine if they contain
unsafe code. It is the first model scanning tool to support multiple model formats.
Expand All @@ -34,12 +34,6 @@ With it installed, scan a model:
modelscan -p /path/to/model_file.h5
```

Or if you want to scan all the models in a Hugging Face Repository:

```bash
modelscan -hf owner/model-repository-name
```

## Why You Should Scan Models

Models are often created from automated pipelines, others may come from a data scientist’s laptop. In either case the model needs to move from one machine to another before it is used. That process of saving a model to disk is called serialization.
Expand Down Expand Up @@ -75,8 +69,7 @@ ModelScan ranks the unsafe code as:
* MEDIUM
* LOW

Things are ranked consistently if the models are local or stored in Hugging Face
![ModelScan Flow Chart](https://github.com/protectai/modelscan/raw/main/imgs/model_scan_flow_chart.png)
![ModelScan Flow Chart](/imgs/model_scan_flow_chart.png)

If an issue is detected, reach out to the author's of the model immediately to determine the cause.

Expand Down Expand Up @@ -123,7 +116,7 @@ ModelScan supports the following arguments via the CLI:
| ```modelscan -h ``` | -h or --help | View usage help |
| ```modelscan -v ``` | -v or --version | View version information |
| ```modelscan -p /path/to/model_file```| -p or --path | Scan a locally stored model |
| ```modelscan -hf repo/model_file``` |-hf or --huggingface | Scan all the models in a Hugging Face model repository|


Remember models are just like any other form of digital media, you should scan content from any untrusted source before use.

Expand Down
Binary file modified imgs/model_scan_flow_chart.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/modelscan-unsafe-model.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed imgs/modelscan_hf_demo.gif
Binary file not shown.
16 changes: 3 additions & 13 deletions modelscan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,8 @@ def format_usage(self, ctx: click.Context, formatter: click.HelpFormatter) -> No
# @click.option(
# "-u", "--url", type=str, default=None, help="URL to the file or folder to scan"
# )
@click.option(
"-hf",
"--huggingface",
type=str,
default=None,
help="Name of the Hugging Face model to scan",
)


@click.option(
"-l",
"--log",
Expand All @@ -63,7 +58,6 @@ def cli(
ctx: click.Context,
log: str,
# url: Optional[str],
huggingface: Optional[str],
path: Optional[str],
show_skipped: bool,
) -> int:
Expand All @@ -82,12 +76,8 @@ def cli(
modelscan.scan_path(pathlibPath)
# elif url is not None:
# modelscan.scan_url(url)
elif huggingface is not None:
modelscan.scan_huggingface_model(huggingface)
else:
raise click.UsageError(
"Command line must include either a path or a Hugging Face model"
)
raise click.UsageError("Command line must include a path")
ConsoleReport.generate(
modelscan.issues,
modelscan.errors,
Expand Down
48 changes: 1 addition & 47 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from modelscan.models.keras.scan import KerasScan
from modelscan.models.scan import ScanBase
from modelscan.tools.utils import _is_zipfile
from modelscan.utils import fetch_huggingface_repo_files, read_huggingface_file


logger = logging.getLogger("modelscan")

Expand Down Expand Up @@ -61,52 +61,6 @@ def _scan_directory(self, directory_path: Path) -> None:
if not path.is_dir():
self.scan_path(path)

def scan_huggingface_model(self, repo_id: str) -> None:
# List model files
file_names = fetch_huggingface_repo_files(repo_id)
if not file_names:
logger.error(f"Hugging face repo {repo_id} didn't return any files")
return

for file_name in file_names:
fq_file_name = f"{repo_id}/{file_name}"
file_ext = os.path.splitext(file_name)[1]
if (
file_ext not in self.supported_extensions
and file_ext not in self.supported_zip_extensions
):
logger.debug(f"Skipping: {fq_file_name} is not supported")
self._skipped.append(fq_file_name)
continue

raw_bytes = read_huggingface_file(repo_id, file_name)
if not raw_bytes:
logger.debug(f"Skipping: Failed to read {fq_file_name}")
self._skipped.append(fq_file_name)
continue

data = io.BytesIO(raw_bytes)
if (
_is_zipfile(source=fq_file_name, data=data)
or file_ext in self.supported_zip_extensions
):
try:
self._scan_zip(source=fq_file_name, data=data)
except:
logger.debug(f"Skipped: Failed to read")
else:
self._scan_source(source=fq_file_name, extension=file_ext, data=data)

def scan_url(self, url: str) -> None:
# Todo: before it was just scanning scanning_pickle_bytes
# We need to validate this url and determine what type of file it is
# self._scan_bytes(
# data=io.BytesIO(_http_get(url)),
# source=url,
# extension=file_ext,
# )
pass

def _scan_source(
self,
source: Union[str, Path],
Expand Down
30 changes: 1 addition & 29 deletions modelscan/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import requests
import logging
import json
from typing import Union, Optional, Generator, List, Tuple
from typing import Union

logger = logging.getLogger("modelscan")

Expand All @@ -24,30 +23,3 @@ def fetch_url(url: str, allow_redirect: bool = False) -> Union[None, bytes]:
except Exception as e:
logger.error("Unexpected error during request to %s: %s", url, str(e))
return None


def fetch_huggingface_repo_files(
repo_id: str,
) -> Union[None, List[str]]:
# Return list of model files
url = f"https://huggingface.co/api/models/{repo_id}"
data = fetch_url(url)
if not data:
return None

try:
model = json.loads(data.decode("utf-8"))
filenames = []
for sibling in model.get("siblings", []):
if sibling.get("rfilename"):
filenames.append(sibling.get("rfilename"))
return filenames
except json.decoder.JSONDecodeError as e:
logger.error(f"Failed to parse response for HuggingFace model repo {repo_id}")

return None


def read_huggingface_file(repo_id: str, file_name: str) -> Union[None, bytes]:
url = f"https://huggingface.co/{repo_id}/resolve/main/{file_name}"
return fetch_url(url, True)
Loading

0 comments on commit 0c30962

Please sign in to comment.