Skip to content

Commit

Permalink
just automatically sync batchnorm if more than one machine detected, …
Browse files Browse the repository at this point in the history
…allow for overriding with sync_batchnorm keyword argument
  • Loading branch information
lucidrains committed Oct 9, 2023
1 parent 6717204 commit f12087b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
25 changes: 15 additions & 10 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# This workflows will upload a Python Package using Twine when a release is created
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.

name: Upload Python Package

on:
release:
types: [created]
types: [published]

jobs:
deploy:
Expand All @@ -21,11 +26,11 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
34 changes: 24 additions & 10 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist

from torchvision import transforms as T

Expand Down Expand Up @@ -37,6 +38,10 @@ def set_requires_grad(model, val):
for p in model.parameters():
p.requires_grad = val

def MaybeSyncBatchnorm(is_distributed = None):
is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

# loss fn

def loss_fn(x, y):
Expand Down Expand Up @@ -75,32 +80,32 @@ def update_moving_average(ema_updater, ma_model, current_model):

# MLP class for projector and predictor

def MLP(dim, projection_size, hidden_size=4096):
def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
return nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)

def SimSiamMLP(dim, projection_size, hidden_size=4096):
def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
return nn.Sequential(
nn.Linear(dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size, bias=False),
nn.BatchNorm1d(projection_size, affine=False)
MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
)

# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None):
super().__init__()
self.net = net
self.layer = layer
Expand All @@ -110,6 +115,7 @@ def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use
self.projection_hidden_size = projection_hidden_size

self.use_simsiam_mlp = use_simsiam_mlp
self.sync_batchnorm = sync_batchnorm

self.hidden = {}
self.hook_registered = False
Expand Down Expand Up @@ -137,7 +143,7 @@ def _register_hook(self):
def _get_projector(self, hidden):
_, dim = hidden.shape
create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size)
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
return projector.to(hidden)

def get_representation(self, x):
Expand Down Expand Up @@ -178,7 +184,8 @@ def __init__(
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True
use_momentum = True,
sync_batchnorm = None
):
super().__init__()
self.net = net
Expand All @@ -205,7 +212,14 @@ def __init__(
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)

self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_simsiam_mlp=not use_momentum)
self.online_encoder = NetWrapper(
net,
projection_size,
projection_hidden_size,
layer = hidden_layer,
use_simsiam_mlp = not use_momentum,
sync_batchnorm = sync_batchnorm
)

self.use_momentum = use_momentum
self.target_encoder = None
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
setup(
name = 'byol-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.6.0',
version = '0.7.0',
license='MIT',
description = 'Self-supervised contrastive learning made simple',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/byol-pytorch',
long_description_content_type = 'text/markdown',
keywords = [
'self-supervised learning',
'artificial intelligence'
Expand Down

0 comments on commit f12087b

Please sign in to comment.