From 6717204748c2a4f4f44b991d4c59ce5b99995582 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 5 Apr 2022 17:10:18 -0700 Subject: [PATCH] fix simsiam, thanks to @chingisooinar --- byol_pytorch/byol_pytorch.py | 40 ++++++++++++++++++++++-------------- setup.py | 7 +++++-- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/byol_pytorch/byol_pytorch.py b/byol_pytorch/byol_pytorch.py index 055e4ca6f..05e85b976 100644 --- a/byol_pytorch/byol_pytorch.py +++ b/byol_pytorch/byol_pytorch.py @@ -75,25 +75,32 @@ def update_moving_average(ema_updater, ma_model, current_model): # MLP class for projector and predictor -class MLP(nn.Module): - def __init__(self, dim, projection_size, hidden_size = 4096): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, hidden_size), - nn.BatchNorm1d(hidden_size), - nn.ReLU(inplace=True), - nn.Linear(hidden_size, projection_size) - ) - - def forward(self, x): - return self.net(x) +def MLP(dim, projection_size, hidden_size=4096): + return nn.Sequential( + nn.Linear(dim, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size) + ) + +def SimSiamMLP(dim, projection_size, hidden_size=4096): + return nn.Sequential( + nn.Linear(dim, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, hidden_size, bias=False), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, projection_size, bias=False), + nn.BatchNorm1d(projection_size, affine=False) + ) # a wrapper class for the base neural network # will manage the interception of the hidden layer output # and pipe it into the projecter and predictor nets class NetWrapper(nn.Module): - def __init__(self, net, projection_size, projection_hidden_size, layer = -2): + def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False): super().__init__() self.net = net self.layer = layer @@ -102,6 +109,8 @@ def __init__(self, net, projection_size, projection_hidden_size, layer = -2): self.projection_size = projection_size self.projection_hidden_size = projection_hidden_size + self.use_simsiam_mlp = use_simsiam_mlp + self.hidden = {} self.hook_registered = False @@ -127,7 +136,8 @@ def _register_hook(self): @singleton('projector') def _get_projector(self, hidden): _, dim = hidden.shape - projector = MLP(dim, self.projection_size, self.projection_hidden_size) + create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP + projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size) return projector.to(hidden) def get_representation(self, x): @@ -195,7 +205,7 @@ def __init__( self.augment1 = default(augment_fn, DEFAULT_AUG) self.augment2 = default(augment_fn2, self.augment1) - self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer) + self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer, use_simsiam_mlp=not use_momentum) self.use_momentum = use_momentum self.target_encoder = None diff --git a/setup.py b/setup.py index cbbeb3d8b..59303a342 100644 --- a/setup.py +++ b/setup.py @@ -3,13 +3,16 @@ setup( name = 'byol-pytorch', packages = find_packages(exclude=['examples']), - version = '0.5.7', + version = '0.6.0', license='MIT', description = 'Self-supervised contrastive learning made simple', author = 'Phil Wang', author_email = 'lucidrains@gmail.com', url = 'https://github.com/lucidrains/byol-pytorch', - keywords = ['self-supervised learning', 'artificial intelligence'], + keywords = [ + 'self-supervised learning', + 'artificial intelligence' + ], install_requires=[ 'torch>=1.6', 'torchvision>=0.8'