Skip to content

Commit

Permalink
Merge pull request #44 from espnet/huggingface
Browse files Browse the repository at this point in the history
Huggingface compatibility
  • Loading branch information
kamo-naoyuki authored Aug 26, 2021
2 parents df4970f + dc0e4d2 commit 40ba4b4
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 23 deletions.
50 changes: 34 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

Utilities managing the pretrained models created by [ESPnet](https://github.com/espnet/espnet). This function is inspired by the [Asteroid pretrained model function](https://github.com/mpariente/asteroid/blob/master/docs/source/readmes/pretrained_models.md).

- **From version 0.1.0, the huggingface models can be also used**: https://huggingface.co/models?filter=espnet
- Zenodo community: https://zenodo.org/communities/espnet/
- Registered models: [table.csv](espnet_model_zoo/table.csv)

Expand All @@ -22,7 +23,7 @@ pip install espnet_model_zoo
```

## Python API for inference
See the next section about `model_name`
`model_name` in the following section should be `huggingface_id` or one of the tags in the [table.csv](espnet_model_zoo/table.csv).

### ASR

Expand Down Expand Up @@ -100,13 +101,20 @@ d = ModelDownloader("~/.cache/espnet") # Specify cachedir
d = ModelDownloader() # <module_dir> is used as cachedir by default
```

To obtain a model, you need to give a model name, which is listed in [table.csv](espnet_model_zoo/table.csv).
To obtain a model, you need to give a `huggingface_id`model` or a tag , which is listed in [table.csv](espnet_model_zoo/table.csv).

```python
>>> d.download_and_unpack("kamo-naoyuki/mini_an4_asr_train_raw_bpe_valid.acc.best")
{"asr_train_config": <config path>, "asr_model_file": <model path>, ...}
```

You can specify the revision if it's huggingface_id giving with `@`:

```python
>>> d.download_and_unpack("kamo-naoyuki/mini_an4_asr_train_raw_bpe_valid.acc.best@<revision>")
{"asr_train_config": <config path>, "asr_model_file": <model path>, ...}
```

Note that if the model already exists, you can skip downloading and unpacking.

You can also get a model with certain conditions.
Expand All @@ -129,22 +137,22 @@ You can also obtain it from the URL directly.
d.download_and_unpack("https://zenodo.org/record/...")
```

If you need to use a local model file using this API, you can also give it.
If you need to use a local model file using this API, you can also give it.

```python
d.download_and_unpack("./some/where/model.zip")
```

In this case, the contents are also expanded in the cache directory,
but the model is identified by the file path,
so if you move the model to somewhere and unpack again,
it's treated as another model,
but the model is identified by the file path,
so if you move the model to somewhere and unpack again,
it's treated as another model,
thus the contents are expanded again at another place.

## Query model names

You can view the model names from our Zenodo community, https://zenodo.org/communities/espnet/,
or using `query()`. All information are written in [table.csv](espnet_model_zoo/table.csv).
You can view the model names from our Zenodo community, https://zenodo.org/communities/espnet/,
or using `query()`. All information are written in [table.csv](espnet_model_zoo/table.csv).

```python
d.query("name")
Expand All @@ -162,11 +170,11 @@ d.query("name", task="asr")

```sh
# Query model name
espnet_model_zoo_query task=asr corpus=wsj
espnet_model_zoo_query task=asr corpus=wsj
# Show all model name
espnet_model_zoo_query
# Query the other key
espnet_model_zoo_query --key url task=asr corpus=wsj
espnet_model_zoo_query --key url task=asr corpus=wsj
```
- `espnet_model_zoo_download`

Expand Down Expand Up @@ -197,19 +205,29 @@ cd egs2/wsj/asr1

## Register your model

### Huggingface
1. Upload your model using huggingface API

Coming soon...

1. Create a Pull Request to modify [table.csv](espnet_model_zoo/table.csv)

The models registered in this `table.csv`, the model are tested in the CI.
Indeed, the model can be downloaded without modification `table.csv`.
1. (Administrator does) Increment the third version number of [setup.py](setup.py), e.g. 0.0.3 -> 0.0.4
1. (Administrator does) Release new version


### Zenodo (Obsolete)

1. Upload your model to Zenodo

You need to [signup to Zenodo](https://zenodo.org/) and [create an access token](https://zenodo.org/account/settings/applications/tokens/new/) to upload models.
You can upload your own model by using `espnet_model_zoo_upload` command freely,
You can upload your own model by using `espnet_model_zoo_upload` command freely,
but we normally upload a model using [recipes](https://github.com/espnet/espnet/blob/master/egs2/TEMPLATE).

1. Create a Pull Request to modify [table.csv](espnet_model_zoo/table.csv)

You need to append your record at the last line.
1. (Administrator does) Increment the third version number of [setup.py](setup.py), e.g. 0.0.3 -> 0.0.4
1. (Administrator does) Release new version


## Update your model

If your model has some troubles, please modify the record at Zenodo directly or reupload a corrected file using `espnet_zenodo_upload` as another record.
102 changes: 97 additions & 5 deletions espnet_model_zoo/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
from typing import Union
import warnings

from filelock import FileLock
from huggingface_hub import snapshot_download
import pandas as pd
import requests
from tqdm import tqdm
import yaml

from espnet2.main_funcs.pack_funcs import find_path_and_change_it_recursive
from espnet2.main_funcs.pack_funcs import get_dict_from_cache
from espnet2.main_funcs.pack_funcs import unpack

Expand Down Expand Up @@ -173,12 +177,9 @@ def get_url(self, name: str = None, version: int = -1, **kwargs: str) -> str:
# If Specifying local file path
if name is not None and Path(name).exists() and len(kwargs) == 1:
url = str(Path(name).absolute())
# If no models satisfy the conditions, raise an error

else:
message = "Not found models:"
for key, value in kwargs.items():
message += f" {key}={value}"
raise RuntimeError(message)
return "huggingface.co"
else:
urls = self.data_frame[conditions]["url"]
if version < 0:
Expand Down Expand Up @@ -236,10 +237,90 @@ def unpack_local_file(self, name: str = None) -> Dict[str, Union[str, List[str]]
# Extract files from archived file
return unpack(filename, outdir)

def huggingface_download(
self, name: str = None, version: int = -1, quiet: bool = False, **kwargs: str
) -> str:
# Get huggingface_id from table.csv
if name is None:
names = self.query(key="name", **kwargs)
if len(names) == 0:
message = "Not found models:"
for key, value in kwargs.items():
message += f" {key}={value}"
raise RuntimeError(message)
if version < 0:
version = len(names) + version
name = list(names)[version]

if "@" in name:
huggingface_id, revision = name.split("@", 1)
else:
huggingface_id = name
revision = None

return snapshot_download(
huggingface_id,
revision=revision,
library_name="espnet",
cache_dir=self.cachedir,
)

@staticmethod
def _unpack_cache_dir_for_huggingfase(cache_dir: str):
meta_yaml = Path(cache_dir) / "meta.yaml"
lock_file = Path(cache_dir) / ".lock"
flag_file = Path(cache_dir) / ".done"

with meta_yaml.open("r", encoding="utf-8") as f:
d = yaml.safe_load(f)
assert isinstance(d, dict), type(d)

yaml_files = d["yaml_files"]
files = d["files"]
assert isinstance(yaml_files, dict), type(yaml_files)
assert isinstance(files, dict), type(files)

# Rewrite yaml_files for first case
with FileLock(lock_file):
if not flag_file.exists():
for key, value in yaml_files.items():
yaml_file = Path(cache_dir) / value
with yaml_file.open("r", encoding="utf-8") as f:
d = yaml.safe_load(f)
assert isinstance(d, dict), type(d)
for name in Path(cache_dir).glob("**/*"):
name = name.relative_to(Path(cache_dir))
d = find_path_and_change_it_recursive(
d, name, str(Path(cache_dir) / name)
)

with yaml_file.open("w", encoding="utf-8") as f:
yaml.safe_dump(d, f)

with flag_file.open("w"):
pass

retval = {}
for key, value in list(yaml_files.items()) + list(files.items()):
retval[key] = str(Path(cache_dir) / value)
return retval

def download(
self, name: str = None, version: int = -1, quiet: bool = False, **kwargs: str
) -> str:
url = self.get_url(name=name, version=version, **kwargs)

# For huggingface compatibility
if url in [
"https://huggingface.co/",
"https://huggingface.co",
"huggingface.co",
]:
# TODO(kamo): Support quiet
cache_dir = self.huggingface_download(name=name, version=version, **kwargs)
self._unpack_cache_dir_for_huggingfase(cache_dir)
return cache_dir

if not is_url(url) and Path(url).exists():
return url

Expand Down Expand Up @@ -281,6 +362,17 @@ def download_and_unpack(
if not is_url(url) and Path(url).exists():
return self.unpack_local_file(url)

# For huggingface compatibility
if url in [
"https://huggingface.co/",
"https://huggingface.co",
"huggingface.co",
]:
# download_and_unpack and download are same if huggingface case
# TODO(kamo): Support quiet
cache_dir = self.huggingface_download(name=name, version=version, **kwargs)
return self._unpack_cache_dir_for_huggingfase(cache_dir)

# Unpack to <cachedir>/<hash> in order to give an unique name
outdir = self.cachedir / str_to_hash(url)

Expand Down
1 change: 1 addition & 0 deletions espnet_model_zoo/table.csv
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,4 @@ dns_ins20,enh,Yen-Ju Lu/dns_ins20_enh_train_enh_blstm_tf_raw_valid.loss.best,htt
jv_openslr35,asr,jv_openslr35,https://zenodo.org/record/5090139/files/asr_train_asr_raw_bpe1000_valid.acc.best.zip?download=1,16000,jv,,1.8.1,0.9.10,,true
su_openslr36,asr,su_openslr36,https://zenodo.org/record/5090135/files/asr_train_asr_raw_bpe1000_valid.acc.best.zip?download=1,16000,su,,1.8.1,0.9.10,,true
ksponspeech,asr,Yushi Ueda/ksponspeech_asr_train_asr_conformer8_n_fft512_hop_length256_raw_kr_bpe2309_valid.acc.best,https://zenodo.org/record/5154341/files/asr_train_asr_conformer8_n_fft512_hop_length256_raw_kr_bpe2309_valid.acc.best.zip?download=1,16000,kr,,1.8.1,0.10.0,538393c,true
librispeech,asr,byan/librispeech_asr_train_asr_conformer_raw_bpe_batch_bins30000000_accum_grad3_optim_conflr0.001_sp,https://huggingface.co/,16000,en,,,,,true
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@


requirements = {
"install": ["pandas", "requests", "tqdm", "numpy", "espnet"],
"install": [
"pandas",
"requests",
"tqdm",
"numpy",
"espnet",
"huggingface_hub",
"filelock",
],
"setup": ["pytest-runner"],
"test": [
"pytest>=3.3.0",
Expand All @@ -29,7 +37,7 @@
dirname = os.path.dirname(__file__)
setup(
name="espnet_model_zoo",
version="0.0.0a31",
version="0.1.0",
url="http://github.com/espnet/espnet_model_zoo",
description="ESPnet Model Zoo",
long_description=open(os.path.join(dirname, "README.md"), encoding="utf-8").read(),
Expand Down

0 comments on commit 40ba4b4

Please sign in to comment.