Skip to content

Commit

Permalink
fix issue with hidden representation coming from last layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 17, 2020
1 parent 4d27acd commit 85997ee
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,20 @@ def _get_projector(self, hidden):
projector = MLP(dim, self.projection_size, self.projection_hidden_size)
return projector.to(hidden)

def forward(self, x):
def get_representation(self, x):
if self.layer == -1:
return self.net(x)

_ = self.net(x)
hidden = self.hidden
self.hidden = None
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden

projector = self._get_projector(hidden)
projection = projector(hidden)
def forward(self, x):
representation = self.get_representation(x)
projector = self._get_projector(representation)
projection = projector(representation)
return projection

# main class
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'byol-pytorch',
packages = find_packages(),
version = '0.0.3',
version = '0.0.5',
license='MIT',
description = 'Self-supervised contrastive learning made simple',
author = 'Phil Wang',
Expand Down

0 comments on commit 85997ee

Please sign in to comment.