-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[v0.1.3] Added Hugging Face Hub support. (#26)
* [v0.1.3] Added Hugging Face Hub support. * [v0.1.3] Added warning if huggingface-hub isn't installed but a huggingface-hub function was used. * [v0.1.3] Cleaned up PR for release + removed branch wording from readme. * [v0.1.3] Add checkpoints to the huggingface repos. * [v0.1.3] Added a requirement for hfhub >= 0.21.0 and accompanying error msg.
- Loading branch information
Showing
6 changed files
with
107 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# -------------------------------------------------------- | ||
# References: | ||
# https://github.com/facebookresearch/hiera/pull/25 | ||
# -------------------------------------------------------- | ||
|
||
import importlib.util | ||
import importlib.metadata | ||
from packaging import version | ||
|
||
import inspect | ||
|
||
def is_huggingface_hub_available(): | ||
available: bool = importlib.util.find_spec("huggingface_hub") is not None | ||
|
||
if not available: | ||
return False | ||
else: | ||
hfversion = importlib.metadata.version("huggingface_hub") | ||
return version.parse(hfversion) >= version.parse("0.21.0") | ||
|
||
|
||
if is_huggingface_hub_available(): | ||
from huggingface_hub import PyTorchModelHubMixin | ||
else: | ||
# Empty class in case modelmixins dont exist | ||
class PyTorchModelHubMixin: | ||
error_str: str = 'This feature requires "huggingface-hub >= 0.21.0" to be installed.' | ||
|
||
@classmethod | ||
def from_pretrained(cls, *args, **kwdargs): | ||
raise RuntimeError(cls.error_str) | ||
|
||
@classmethod | ||
def save_pretrained(cls, *args, **kwdargs): | ||
raise RuntimeError(cls.error_str) | ||
|
||
@classmethod | ||
def push_to_hub(cls, *args, **kwdargs): | ||
raise RuntimeError(cls.error_str) | ||
|
||
|
||
|
||
# Saves the input args to the function as self.config, also allows | ||
# loading a config instead of kwdargs. | ||
def has_config(func): | ||
signature = inspect.signature(func) | ||
|
||
def wrapper(self, *args, **kwdargs): | ||
if "config" in kwdargs: | ||
config = kwdargs["config"] | ||
del kwdargs["config"] | ||
kwdargs.update(**config) | ||
|
||
self.config = { | ||
k: v.default if (i-1) >= len(args) else args[i-1] | ||
for i, (k, v) in enumerate(signature.parameters.items()) | ||
if v.default is not inspect.Parameter.empty | ||
} | ||
self.config.update(**kwdargs) | ||
|
||
func(self, **kwdargs) | ||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters