From 42d33a7919a6787d0add7ca511997ddddc30050a Mon Sep 17 00:00:00 2001 From: Daniel Bolya Date: Fri, 1 Mar 2024 23:15:10 -0500 Subject: [PATCH] [v0.1.3] Added Hugging Face Hub support. (#26) * [v0.1.3] Added Hugging Face Hub support. * [v0.1.3] Added warning if huggingface-hub isn't installed but a huggingface-hub function was used. * [v0.1.3] Cleaned up PR for release + removed branch wording from readme. * [v0.1.3] Add checkpoints to the huggingface repos. * [v0.1.3] Added a requirement for hfhub >= 0.21.0 and accompanying error msg. --- CHANGELOG.md | 6 +++- README.md | 20 +++++++++++-- hiera/hfhub.py | 67 ++++++++++++++++++++++++++++++++++++++++++++ hiera/hiera.py | 13 ++++++--- hiera/hiera_utils.py | 6 +++- setup.py | 6 ++-- 6 files changed, 107 insertions(+), 11 deletions(-) create mode 100644 hiera/hfhub.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 803c256..0936ab1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +### **[2024.02.07]** v0.1.3 + - Added support to save and load models to the huggingface hub, if huggingface_hub is installed. + - Most Hiera models have been uploaded to HuggingFace. + ### **[2023.07.20]** v0.1.2 - Released the full model zoo. - Added MAE functionality to the video models. @@ -8,4 +12,4 @@ - Released all in1k finetuned models. ### **[2023.06.01]** v0.1.0 - - Initial Release. \ No newline at end of file + - Initial Release. diff --git a/README.md b/README.md index 51b40b0..53a1114 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,8 @@ python setup.py build develop ## Model Zoo +### Torch Hub + Here we provide model checkpoints for Hiera. Each model listed is accessible on [torch hub](https://pytorch.org/docs/stable/hub.html) even without the `hiera-transformer` package installed, e.g. the following initializes a base model pretrained and finetuned on ImageNet-1k: ```py model = torch.hub.load("facebookresearch/hiera", model="hiera_base_224", pretrained=True, checkpoint="mae_in1k_ft_in1k") @@ -74,8 +76,21 @@ model = torch.hub.load("facebookresearch/hiera", model="mae_hiera_base_224", pre ``` **Note:** Our MAE models were trained with a _normalized pixel loss_. That means that the patches were normalized before the network had to predict them. If you want to visualize the predictions, you'll have to unnormalize them using the visible patches (which might work but wouldn't be perfect) or unnormalize them using the ground truth. For model more names and corresponding checkpoint names see below. +### Hugging Face Hub + +This repo also has [🤗 hub](https://huggingface.co/docs/hub/index) support. With the `hiera-transformer` and `huggingface-hub` packages installed, you can simply run, e.g., +```py +from hiera import Hiera +model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k_ft_in1k") # mae pt then in1k ft'd model +model = Hiera.from_pretrained("facebook/hiera_base_224.mae_in1k") # just mae pt, no ft +``` +to load a model. Use `.` from model zoo below. + +If you want to save a model, use `model.config` as the config, e.g., +```py +model.save_pretrained("hiera-base-224", config=model.config) +``` -**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here. ### Image Models | Model | Model Name | Pretrained Models
(IN-1K MAE) | Finetuned Models
(IN-1K Supervised) | IN-1K
Top-1 (%) | A100 fp16
Speed (im/s) | |----------|-----------------------|----------------------------------|----------------------------------------|:------------------:|:-------------------------:| @@ -97,6 +112,7 @@ Each model inputs a 224x224 image. Each model inputs 16 224x224 frames with a temporal stride of 4. +**Note:** the speeds listed here were benchmarked _without_ PyTorch's optimized [scaled dot product attention](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html). If using PyTorch 2.0 or above, your inference speed will probably be faster than what's listed here. ## Usage @@ -116,7 +132,7 @@ See [examples](https://github.com/facebookresearch/hiera/tree/main/examples) for See [examples/inference](https://github.com/facebookresearch/hiera/blob/main/examples/inference.ipynb) for an example of how to prepare the data for inference. -Instantiate a model with either [torch hub](#model-zoo) or by [installing hiera](#installing-from-source) and running: +Instantiate a model with either [torch hub](#model-zoo) or [🤗 hub](#model-zoo) or by [installing hiera](#installing-from-source) and running: ```py import hiera model = hiera.hiera_base_224(pretrained=True, checkpoint="mae_in1k_ft_in1k") diff --git a/hiera/hfhub.py b/hiera/hfhub.py new file mode 100644 index 0000000..880245d --- /dev/null +++ b/hiera/hfhub.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# https://github.com/facebookresearch/hiera/pull/25 +# -------------------------------------------------------- + +import importlib.util +import importlib.metadata +from packaging import version + +import inspect + +def is_huggingface_hub_available(): + available: bool = importlib.util.find_spec("huggingface_hub") is not None + + if not available: + return False + else: + hfversion = importlib.metadata.version("huggingface_hub") + return version.parse(hfversion) >= version.parse("0.21.0") + + +if is_huggingface_hub_available(): + from huggingface_hub import PyTorchModelHubMixin +else: + # Empty class in case modelmixins dont exist + class PyTorchModelHubMixin: + error_str: str = 'This feature requires "huggingface-hub >= 0.21.0" to be installed.' + + @classmethod + def from_pretrained(cls, *args, **kwdargs): + raise RuntimeError(cls.error_str) + + @classmethod + def save_pretrained(cls, *args, **kwdargs): + raise RuntimeError(cls.error_str) + + @classmethod + def push_to_hub(cls, *args, **kwdargs): + raise RuntimeError(cls.error_str) + + + +# Saves the input args to the function as self.config, also allows +# loading a config instead of kwdargs. +def has_config(func): + signature = inspect.signature(func) + + def wrapper(self, *args, **kwdargs): + if "config" in kwdargs: + config = kwdargs["config"] + del kwdargs["config"] + kwdargs.update(**config) + + self.config = { + k: v.default if (i-1) >= len(args) else args[i-1] + for i, (k, v) in enumerate(signature.parameters.items()) + if v.default is not inspect.Parameter.empty + } + self.config.update(**kwdargs) + + func(self, **kwdargs) + return wrapper diff --git a/hiera/hiera.py b/hiera/hiera.py index 35e8c93..ba78700 100644 --- a/hiera/hiera.py +++ b/hiera/hiera.py @@ -20,7 +20,7 @@ import math from functools import partial -from typing import List, Tuple, Callable, Optional +from typing import List, Tuple, Callable, Optional, Union import torch import torch.nn as nn @@ -29,7 +29,7 @@ from timm.models.layers import DropPath, Mlp from .hiera_utils import pretrained_model, conv_nd, do_pool, do_masked_conv, Unroll, Reroll - +from .hfhub import has_config, PyTorchModelHubMixin class MaskUnitAttention(nn.Module): @@ -204,7 +204,8 @@ def forward( return x -class Hiera(nn.Module): +class Hiera(nn.Module, PyTorchModelHubMixin): + @has_config def __init__( self, input_size: Tuple[int, ...] = (224, 224), @@ -225,13 +226,17 @@ def __init__( patch_padding: Tuple[int, ...] = (3, 3), mlp_ratio: float = 4.0, drop_path_rate: float = 0.0, - norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), + norm_layer: Union[str, nn.Module] = "LayerNorm", head_dropout: float = 0.0, head_init_scale: float = 0.001, sep_pos_embed: bool = False, ): super().__init__() + # Do it this way to ensure that the init args are all PoD (for config usage) + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + depth = sum(stages) self.patch_stride = patch_stride self.tokens_spatial_shape = [i // s for i, s in zip(input_size, patch_stride)] diff --git a/hiera/hiera_utils.py b/hiera/hiera_utils.py index 62d087b..ced15c0 100644 --- a/hiera/hiera_utils.py +++ b/hiera/hiera_utils.py @@ -60,6 +60,10 @@ def model_def(pretrained: bool = False, checkpoint: str = default, strict: bool return model + # Keep some metadata so we can do things that require looping through all available models + model_def.checkpoints = checkpoints + model_def.default = default + return model_def return inner @@ -284,4 +288,4 @@ def forward( # If not masked, we can return [B, H, W, C] x = undo_windowing(x, size, cur_mu_shape) - return x \ No newline at end of file + return x diff --git a/setup.py b/setup.py index f51318b..11b1699 100644 --- a/setup.py +++ b/setup.py @@ -9,13 +9,13 @@ setup( name="hiera-transformer", - version="0.1.2", + version="0.1.3", author="Chaitanya Ryali, Daniel Bolya", url="https://github.com/facebookresearch/hiera", description="A fast, powerful, and simple hierarchical vision transformer", - install_requires=["torch>=1.8.1", "timm>=0.4.12", "tqdm"], + install_requires=["torch>=1.8.1", "timm>=0.4.12", "tqdm", "packaging"], packages=find_packages(exclude=("examples", "build")), license = 'CC BY-NC 4.0', long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", -) \ No newline at end of file +)