Skip to content

Commit

Permalink
concat augmented images and only call online encoder once
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 16, 2023
1 parent 4258651 commit 40565b7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
19 changes: 11 additions & 8 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,18 +262,21 @@ def forward(

image_one, image_two = self.augment1(x), self.augment2(x)

online_proj_one, _ = self.online_encoder(image_one)
online_proj_two, _ = self.online_encoder(image_two)
images = torch.cat((image_one, image_two), dim = 0)

online_pred_one = self.online_predictor(online_proj_one)
online_pred_two = self.online_predictor(online_proj_two)
online_projections, _ = self.online_encoder(images)
online_predictions = self.online_predictor(online_projections)

online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)
online_proj_one, online_proj_two = online_projections.chunk(2, dim = 0)

with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(image_one)
target_proj_two, _ = target_encoder(image_two)
target_proj_one = target_proj_one.detach()
target_proj_two = target_proj_two.detach()

target_projections, _ = target_encoder(images)
target_projections = target_projections.detach()

target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0)

loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'byol-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.1',
version = '0.7.2',
license='MIT',
description = 'Self-supervised contrastive learning made simple',
author = 'Phil Wang',
Expand Down

0 comments on commit 40565b7

Please sign in to comment.