Skip to content

Commit

Permalink
Improve support for scanning HuggingFace models. (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
iamfaisalkhan authored Aug 23, 2023
1 parent a52bc33 commit 19b0429
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 190 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ repos:
hooks:
- id: mypy
args: ["--ignore-missing-imports", "--strict", "--check-untyped-defs"]
additional_dependencies: ["click>=8.1.3","numpy==1.24.0", "pytest==7.4.0"]
additional_dependencies: ["click>=8.1.3","numpy==1.24.0", "pytest==7.4.0", "types-requests>=1.26"]
exclude: notebooks

2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
VERSION ?= $(shell dunamai from git --style pep440 --format "{base}.dev{distance}+{commit}")

install-dev:
poetry install --with dev --extras "tensorflow h5py"
poetry install --with dev --with test --extras "tensorflow h5py"
pre-commit install

install:
Expand Down
14 changes: 13 additions & 1 deletion modelscan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,20 @@
default="INFO",
help="level of log messages to display (default: INFO)",
)
@click.option(
"--show-skipped",
is_flag=True,
default=False,
help="Print a list of files that were skipped during the scan",
)
@click.pass_context
def cli(
ctx: click.Context,
log: str,
# url: Optional[str],
huggingface: Optional[str],
path: Optional[str],
show_skipped: bool,
) -> int:
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
Expand All @@ -74,7 +81,12 @@ def cli(
raise click.UsageError(
"Command line must include either a path or a Hugging Face model"
)
ConsoleReport.generate(modelscan.issues, modelscan.errors)
ConsoleReport.generate(
modelscan.issues,
modelscan.errors,
modelscan._skipped,
show_skipped=show_skipped,
)
return 0

except click.UsageError as e:
Expand Down
86 changes: 50 additions & 36 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from modelscan import models
from modelscan.models.keras.scan import KerasScan
from modelscan.models.scan import ScanBase
from modelscan.tools.utils import _http_get, _is_zipfile
from modelscan.tools.utils import _is_zipfile
from modelscan.utils import fetch_huggingface_repo_files, read_huggingface_file

logger = logging.getLogger("modelscan")

Expand All @@ -33,8 +34,9 @@ def __init__(self) -> None:
self.supported_extensions = set()
for scan in self.supported_model_scans:
self.supported_extensions.update(scan.supported_extensions())

self.supported_zip_extensions = set([".zip", ".npz"])
logger.debug(f"Supported model files {self.supported_extensions}")
logger.debug(f"Supported zip model files {self.supported_zip_extensions}")

# Output
self._issues = Issues()
Expand All @@ -44,7 +46,7 @@ def __init__(self) -> None:
def scan_path(self, path: Path) -> None:
if path.is_dir():
self._scan_directory(path)
elif _is_zipfile(path) or path.suffix in self._supported_zip_extensions():
elif _is_zipfile(path) or path.suffix in self.supported_zip_extensions:
is_keras_file = path.suffix in KerasScan.supported_extensions()
if is_keras_file:
self._scan_source(source=path, extension=path.suffix)
Expand All @@ -60,31 +62,39 @@ def _scan_directory(self, directory_path: Path) -> None:

def scan_huggingface_model(self, repo_id: str) -> None:
# List model files
model = json.loads(
_http_get(f"https://huggingface.co/api/models/{repo_id}").decode("utf-8")
)
file_names = [
file_name
for file_name in (sibling.get("rfilename") for sibling in model["siblings"])
if file_name is not None
]
file_names = fetch_huggingface_repo_files(repo_id)
if not file_names:
logger.error(f"Hugging face repo {repo_id} didn't return any files")
return

# Scan model files
for file_name in file_names:
fq_file_name = f"{repo_id}/{file_name}"
file_ext = os.path.splitext(file_name)[1]
url = f"https://huggingface.co/{repo_id}/resolve/main/{file_name}"
data = io.BytesIO(_http_get(url))
if (
_is_zipfile(source=url, data=data)
or file_ext in self._supported_zip_extensions()
file_ext not in self.supported_extensions
and file_ext not in self.supported_zip_extensions
):
logger.debug(f"Skipping: {fq_file_name} is not supported")
self._skipped.append(fq_file_name)
continue

raw_bytes = read_huggingface_file(repo_id, file_name)
if not raw_bytes:
logger.debug(f"Skipping: Failed to read {fq_file_name}")
self._skipped.append(fq_file_name)
continue

data = io.BytesIO(raw_bytes)
if (
_is_zipfile(source=fq_file_name, data=data)
or file_ext in self.supported_zip_extensions
):
self._scan_zip(source=url, data=data)
try:
self._scan_zip(source=fq_file_name, data=data)
except:
logger.debug(f"Skipped: Failed to read")
else:
self._scan_source(
source=url,
extension=file_ext,
data=data,
)
self._scan_source(source=fq_file_name, extension=file_ext, data=data)

def scan_url(self, url: str) -> None:
# Todo: before it was just scanning scanning_pickle_bytes
Expand Down Expand Up @@ -121,20 +131,20 @@ def _scan_source(
def _scan_zip(
self, source: Union[str, Path], data: Optional[IO[bytes]] = None
) -> None:
with zipfile.ZipFile(data or source, "r") as zip:
file_names = zip.namelist()
for file_name in file_names:
file_ext = os.path.splitext(file_name)[1]
with zip.open(file_name, "r") as file_io:
self._scan_source(
source=f"{source}:{file_name}",
extension=file_ext,
data=file_io,
)

@staticmethod
def _supported_zip_extensions() -> List[str]:
return [".zip", ".npz"]
try:
with zipfile.ZipFile(data or source, "r") as zip:
file_names = zip.namelist()
for file_name in file_names:
file_ext = os.path.splitext(file_name)[1]
with zip.open(file_name, "r") as file_io:
self._scan_source(
source=f"{source}:{file_name}",
extension=file_ext,
data=file_io,
)
except zipfile.BadZipFile as e:
logger.debug(f"Skipping zip file {source}, due to error", e, exc_info=True)
self._skipped.append(str(source))

@property
def issues(self) -> Issues:
Expand All @@ -143,3 +153,7 @@ def issues(self) -> Issues:
@property
def errors(self) -> List[Error]:
return self._errors

@property
def skipped(self) -> List[str]:
return self._skipped
14 changes: 14 additions & 0 deletions modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self) -> None:
def generate(
issues: Issues,
errors: List[Error],
skipped: List[str],
show_skipped: bool = False,
) -> Optional[str]:
"""
Generate report for the given codebase.
Expand All @@ -39,6 +41,8 @@ class ConsoleReport(Report):
def generate(
issues: Issues,
errors: List[Error],
skipped: List[str],
show_skipped: bool = False,
) -> None:
issues_by_severity = issues.group_by_severity()
print("\n[blue]--- Summary ---")
Expand Down Expand Up @@ -67,3 +71,13 @@ def generate(
for index, error in enumerate(errors):
print(f"\nError {index+1}:")
print(str(error))

if len(skipped) > 0:
print("\n[blue]--- Skipped --- ")
print(
f"\nTotal skipped: {len(skipped)} - run with --show-skipped to see the full list."
)
if show_skipped:
print(f"\nSkipped files list:\n")
for file_name in skipped:
print(str(file_name))
23 changes: 0 additions & 23 deletions modelscan/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,3 @@ def get_magic_number(data: IO[bytes]) -> Optional[int]:
except ValueError:
return None
return None


# TODO we can rewrite this function with better error logging and move it
# modelscan/tools/utils.py
def _http_get(url: str) -> bytes:
parsed_url = urllib.parse.urlparse(url)
path_and_query = parsed_url.path + (
"?" + parsed_url.query if len(parsed_url.query) > 0 else ""
)

conn = http.client.HTTPSConnection(parsed_url.netloc)
try:
conn.request("GET", path_and_query)
response = conn.getresponse()
if response.status == 302: # Follow redirections
return _http_get(response.headers["Location"])
elif response.status >= 400:
raise RuntimeError(
f"HTTP {response.status} ({response.reason}) calling GET {parsed_url.scheme}://{parsed_url.netloc}{path_and_query}"
)
return response.read()
finally:
conn.close()
53 changes: 53 additions & 0 deletions modelscan/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import requests
import logging
import json
from typing import Union, Optional, Generator, List, Tuple

logger = logging.getLogger("modelscan")


def fetch_url(url: str, allow_redirect: bool = False) -> Union[None, bytes]:
try:
response = requests.get(url, allow_redirects=allow_redirect, timeout=10)
response.raise_for_status()
return response.content
except requests.exceptions.HTTPError as e:
logger.error("Error with request: %s", e.response.status_code)
logger.error(e.response.text)
return None
except requests.exceptions.JSONDecodeError as e:
logger.error("Response was not valid JSON")
return None
except requests.exceptions.Timeout as e:
logger.error("Request timed out")
return None
except Exception as e:
logger.error("Unexpected error during request to %s: %s", url, str(e))
return None


def fetch_huggingface_repo_files(
repo_id: str,
) -> Union[None, List[str]]:
# Return list of model files
url = f"https://huggingface.co/api/models/{repo_id}"
data = fetch_url(url)
if not data:
return None

try:
model = json.loads(data.decode("utf-8"))
filenames = []
for sibling in model.get("siblings", []):
if sibling.get("rfilename"):
filenames.append(sibling.get("rfilename"))
return filenames
except json.decoder.JSONDecodeError as e:
logger.error(f"Failed to parse response for HuggingFace model repo {repo_id}")

return None


def read_huggingface_file(repo_id: str, file_name: str) -> Union[None, bytes]:
url = f"https://huggingface.co/{repo_id}/resolve/main/{file_name}"
return fetch_url(url, True)
Loading

0 comments on commit 19b0429

Please sign in to comment.