Skip to content

PyTorch code for training neural networks without global back-propagation

Notifications You must be signed in to change notification settings

scotwilli/local-loss

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 

Repository files navigation

Training neural networks with local error signals

This repo contains PyTorch code for training neural networks without global backprop. Experiments are performed by Arild Nøkland and Lars Hiller Eidnes.

A more detailed description of the experiments is available on arXiv here: https://arxiv.org/abs/1901.06656

Supervised training of neural networks for classification is typically performed with a global loss function. The loss function provides a gradient for the output layer, and this gradient is back-propagated to hidden layers to dictate an update direction for the weights. An alternative approach is to train the network with layer-wise loss functions. In this paper we demonstrate, for the first time, that layer-wise training can approach the state-of-the-art on a variety of image datasets. We use single-layer sub-networks and two different supervised loss functions to generate local error signals for the hidden layers, and we show that the combination of these losses help with optimization in the context of local learning. Using local errors could be a step towards more biologically plausible deep learning because the global error does not have to be transported back to hidden layers.

In the tables below, 'pred' indicates a layer-wise cross-entropy loss, 'sim' indicates a layer-wise similarity matching loss, and 'predsim' indicates a combination of these losses. For the local losses, the computational graph is detached after each hidden layer.

Experiments

Results on MNIST with 2 pixel jittering:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 2.9M 0.75 0.68 0.80 0.62
vgg8b 7.3M 0.26 0.40 0.65 0.31
vgg8b + cutout 7.3M - - - 0.26

Results on Fashion-MNIST with 2 pixel jittering and horizontal flipping:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 2.9M 8.37 8.60 9.70 8.54
vgg8b 7.3M 4.53 5.66 5.12 4.65
vgg8b (2x) 28.2M 4.55 5.11 4.92 4.33
vgg8b (2x) + cutout 28.2M - - - 4.14

Results on Kuzusjiji-MNIST with no data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 2.9M 5.99 7.26 9.80 7.33
vgg8b 7.3M 1.53 2.22 2.19 1.36
vgg8b + cutout 7.3M - - - 0.99

Results on Cifar-10 with data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 27.3M 33.56 32.33 33.48 30.93
vgg8b 8.9M 5.99 8.40 7.16 5.58
vgg11b 11.6M 5.56 8.39 6.70 5.30
vgg11b (2x) 42.0M 4.91 7.30 6.66 4.42
vgg11b (3x) 91.3M 5.02 7.37 9.34 3.97
vgg11b (3x) + cutout 91.3M - - - 3.60

Results on Cifar-100 with data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
mlp 27.3M 62.57 58.87 62.46 56.88
vgg8b 9.0M 26.24 29.32 32.64 24.07
vgg11b 11.7M 25.18 29.58 30.82 24.05
vgg11b (2x) 42.1M 23.44 26.91 28.03 21.20
vgg11b (3x) 91.4M 23.69 25.90 28.01 20.13

Results on SVHN with extra training data, but no augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
vgg8b 8.9M 2.29 2.12 1.89 1.74
vgg8b + cutout 8.9M - - - 1.65

Results on STL-10 with no data augmentation:

Network #Params Global loss Local loss 'pred' Local loss 'sim' Local loss 'predsim'
vgg8b 11.5M 33.08 26.83 23.15 20.51
vgg8b + cutout 11.5M - - - 19.25

Training recipes

To replicate training of MLP on MNIST with local loss 'predsim':

python train.py --model mlp --dataset MNIST --dropout 0.1 --lr 5e-4 --num-layers 3 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu

To replicate training of VGG8b on MNIST with local loss 'predsim':

python train.py --model vgg8b --dataset MNIST --dropout 0.2 --lr 5e-4 --epochs 100 --lr-decay-milestones 50 75 89 94 --nonlin leakyrelu --dim-in-decoder 1024

To replicate training of MLP on CIFAR10 with local loss 'predsim':

python train.py --model mlp --dataset CIFAR10 --dropout 0.1 --lr 5e-4 --num-layers 3 --num-hidden 3000 --nonlin leakyrelu

To replicate training of VGG8b on CIFAR10 with local loss 'predsim':

python train.py --model vgg8b --dataset CIFAR10 --dropout 0.2 --lr 5e-4 --nonlin leakyrelu --dim-in-decoder 2048

To replicate training of VGG11b (3x) on CIFAR10 with local loss 'predsim':

python train.py --model vgg11b --dataset CIFAR10 --dropout 0.3 --lr 3e-4 --feat-mult 3 --nonlin leakyrelu

For all the above recipes, to train with local cross-entropy loss, add argument

--loss-sup pred

For all the above recipes, to train with local similarity matching loss, add argument

--loss-sup sim

For all the above recipes, to train with global loss, add argument

--backprop

For all the above recipes, to train with a more biologically plausible version of local loss, add argument

--bio

To add cutout regularization with cutout hole size 14, add arguments

--cutout --length 14

To replicate all the above experiments, run

./run_experiments.sh

About

PyTorch code for training neural networks without global back-propagation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 86.5%
  • Shell 13.5%