Skip to content

Commit

Permalink
default to returning projections, but can be turned off with return_p…
Browse files Browse the repository at this point in the history
…rojection = False on forward
  • Loading branch information
lucidrains committed Apr 12, 2021
1 parent 8c08cef commit 2aa84ee
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 9 additions & 4 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def get_representation(self, x):
assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
return hidden

def forward(self, x, return_embedding = False):
def forward(self, x, return_projection = True):
representation = self.get_representation(x)

if return_embedding:
if not return_projection:
return representation

projector = self._get_projector(representation)
Expand Down Expand Up @@ -225,9 +225,14 @@ def update_moving_average(self):
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

def forward(self, x, return_embedding = False):
def forward(
self,
x,
return_embedding = False,
return_projection = True
):
if return_embedding:
return self.online_encoder(x, True)
return self.online_encoder(x, return_projection = return_projection)

image_one, image_two = self.augment1(x), self.augment2(x)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'byol-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.5.5',
version = '0.5.6',
license='MIT',
description = 'Self-supervised contrastive learning made simple',
author = 'Phil Wang',
Expand Down

0 comments on commit 2aa84ee

Please sign in to comment.