Skip to content

Commit

Permalink
fix simsiam, thanks to @chingisooinar
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2022
1 parent caa65d7 commit 6717204
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
40 changes: 25 additions & 15 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,32 @@ def update_moving_average(ema_updater, ma_model, current_model):

# MLP class for projector and predictor

class MLP(nn.Module):
def __init__(self, dim, projection_size, hidden_size = 4096):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)

def forward(self, x):
return self.net(x)
def MLP(dim, projection_size, hidden_size=4096):
return nn.Sequential(
nn.Linear(dim, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size)
)

def SimSiamMLP(dim, projection_size, hidden_size=4096):
return nn.Sequential(
nn.Linear(dim, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, hidden_size, bias=False),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, projection_size, bias=False),
nn.BatchNorm1d(projection_size, affine=False)
)

# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False):
super().__init__()
self.net = net
self.layer = layer
Expand All @@ -102,6 +109,8 @@ def __init__(self, net, projection_size, projection_hidden_size, layer = -2):
self.projection_size = projection_size
self.projection_hidden_size = projection_hidden_size

self.use_simsiam_mlp = use_simsiam_mlp

self.hidden = {}
self.hook_registered = False

Expand All @@ -127,7 +136,8 @@ def _register_hook(self):
@singleton('projector')
def _get_projector(self, hidden):
_, dim = hidden.shape
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)

def get_representation(self, x):
Expand Down Expand Up @@ -195,7 +205,7 @@ def __init__(
self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)

self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)
self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_simsiam_mlp=not use_momentum)

self.use_momentum = use_momentum
self.target_encoder = None
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
setup(
name = 'byol-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.5.7',
version = '0.6.0',
license='MIT',
description = 'Self-supervised contrastive learning made simple',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/byol-pytorch',
keywords = ['self-supervised learning', 'artificial intelligence'],
keywords = [
'self-supervised learning',
'artificial intelligence'
],
install_requires=[
'torch>=1.6',
'torchvision>=0.8'
Expand Down

0 comments on commit 6717204

Please sign in to comment.