Skip to content

Commit

Permalink
[v0.1.3] Added Hugging Face Hub support. (#26)
Browse files Browse the repository at this point in the history
* [v0.1.3] Added Hugging Face Hub support.

* [v0.1.3] Added warning if huggingface-hub isn't installed but a huggingface-hub function was used.

* [v0.1.3] Cleaned up PR for release + removed branch wording from readme.

* [v0.1.3] Add checkpoints to the huggingface repos.

* [v0.1.3] Added a requirement for hfhub >= 0.21.0 and accompanying error msg.
  • Loading branch information
dbolya authored Mar 2, 2024
1 parent 1f825a3 commit 42d33a7
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 11 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
### **[2024.02.07]** v0.1.3
- Added support to save and load models to the huggingface hub, if huggingface_hub is installed.
- Most Hiera models have been uploaded to HuggingFace.

### **[2023.07.20]** v0.1.2
- Released the full model zoo.
- Added MAE functionality to the video models.
Expand All @@ -8,4 +12,4 @@
- Released all in1k finetuned models.

### **[2023.06.01]** v0.1.0
- Initial Release.
- Initial Release.
20 changes: 18 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ python setup.py build develop

## Model Zoo

### Torch Hub

Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html) even without the `hiera-transformer` package installed, e.g. the following initializes a base model pretrained and finetuned on ImageNet-1k:
```py
model = torch.hub.load("facebookresearch/hiera", model="hiera_base_224", pretrained=True, checkpoint="mae_in1k_ft_in1k")
Expand All @@ -74,8 +76,21 @@ model = torch.hub.load("facebookresearch/hiera", model="mae_hiera_base_224", pre
```
**Note:** Our MAE models were trained with a _normalized pixel loss_. That means that the patches were normalized before the network had to predict them. If you want to visualize the predictions, you'll have to unnormalize them using the visible patches (which might work but wouldn't be perfect) or unnormalize them using the ground truth. For model more names and corresponding checkpoint names see below.

### Hugging Face Hub

This repo also has [🤗 hub](https://huggingface.co/docs/hub/index) support. With the `hiera-transformer` and `huggingface-hub` packages installed, you can simply run, e.g.,
```py
from hiera import Hiera
model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k_ft_in1k") # mae pt then in1k ft'd model
model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k") # just mae pt, no ft
```
to load a model. Use `<model_name>.<checkpoint_name>` from model zoo below.

If you want to save a model, use `model.config` as the config, e.g.,
```py
model.save_pretrained("hiera-base-224", config=model.config)
```

**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here.
### Image Models
| Model | Model Name | Pretrained Models<br>(IN-1K MAE) | Finetuned Models<br>(IN-1K Supervised) | IN-1K<br>Top-1 (%) | A100 fp16<br>Speed (im/s) |
|----------|-----------------------|----------------------------------|----------------------------------------|:------------------:|:-------------------------:|
Expand All @@ -97,6 +112,7 @@ Each model inputs a 224x224 image.

Each model inputs 16 224x224 frames with a temporal stride of 4.

**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here.

## Usage

Expand All @@ -116,7 +132,7 @@ See [examples](https://github.com/facebookresearch/hiera/tree/main/examples) for

See [examples/inference](https://github.com/facebookresearch/hiera/blob/main/examples/inference.ipynb) for an example of how to prepare the data for inference.

Instantiate a model with either [torch hub](#model-zoo) or by [installing hiera](#installing-from-source) and running:
Instantiate a model with either [torch hub](#model-zoo) or [🤗 hub](#model-zoo) or by [installing hiera](#installing-from-source) and running:
```py
import hiera
model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k")
Expand Down
67 changes: 67 additions & 0 deletions hiera/hfhub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# https://github.com/facebookresearch/hiera/pull/25
# --------------------------------------------------------

import importlib.util
import importlib.metadata
from packaging import version

import inspect

def is_huggingface_hub_available():
available: bool = importlib.util.find_spec("huggingface_hub") is not None

if not available:
return False
else:
hfversion = importlib.metadata.version("huggingface_hub")
return version.parse(hfversion) >= version.parse("0.21.0")


if is_huggingface_hub_available():
from huggingface_hub import PyTorchModelHubMixin
else:
# Empty class in case modelmixins dont exist
class PyTorchModelHubMixin:
error_str: str = 'This feature requires "huggingface-hub >= 0.21.0" to be installed.'

@classmethod
def from_pretrained(cls, *args, **kwdargs):
raise RuntimeError(cls.error_str)

@classmethod
def save_pretrained(cls, *args, **kwdargs):
raise RuntimeError(cls.error_str)

@classmethod
def push_to_hub(cls, *args, **kwdargs):
raise RuntimeError(cls.error_str)



# Saves the input args to the function as self.config, also allows
# loading a config instead of kwdargs.
def has_config(func):
signature = inspect.signature(func)

def wrapper(self, *args, **kwdargs):
if "config" in kwdargs:
config = kwdargs["config"]
del kwdargs["config"]
kwdargs.update(**config)

self.config = {
k: v.default if (i-1) >= len(args) else args[i-1]
for i, (k, v) in enumerate(signature.parameters.items())
if v.default is not inspect.Parameter.empty
}
self.config.update(**kwdargs)

func(self, **kwdargs)
return wrapper
13 changes: 9 additions & 4 deletions hiera/hiera.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import math
from functools import partial
from typing import List, Tuple, Callable, Optional
from typing import List, Tuple, Callable, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -29,7 +29,7 @@
from timm.models.layers import DropPath, Mlp

from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll

from .hfhub import has_config, PyTorchModelHubMixin


class MaskUnitAttention(nn.Module):
Expand Down Expand Up @@ -204,7 +204,8 @@ def forward(
return x


class Hiera(nn.Module):
class Hiera(nn.Module, PyTorchModelHubMixin):
@has_config
def __init__(
self,
input_size: Tuple[int, ...] = (224, 224),
Expand All @@ -225,13 +226,17 @@ def __init__(
patch_padding: Tuple[int, ...] = (3, 3),
mlp_ratio: float = 4.0,
drop_path_rate: float = 0.0,
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
norm_layer: Union[str, nn.Module] = "LayerNorm",
head_dropout: float = 0.0,
head_init_scale: float = 0.001,
sep_pos_embed: bool = False,
):
super().__init__()

# Do it this way to ensure that the init args are all PoD (for config usage)
if isinstance(norm_layer, str):
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)

depth = sum(stages)
self.patch_stride = patch_stride
self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)]
Expand Down
6 changes: 5 additions & 1 deletion hiera/hiera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool

return model

# Keep some metadata so we can do things that require looping through all available models
model_def.checkpoints = checkpoints
model_def.default = default

return model_def

return inner
Expand Down Expand Up @@ -284,4 +288,4 @@ def forward(
# If not masked, we can return [B, H, W, C]
x = undo_windowing(x, size, cur_mu_shape)

return x
return x
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@

setup(
name="hiera-transformer",
version="0.1.2",
version="0.1.3",
author="Chaitanya Ryali, Daniel Bolya",
url="https://github.com/facebookresearch/hiera",
description="A fast, powerful, and simple hierarchical vision transformer",
install_requires=["torch>=1.8.1", "timm>=0.4.12", "tqdm"],
install_requires=["torch>=1.8.1", "timm>=0.4.12", "tqdm", "packaging"],
packages=find_packages(exclude=("examples", "build")),
license = 'CC BY-NC 4.0',
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
)
)

0 comments on commit 42d33a7

Please sign in to comment.