-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FroSSL Implementation #390
Open
OFSkean
wants to merge
5
commits into
vturrisi:main
Choose a base branch
from
OFSkean:frossl
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
defaults: | ||
- _self_ | ||
- augmentations: symmetric.yaml | ||
- wandb: private.yaml | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
# disable hydra outputs | ||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
name: "frossl-cifar10" # change here for cifar100 | ||
method: "frossl" | ||
backbone: | ||
name: "resnet18" | ||
method_kwargs: | ||
proj_hidden_dim: 2048 | ||
proj_output_dim: 1024 | ||
invariance_weight: 1.4 | ||
|
||
data: | ||
dataset: cifar10 # change here for cifar100 | ||
train_path: "./datasets" | ||
val_path: "./datasets" | ||
format: "image_folder" | ||
num_workers: 8 | ||
optimizer: | ||
name: "lars" | ||
batch_size: 256 | ||
lr: 0.3 | ||
classifier_lr: 0.1 | ||
weight_decay: 1e-4 | ||
kwargs: | ||
clip_lr: True | ||
eta: 0.02 | ||
exclude_bias_n_norm: True | ||
scheduler: | ||
name: "warmup_cosine" | ||
checkpoint: | ||
enabled: True | ||
dir: "trained_models" | ||
frequency: 1 | ||
auto_resume: | ||
enabled: True | ||
|
||
# overwrite PL stuff | ||
max_epochs: 1000 | ||
devices: [0] | ||
sync_batchnorm: True | ||
accelerator: "gpu" | ||
strategy: "ddp" | ||
precision: 16-mixed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
defaults: | ||
- _self_ | ||
- augmentations: symmetric.yaml | ||
- wandb: private.yaml | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
# disable hydra outputs | ||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
name: "frossl-imagenet100" | ||
method: "frossl" | ||
backbone: | ||
name: "resnet18" | ||
method_kwargs: | ||
proj_hidden_dim: 2048 | ||
proj_output_dim: 1024 | ||
invariance_weight: 2.0 | ||
data: | ||
dataset: imagenet100 | ||
train_path: "./datasets/imagenet100/train" | ||
val_path: "./datasets/imagenet100/val" | ||
format: "dali" | ||
num_workers: 16 | ||
optimizer: | ||
name: "lars" | ||
batch_size: 256 | ||
lr: 0.3 | ||
classifier_lr: 0.1 | ||
weight_decay: 1e-4 | ||
kwargs: | ||
clip_lr: True | ||
eta: 0.02 | ||
exclude_bias_n_norm: True | ||
scheduler: | ||
name: "warmup_cosine" | ||
checkpoint: | ||
enabled: True | ||
dir: "trained_models" | ||
frequency: 1 | ||
auto_resume: | ||
enabled: False | ||
|
||
# overwrite PL stuff | ||
max_epochs: 400 | ||
devices: [0, 1] | ||
sync_batchnorm: True | ||
accelerator: "gpu" | ||
strategy: "ddp" | ||
precision: 16-mixed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
defaults: | ||
- _self_ | ||
- augmentations: vicreg.yaml | ||
- wandb: private.yaml | ||
- override hydra/hydra_logging: disabled | ||
- override hydra/job_logging: disabled | ||
|
||
# disable hydra outputs | ||
hydra: | ||
output_subdir: null | ||
run: | ||
dir: . | ||
|
||
name: "frossl-imagenet" | ||
method: "frossl" | ||
backbone: | ||
name: "resnet18" | ||
method_kwargs: | ||
proj_hidden_dim: 2048 | ||
proj_output_dim: 1024 | ||
invariance_weight: 2.0 | ||
|
||
data: | ||
dataset: imagenet | ||
train_path: "./datasets/imagenet/train" | ||
val_path: "./datasets/imagenet/val" | ||
format: "dali" | ||
num_workers: 8 | ||
optimizer: | ||
name: "lars" | ||
batch_size: 256 | ||
lr: 0.3 | ||
classifier_lr: 0.1 | ||
weight_decay: 1e-4 | ||
kwargs: | ||
clip_lr: True | ||
eta: 0.02 | ||
exclude_bias_n_norm: True | ||
scheduler: | ||
name: "warmup_cosine" | ||
checkpoint: | ||
enabled: True | ||
dir: "trained_models" | ||
frequency: 1 | ||
auto_resume: | ||
enabled: True | ||
|
||
# overwrite PL stuff | ||
max_epochs: 100 | ||
devices: [0, 1] | ||
sync_batchnorm: True | ||
accelerator: "gpu" | ||
strategy: "ddp" | ||
precision: 16-mixed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2024 solo-learn development team. | ||
|
||
# Permission is hereby granted, free of charge, to any person obtaining a copy of | ||
# this software and associated documentation files (the "Software"), to deal in | ||
# the Software without restriction, including without limitation the rights to use, | ||
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the | ||
# Software, and to permit persons to whom the Software is furnished to do so, | ||
# subject to the following conditions: | ||
|
||
# The above copyright notice and this permission notice shall be included in all copies | ||
# or substantial portions of the Software. | ||
|
||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | ||
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR | ||
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE | ||
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR | ||
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | ||
# DEALINGS IN THE SOFTWARE. | ||
|
||
from typing import Any, List, Sequence, Dict | ||
import torch | ||
import torch.distributed as dist | ||
import torch.nn.functional as F | ||
|
||
def calculate_frobenius_regularization_term(z: torch.Tensor) -> torch.Tensor: | ||
V, N, D = z.shape | ||
|
||
if N > D: | ||
cov = torch.matmul(z.transpose(1, 2), z) # V x D x D | ||
else: | ||
cov = torch.matmul(z, z.transpose(1, 2)) # V x N x N | ||
|
||
# divide each view covariance by its trace | ||
trace = torch.diagonal(cov, dim1=1, dim2=2) # V x D | ||
trace = torch.sum(trace, dim=1) # V x 1 | ||
cov = cov / trace.unsqueeze(-1).unsqueeze(-1) | ||
|
||
# REGULARIZATION TERM - sum the log-frobenius norm of each view covariance matrix | ||
fro_norm_per_view = torch.linalg.norm(cov, dim=(1,2), ord='fro') # V x 1 | ||
regularization_term = -torch.sum( 2*torch.log(fro_norm_per_view) ) # we bring frobenius square outside log | ||
|
||
return regularization_term | ||
|
||
def calculate_invariance_term(z: torch.Tensor) -> torch.Tensor: | ||
V, N, D = z.shape | ||
|
||
# INVARIANCE - align each view to the average view | ||
average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views | ||
average_z = average_z.repeat(V, 1, 1) # V x N x D | ||
invariance_loss_term = F.mse_loss(z, average_z) | ||
|
||
return invariance_loss_term | ||
|
||
def frossl_loss_func( | ||
z: torch.Tensor, invariance_weight=1, logger=None | ||
) -> torch.Tensor: | ||
""" | ||
Implements FroSSL (https://arxiv.org/pdf/2310.02903) | ||
Heavily adapted from https://github.com/OFSkean/FroSSL. The main difference is that this | ||
implementation stacks the views and operates on all of them at once, rather than one at a time. | ||
This saves ~2 seconds (about 5% improvement) per batch with N=2,D=1024 on a A5000 GPU. For a simpler, | ||
ableit slower, implementation of loss that operates on one view at a time, please see | ||
the original implementation. | ||
|
||
Args: | ||
z (torch.Tensor): V x N x D Tensor containing projected features from the views. | ||
Every N-th sample is a different view of the same image. | ||
invariance_weight (float): weight for the invariance loss term. default is 1. | ||
|
||
Return: | ||
torch.Tensor: FroSSL loss. | ||
""" | ||
V, N, D = z.shape | ||
|
||
z = F.normalize(z, dim=1) # V x N x D | ||
|
||
regularization_term = calculate_frobenius_regularization_term(z) | ||
regularization_term = -1 * regularization_term # make sure its maximized | ||
|
||
invariance_tradeoff = V * D * invariance_weight | ||
invariance_term = calculate_invariance_term(z) | ||
invariance_term = invariance_tradeoff * invariance_term | ||
|
||
if logger is not None: | ||
logger("frossl_regularization_loss", -regularization_term, sync_dist=True) | ||
logger("frossl_invariance_loss", invariance_term, sync_dist=True) | ||
|
||
total_loss = regularization_term + invariance_term | ||
return total_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to be logged as
- regularization_term
? I would also say that for consistency with the other loss functions, it's better if you return the individual elements and log them on the method's side. You could return the two terms and sum them in the method.