Weight Watcher analyzes the Fat Tails in the weight matrices of Deep Neural Networks (DNNs).
This tool can predict the trends in the generalization accuracy of a series of DNNs, such as VGG11, VGG13, ..., or even the entire series of ResNet models--without needing a test set !
This relies upon recent research into the Heavy (Fat) Tailed Self Regularization in DNNs
The tool lets one compute a averager capacity, or quality, metric for a series of DNNs, trained on the same data, but with different hyperparameters, or even different but related architectures. For example, it can predict that VGG19_BN generalizes better than VGG19, and better than VGG16_BN, VGG16, etc.
There are 2 basic types metrics we use
- alpha (the average power law exponent)
- weighted alpha / log_alpha_norm (scale adjusted alpha metrics)
The average alpha can be used to compare one or more DNN models with different hyperparemeter settings, but of the same depth. The average weighted alpha is suitable for DNNs of differing depths.
Here is an example of the Weighted Alpha capacity metric for all the current pretrained VGG models.
Notice: we did not peek at the ImageNet test data to build this plot.
- Tensorflow 2.x / Keras
- PyTorch
- HuggingFace
- Dense / Linear / Fully Connected (and Conv1D)
- Conv2D
pip install weightwatcher
import weightwatcher as ww
import torchvision.models as models
model = models.vgg19_bn(pretrained=True)
watcher = ww.WeightWatcher(model=model)
details = watcher.analyze()
summary = watcher.get_summary(details)
It is as easy to run and generates a pandas dataframe with details (and plots) for each layer
and summary dict of generalization metrics
{'log_norm': 2.11,
'alpha': 3.06,
'alpha_weighted': 2.78,
'log_alpha_norm': 3.21,
'log_spectral_norm': 0.89,
'stable_rank': 20.90,
'mp_softrank': 0.52}]
More examples are include the Demo Notebook
and will be made available shortly in a Jupyter book
The watcher object has several functions and analyze features described below
analyze( model=None, layers=[], min_evals=0, max_evals=None,
plot=True, randomize=True, mp_fit=True, ww2x=False):
...
describe(self, model=None, layers=[], min_evals=0, max_evals=None,
plot=True, randomize=True, mp_fit=True, ww2x=False):
...
get_details()
get_summary(details) or get_summary()
get_ESD()
...
distances(model_1, model_2)
ww.LAYER_TYPE.CONV2D | ww.LAYER_TYPE.CONV2D | ww.LAYER_TYPE.DENSE
as
details=watcher.analyze(layers=[ww.LAYER_TYPE.CONV2D])
details=watcher.analyze(layers=[20])
Sets the minimum and maximum size of the weight matrices analyzed. Setting max is useful for a quick debugging.
details = watcher.analyze(min_evals=50, max_evals=500)
Create ESD plots for each layer weight matrix to observe how well the power law fits work
details = watcher.analyze(plot=True)
The randomize option compares the ESD of the layer weight matrix (W) to the ESD of the randomized W matrix. This is good way to visualize the correlations in the true ESD.
details = watcher.analyze(randomize=True, plot=True)
Attempts to the fit the ESD to an MP dist.
details = watcher.analyze(mp_fit=True, plot=True)
and reports the
num_spikes, mp_sigma, and mp_sofrank
Also works for randomized ESD and reports
rand_num_spikes, rand_mp_sigma, and rand_mp_sofrank
watcher.analyze()
esd = watcher.get_ESD()
Describe a model and report the details dataframe, without analyzing it
details = watcher.describe(model=model)
Get the average metrics, as a summary (dict), from the given (or current) details dataframe
details = watcher.analyze(model=model)
summary = watcher.get_summary(model)
or just
watcher.analyze()
summary = watcher.get_summary()
The new distances method reports the distances between 2 models, such as the norm between the initial weight matrices and the final, trained weight matrices
details = watcher.distances(initial_model, trained_model)
The new 0.4 version of weightwatcher treats each layer as a single, unified set of eigenvalues. In contrast, the 0.2x versions split the Conv2D layers into n slices, 1 for each receptive field. The ww2x option provides results which are back-compatable with the 0.2x version of weightwatcher, with details provide for each slice for each layer.
details = watcher.analyze(ww2x=True)
Calculation Consulting homepage
This tool is based on state-of-the-art research done in collaboration with UC Berkeley:
-
Traditional and Heavy Tailed Self Regularization in Neural Network Models
- Notebook for above 2 papers (https://github.com/CalculatedContent/ImplicitSelfRegularization)
-
- Notebook for paper (https://github.com/CalculatedContent/PredictingTestAccuracies)
and has been presented at Stanford, UC Berkeley, etc:
and major AI conferences like ICML, KDD, etc.
and has been the subject many popular podcasts
-
Data Science at Home Podcast ======= KDD 2019 Workshop: Statistical Mechanics Methods for Discovering Knowledge from Production-Scale Neural Networks
Talk on latest results, Stanford ICME 2020
Publishing to the PyPI repository:
# 1. Check in the latest code with the correct revision number (__version__ in __init__.py)
vi weightwatcher/__init__.py # Increse release number, remove -dev to revision number
git commit
# 2. Check out latest version from the repo in a fresh directory
cd ~/temp/
git clone https://github.com/CalculatedContent/WeightWatcher
cd WeightWatcher/
# 3. Use the latest version of the tools
python -m pip install --upgrade setuptools wheel twine
# 4. Create the package
python setup.py sdist bdist_wheel
# 5. Test the package
twine check dist/*
# 6. Upload the package to PyPI
twine upload dist/*
# 7. Tag/Release in github by creating a new release (https://github.com/CalculatedContent/WeightWatcher/releases/new)
We have a slack channel for the tool if you need help For an invite, please send an email to [email protected]